"笨方法"学习CNN图像识别(三)—— ResNet网络训练及预测

— 全文阅读8分钟 —

在本文中,你将学习到以下内容:


  • TensorFlow中调用ResNet网络
  • 训练网络并保存模型
  • 加载模型预测结果

前言

在深度学习中,随着网络深度的增加,模型优化会变得越来越困难,甚至会发生梯度爆炸,导致整个网络训练无法收敛。ResNet(Residual Networks)的提出解决了这个问题。在这里我们直接调用ResNet网络进行训练,讲解ResNet细节的文章有很多,这里找了一篇供参考

搭建训练网络

如果你看过了前面的准备工作,图片预处理制作tfrecord格式,默认已经有tfrecord格式的数据文件了。我们接着搭建网络,来处理100类商标图片的分类问题。将制作好的tfrecord数据通过队列系统传入ResNet网络进行训练。

《
100类

首先导入必要的库:

nets库里面集成了现有的很多网络(AlexNet,Inception,ResNet,VGG)可以直接调用,我们在这里使用ResNet_50,即50层的网络训练。
接下来我们先定义一个读取tfrecord文件的函数:

定义模型保存地址,batch_sizes设置的小一点训练效果更好,将当前目录下的tfrecord文件放入列表中:

注意这里使用了tf.train.shuffle_batch随机打乱队列里面的数据顺序,num_threads表示线程数,capacity表示队列的容量,在这里设置成10000, min_after_dequeue队列里保留的最小数据量,并且控制着随机的程度,设置成9900的意思是,当队列中的数据出列100个,剩下9900个的时候,就要重新补充100个数据进来并打乱顺序。如果你要按顺序导入队列,改成tf.train.batch函数,并删除min_after_dequeue参数。这些参数都要根据自己的电脑配置进行相应的设置。
接下来将label值进行onehot编码,直接调用tf.one_hot函数。因为我们这里有100类,depth设置成100:

我们通过nets.resnet_v2.resnet_v2_50直接调用ResNet_50网络,同样num_classes等于类别总数,is_training表示我们是否要训练网络里面固定层的参数,True表示所有参数都重新训练,False表示只训练后面几层的参数。
网络搭好后,我们继续定义损失函数和优化器,损失函数选择sigmoid交叉熵,优化器选择Adam:

定义准确率函数,tf.argmax函数返回最大值所在位置:

最后我们构建Session,让网络跑起来:

当我们使用队列系统时,在Session部分一定要创建一个协调器管理线程。我们每20步输出一次准确率,在200000,300000,400000步的时候自动保存模型。
训练结束后会得到如下模型文件,我在这里只保留了300000步的模型:

《
模型文件

附上训练网络完整代码:

预测结果

我们利用1000张测试数据评估我们的模型,直接放代码:

需要注意的是test数据集并没有处理成tfrecord格式,在这里直接将图片一张张导入用模型预测,生成的结果文件主要是为了提交比赛使用。原始数据和模型我会放在这里,密码:8xbi。有兴趣自提。
至此,我们就完成了一个CNN图像识别项目。

点赞