Tensorflow 构建自己的目标检测与识别模型之数据增强(三)

上一篇的博客中如何对图像进行数据增强进行的叙述,见链接:https://blog.csdn.net/weixin_41644725/article/details/85678348
在本章内容中中,主要对采用数据增强技术后的图像进行保存,并将边界框信息存入到csv文件中,方便在生成tfrecord时用到(后面会提到)。
例如,以下是未采用数据增强时所生成的csv文件:

以下时未采用数据增强时存放图像的文件夹:

采用上一篇博客中(https://blog.csdn.net/weixin_41644725/article/details/85678348)中多提到的调整图像亮度,裁剪,cutout,旋转。
代码如下:

def creat_image_DA(img_name,img,bboxs,csv_path,img_class):
    if bboxs is not None:
        for bbox in bboxs:
            '''调整亮度'''
            list_box = []
            list_box.append(bbox)
            change_light_img, x_min1, y_min1, x_max1, y_max1 = changeLight(img=img, bboxes=list_box)
            change_light_img_size = change_light_img.shape
            b, g, r = cv2.split(change_light_img)
            change_light_img = cv2.merge([r, g, b])
            change_light_img = cv2.GaussianBlur(change_light_img, (3, 3), 0)
            msg1 = "change_light_" + img_name + "," + str(change_light_img_size[1]) + "," + str(
                change_light_img_size[0]) + "," \
                   + img_class + "," + str(x_min1) + "," + str(y_min1) + "," + str(x_max1) + "," + str(y_max1) + "\n"
            cv2.imwrite('./images/change_light_' + img_name, change_light_img)
            '''cutout'''
            cut_out_img, x_min2, y_min2, x_max2, y_max2 = cutout(img=img, bboxes=list_box)
            cut_out_img_size = cut_out_img.shape
            b, g, r = cv2.split(cut_out_img)
            cut_out_img = cv2.merge([r, g, b])
            cut_out_img = cv2.GaussianBlur(cut_out_img, (3, 3), 0)
            cv2.imwrite('./images/cut_out_' + img_name, cut_out_img)
            msg2 = "cut_out_" + img_name + "," + str(cut_out_img_size[1]) + "," + str(
                cut_out_img_size[0]) + "," + img_class + \
                   "," + str(x_min2) + "," + str(y_min2) + "," + str(x_max2) + "," + str(y_max2) + "\n"
            '''旋转'''
            rotate_img, x_min3, y_min3, x_max3, y_max3 = rotate_img_bbox(img=img, bboxes=list_box)
            rotate_img_size = rotate_img.shape
            b, g, r = cv2.split(rotate_img)
            rotate_img = cv2.merge([r, g, b])
            rotate_img = cv2.GaussianBlur(rotate_img, (3, 3), 0)
            cv2.imwrite('./images/rotate_' + img_name, rotate_img)
            msg3 = "rotate_" + img_name + "," + str(rotate_img_size[1]) + "," + str(
                rotate_img_size[0]) + "," + img_class + \
                   "," + str(x_min3) + "," + str(y_min3) + "," + str(x_max3) + "," + str(y_max3) + "\n"
            '''裁剪'''
            crop_img, x_min4, y_min4, x_max4, y_max4 = crop_img_bboxes(img=img, bboxes=list_box)
            crop_img_size = crop_img.shape
            b, g, r = cv2.split(crop_img)
            crop_img = cv2.merge([r, g, b])
            crop_img = cv2.GaussianBlur(crop_img, (3, 3), 0)
            cv2.imwrite('./images/crop_' + img_name, crop_img)
            msg4 = "crop_" + img_name + "," + str(crop_img_size[1]) + "," + str(crop_img_size[0]) + "," + img_class + \
                   "," + str(x_min4) + "," + str(y_min4) + "," + str(x_max4) + "," + str(y_max4) + "\n"
            all_msg = msg1 + msg2 + msg3 + msg4
            
            f = open(csv_path, 'a+')      #写入csv文件
            f.write(all_msg)
            f.close()

加载图像数据集时使用如下代码:

def load_train(train_path,csv_path):
    print('Going to read training images')
    m1 = 'change_light'
    m2 = 'cut_out'
    m3 = 'rotate'
    m4 = 'crop'
    m5 = 'shift'
    files = glob.glob(train_path)  #每个图像路径读取
    #print(len(files))
    for fl in files:
        m1_true = m1 in fl
        m2_true = m2 in fl
        m3_true = m3 in fl
        m4_true = m4 in fl
        m5_true = m5 in fl
        if m1_true!=True or m2_true!=True or m3_true!=True or m4_true!=True or m5_true!=True:
            img = cv2.imread(fl)
            b, g, r = cv2.split(img)
            img = cv2.merge([r, g, b])
            img = cv2.GaussianBlur(img, (3, 3), 0)
            coords, img_class = get_bbox(fl[7:], csv_path)
            coords = [coord[:4] for coord in coords]
            creat_image_DA(fl[7:], img, coords, csv_path, img_class)
def main():
	 csv_path = './csv/class.csv'
     train_path = 'images/*g'
     load_train(train_path,  csv_path)

main()

结果如图所示:

将图像数据集分为训练集和验证集,代码如下:

def split_train_vaild(csv_path):
    full_labels = pd.read_csv(csv_path)
    gb = full_labels.groupby('filename')
    grouped_list = [gb.get_group(x) for x in gb.groups]
    len_imge = len(grouped_list)
    train_index = np.random.choice(len_imge, size=int(len_imge*0.8), replace=False)
    test_index = np.setdiff1d(list(range(len_imge)), train_index)
    train = pd.concat([grouped_list[i] for i in train_index])
    test = pd.concat([grouped_list[i] for i in test_index])
    print(len(train_index), len(test_index))
    train.to_csv('data_set/all_train.csv', index=None)
    test.to_csv('data_set/all_vaild.csv', index=None)
csv_path = 'csv/class.csv'
split_train_vaild(csv_path)

然后生成tfrecord格式,代码如下:

from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import os
import io
import pandas as pd
import tensorflow as tf
from PIL import Image
#from object_detection.utils import dataset_util
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDict
flags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
FLAGS = flags.FLAGS

 def class_text_to_int(row_label):
    if row_label == 'class1':
        return 1
    elif row_label == class2':
        return 2
    elif row_label == 'class3':
        return 3
    elif row_label == 'class4':
        return 4
    elif row_label == 'class5':
        return 5
    elif row_label == 'class6':
        return 6
    else:
        print('NONE: ' + row_label)
        # None
def split(df, group):
    data = namedtuple('data', ['filename', 'object'])
    gb = df.groupby(group)
    return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
def create_tf_example(group, path):
    print(os.path.join(path, '{}'.format(group.filename)))
    with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    width, height = image.size
    filename = (group.filename + '.jpg').encode('utf8')
    image_format = b'jpg'
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes = []

    for index, row in group.object.iterrows():
        xmins.append(int(row['xmin']) / int(width))
        xmaxs.append(int(row['xmax']) / int(width))
        ymins.append(int(row['ymin']) / int(height))
        ymaxs.append(int(row['ymax']) / int(height))
        classes_text.append(row['class'].encode('utf8'))
        classes.append(class_text_to_int(row['class']))

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(filename),
        'image/source_id': dataset_util.bytes_feature(filename),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example
def main(csv_input, output_path, imgPath):
    writer = tf.python_io.TFRecordWriter(output_path)
    path = imgPath
    examples = pd.read_csv(csv_input)
    grouped = split(examples, 'filename')
    for group in grouped:
        tf_example = create_tf_example(group, path)
        writer.write(tf_example.SerializeToString())
    writer.close()
    print('Successfully created the TFRecords: {}'.format(output_path))


if __name__ == '__main__':
    imgPath = 'images/all_images'
    # 生成train.record文件
    output_path = 'data_set/all_train.record'
    csv_input = 'data_set/all_train.csv'
    main(csv_input, output_path, imgPath)

    # 生成验证文件 eval.record
    output_path = 'data_set/all_vaild.record'
    csv_input = 'data_set/all_vaild.csv'
    main(csv_input, output_path, imgPath)

在此处要注意下面部分,有几个类设置几个

def class_text_to_int(row_label):
    if row_label == 'class1':
        return 1
    elif row_label == class2':
        return 2
    elif row_label == 'class3':
        return 3
    elif row_label == 'class4':
        return 4
    elif row_label == 'class5':
        return 5
    elif row_label == 'class6':
        return 6
    else:
        print('NONE: ' + row_label)
        # None