— 全文阅读5分钟 —
在本文中,你将学习到以下内容:
- 将图片数据制作成tfrecord格式
- 将tfrecord格式数据还原成图片
前言
tfrecord是TensorFlow官方推荐的标准格式,能够将图片数据和标签一起存储成二进制文件,在TensorFlow中实现快速地复制、移动、读取和存储操作。训练网络的时候,通过建立队列系统,可以预先将tfrecord格式的数据加载进队列,队列会自动实现数据随机或有序地进出栈,并且队列系统和模型训练是独立进行的,这就加速了我们模型的读取和训练。
准备图片数据
按照图片预处理教程,我们获得了两组resize成224*224大小的商标图片集,把标签分别命名成1和2两类,如下图:



我们现在就将这两个类别的图片集制作成tfrecord格式。
制作tfrecord格式
导入必要的库:
1 2 3 |
import os from PIL import Image import tensorflow as tf |
定义一些路径和参数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
# 图片路径,两组标签都在该目录下 cwd = r"./brand_picture/" # tfrecord文件保存路径 file_path = r"./" # 每个tfrecord存放图片个数 bestnum = 1000 # 第几个图片 num = 0 # 第几个TFRecord文件 recordfilenum = 0 # 将labels放入到classes中 classes = [] for i in os.listdir(cwd): classes.append(i) # tfrecords格式文件名 ftrecordfilename = ("traindata_63.tfrecords-%.3d" % recordfilenum) writer = tf.python_io.TFRecordWriter(os.path.join(file_path, ftrecordfilename)) |
bestnum控制每个tfrecord的大小,这里使用1000,首先定义tf.python_io.TFRecordWriter
,方便后面写入存储数据。
制作tfrecord格式时,实际上是将图片和标签一起存储在tf.train.Example
中,它包含了一个字典,键是一个字符串,值的类型可以是BytesList
,FloatList
和Int64List
。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
for index, name in enumerate(classes): class_path = os.path.join(cwd, name) for img_name in os.listdir(class_path): num = num + 1 if num > bestnum: #超过1000,写入下一个tfrecord num = 1 recordfilenum += 1 ftrecordfilename = ("traindata_63.tfrecords-%.3d" % recordfilenum) writer = tf.python_io.TFRecordWriter(os.path.join(file_path, ftrecordfilename)) img_path = os.path.join(class_path, img_name) # 每一个图片的地址 img = Image.open(img_path, 'r') img_raw = img.tobytes() # 将图片转化为二进制格式 example = tf.train.Example( features=tf.train.Features(feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), })) writer.write(example.SerializeToString()) # 序列化为字符串 writer.close() |
在这里我们保存的label是classes中的编号索引,即0和1,你也可以改成文件名作为label,但是一定是int类型。图片读取以后转化成了二进制格式。最后通过writer写入数据到tfrecord中。
最终我们在当前目录下生成一个tfrecord文件:

读取tfrecord文件
读取tfrecord文件是存储的逆操作,我们定义一个读取tfrecord的函数,方便后面调用。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import tensorflow as tf def read_and_decode_tfrecord(filename): filename_deque = tf.train.string_input_producer(filename) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_deque) features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string)}) label = tf.cast(features['label'], tf.int32) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [224, 224, 3]) img = tf.cast(img, tf.float32) / 255.0 return img, label train_list = ['traindata_63.tfrecords-000'] img, label = read_and_decode_tfrecord(train_list) |
这段代码主要是通过tf.TFRecordReader
读取里面的数据,并且还原数据类型,最后我们对图片矩阵进行归一化。到这里我们就完成了tfrecord输出,可以对接后面的训练网络了。
如果我们想直接还原成原来的图片,就需要先注释掉读取tfrecord函数中的归一化一行,并添加部分代码,完整代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt def read_and_decode_tfrecord(filename): filename_deque = tf.train.string_input_producer(filename) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_deque) features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string)}) label = tf.cast(features['label'], tf.int32) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [224, 224, 3]) # img = tf.cast(img, tf.float32) / 255.0 #将矩阵归一化0-1之间 return img, label train_list = ['traindata_63.tfrecords-000'] img, label = read_and_decode_tfrecord(train_list) img_batch, label_batch = tf.train.batch([img, label], num_threads=2, batch_size=2, capacity=1000) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 创建一个协调器,管理线程 coord = tf.train.Coordinator() # 启动QueueRunner,此时文件名队列已经进队 threads = tf.train.start_queue_runners(sess=sess, coord=coord) while True: b_image, b_label = sess.run([img_batch, label_batch]) b_image = Image.fromarray(b_image[0]) plt.imshow(b_image) plt.axis('off') plt.show() coord.request_stop() # 其他所有线程关闭之后,这一函数才能返回 coord.join(threads) |
在后面建立了一个队列tf.train.batch
,通过Session调用顺序队列系统,输出每张图片。Session部分在训练网络的时候还会讲到。我们学习tfrecord过程,能加深对数据结构和类型的理解。到这里我们对tfrecord格式的输入输出有了一定了解,我们训练网络的准备工作已完成,接下来就是我们CNN模型的搭建工作了。