6. 多分类图像识别案例
CIFAR-10
CIFAR-10数据集由10个类别的60000 32x32彩色图像组成,每个类别有6000张图像。有50000个训练图像和10000个测试图像。数据集分为五个训练集和一个测试集,每个集有10000个图像。测试集包含来自每个类的正好1000个随机选择的图像。训练集的每个类别5000个图像。图像类别如下:
下载数据集
可以去官网下载,https://www.cs.toronto.edu/~kriz/cifar.html
里面有很多种版本我们下载 CIFAR-10二进制版本。
二进制版本格式
二进制版本包含文件data_batch_1.bin,data_batch_2.bin,data_batch_4.bin,data_batch_5.bin以及test_batch.bin。这些文件的格式如下:
<1 x label> <3072 x像素>
...
<1 x label> <3072 x像素>
第一个字节是第一个图像的标签,它是0-9范围内的数字。接下来的3072个字节是图像像素的值。前1024个字节是红色通道值,接下来是1024个绿色,最后1024个是蓝色。
所以每个文件包含10000个这样的3073字节的“行”的图像,还有一个名为batches.meta.txt的文件。这是一个ASCII文件,将范围为0-9的数字标签映射到有意义的类名。
6.1. 图片信息的读取与写入
二进制文件的读取
使用tf.FixedLengthRecordReader去读取,我们将其保存到TFRecords文件当中,以这种文件格式保存当作模型训练数据的来源
在这里我们设计一个CifarRead类去完成。将会初始化每个图片的大小数据
def __init__(self, filelist=None):
# 文件列表
self.filelist = filelist
# 每张图片大小数据初始化
self.height = 32
self.width = 32
self.channel = 3
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.channel
self.bytes = self.image_bytes + self.label_bytes
读取代码:
def read_decode(self):
""" 读取数据并转换成张量 :return: 图片数据,标签值 """
# 1、构造文件队列
file_queue = tf.train.string_input_producer(self.filelist)
# 2、构造二进制文件的阅读器,解码成张量
reader = tf.FixedLengthRecordReader(self.bytes)
key, value = reader.read(file_queue)
# 解码成张量
image_label = tf.decode_raw(value, tf.uint8)
# 分割标签与数据
label_tensor = tf.cast(tf.slice(image_label, [0], [self.label_bytes]), tf.int32)
image = tf.slice(image_label, [self.label_bytes], [self.image_bytes])
print(image)
# 3、图片数据格式转换
image_tensor = tf.reshape(image, [self.height, self.width, self.channel])
# 4、图片数据批处理,一次从二进制文件中读取多少数据出来
image_batch, label_batch = tf.train.batch([image_tensor, label_tensor], batch_size=5000, num_threads=1, capacity=50000)
return image_batch, label_batch
保存和读取TFRecords文件当中
我们通过两个接口实现
def write_to_tfrecords(self, image_batch, label_batch):
""" 把读取出来的数据进行存储(tfrecords) :param image_batch: 图片RGB值 :param label_batch: 图片标签 :return: """
# 1、构造存储器
writer = tf.python_io.TFRecordWriter(FLAGS.image_dir)
# 2、每张图片进行example协议化,存储
for i in range(5000):
print(i)
# 图片的张量要转换成字符串才能写进去,否则大小格式不对
image = image_batch[i].eval().tostring()
# 标签值
label = int(label_batch[i].eval())
# 构造example协议快,存进去的名字是提供给取的时候使用
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
writer.write(example.SerializeToString())
return None
def read_tfrecords(self):
# 1、构建文件的队列
file_queue = tf.train.string_input_producer([FLAGS.image_dir])
# 2、构建tfrecords文件阅读器
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
# 3、解析example协议块,返回字典数据,feature["image"],feature["label"]
feature = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
})
# 4、解码图片数据,标签数据不用
# 图片数据处理
image = tf.decode_raw(feature["image"], tf.uint8)
# 处理一下形状
image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
# 改变数据类型
image_tensor = tf.cast(image_reshape, tf.float32)
# 标签数据处理
label_tensor = tf.cast(feature["label"], tf.int32)
# 批处理图片数据,训练数据每批次读取多少
image_batch, label_batch = tf.train.batch([image_tensor, label_tensor], batch_size=10, num_threads=1, capacity=10)
return image_batch, label_batch
我们将数据读取的代码放入cifar_data.py文件当中,当作我们的原始数据读取,完整代码如下
import tensorflow as tf
import os
"""用于获取Cifar TFRecords数据文件的程序"""
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("image_dir", "./cifar10.tfrecords","数据文件目录")
class CifarRead(object):
def __init__(self, filelist=None):
# 文件列表
self.filelist = filelist
# 每张图片大小数据初始化
self.height = 32
self.width = 32
self.channel = 3
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.channel
self.bytes = self.image_bytes + self.label_bytes
def read_decode(self):
""" 读取数据并转换成张量 :return: 图片数据,标签值 """
# 1、构造文件队列
file_queue = tf.train.string_input_producer(self.filelist)
# 2、构造二进制文件的阅读器,解码成张量
reader = tf.FixedLengthRecordReader(self.bytes)
key, value = reader.read(file_queue)
# 解码成张量
image_label = tf.decode_raw(value, tf.uint8)
# 分割标签与数据
label_tensor = tf.cast(tf.slice(image_label, [0], [self.label_bytes]), tf.int32)
image = tf.slice(image_label, [self.label_bytes], [self.image_bytes])
print(image)
# 3、图片数据格式转换
image_tensor = tf.reshape(image, [self.height, self.width, self.channel])
# 4、图片数据批处理
image_batch, label_batch = tf.train.batch([image_tensor, label_tensor], batch_size=5000, num_threads=1, capacity=50000)
return image_batch, label_batch
def write_to_tfrecords(self, image_batch, label_batch):
""" 把读取出来的数据进行存储(tfrecords) :param image_batch: 图片RGB值 :param label_batch: 图片标签 :return: """
# 1、构造存储器
writer = tf.python_io.TFRecordWriter(FLAGS.image_dir)
# 2、每张图片进行example协议化,存储
for i in range(5000):
print(i)
# 图片的张量要转换成字符串才能写进去,否则大小格式不对
image = image_batch[i].eval().tostring()
# 标签值
label = int(label_batch[i].eval())
# 构造example协议快,存进去的名字是提供给取的时候使用
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
writer.write(example.SerializeToString())
return None
def read_tfrecords(self):
# 1、构建文件的队列
file_queue = tf.train.string_input_producer([FLAGS.image_dir])
# 2、构建tfrecords文件阅读器
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
# 3、解析example协议块,返回字典数据,feature["image"],feature["label"]
feature = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
})
# 4、解码图片数据,标签数据不用
# 图片数据处理
image = tf.decode_raw(feature["image"], tf.uint8)
# 处理一下形状
image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
# 改变数据类型
image_tensor = tf.cast(image_reshape, tf.float32)
# 标签数据处理
label_tensor = tf.cast(feature["label"], tf.int32)
# 批处理图片数据
image_batch, label_batch = tf.train.batch([image_tensor, label_tensor], batch_size=10, num_threads=1, capacity=10)
return image_batch, label_batch
if __name__ == "__main__":
# 生成文件名列表(路径+文件名)
filename = os.listdir("./cifar10/cifar-10-batches-bin")
filelist = [os.path.join("./cifar10/cifar-10-batches-bin", file) for file in filename if file[-3:] == "bin"]
# 实例化
cfr = CifarRead(filelist)
# 生成张量
image_batch, label_batch = cfr.read_decode()
# image_batch, label_batch = cfr.read_tfrecords()
# 会话
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord, start=True)
print(sess.run([image_batch, label_batch]))
# 写进tfrecords文件
cfr.write_to_tfrecords(image_batch, label_batch)
coord.request_stop()
coord.join(threads)