Tensorflow 训练自己的目标检测与识别模型(四)

前面对数据集的创建以及图像数据增强进行了叙述,见链接:https://blog.csdn.net/weixin_41644725/article/details/85678348https://blog.csdn.net/weixin_41644725/article/details/85687049
接下来对项目中的配置进行叙述:

(1)下载所需要的预训练模型。

见链接:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
如图下图所示。

下载好模型之后,解压该文件之后得到3个model.ckpt文件,将其考到存放预训练模型的文件夹下,如下图所示。

(2)配置label_map.pbtxt文件。

在当前文件夹下创建并打开该文件,在文件中写入以下信息:

item{
  id:1
  name:'class1’
}
item{
  id:2
  name:'class2'
}
item{
  id:3
  name:'class3'
}
item{
  id: 4
  name:'class4'
}
item{
  id:5
  name:'class5'
}
item{
  id:6
  name:'class6'
}

(3)配置config文件

在object_detection/samples/configs下找到模型多对应的.config文件,如图所示:

将其拷贝到自己的目录中,打开该文件,对以下地方进行修改。

第一处:在第10行处将num_classes: 96中的96修改为自己索要训练的类别数目,若训练的类别有6类,则num_classes: 6

第二处:在108行处的fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt 的路径设置为自己的训练模型所存放的路径,例如 fine_tune_checkpoint: ./data_set/fine_tune_model/faster_rcnn_inception_resnet_v2/model.ckpt

第三处:在123行的input_path的路径设置为自己的训练数据集的路径,例如input_path: "./new_data_set/all_train_1.record"

第四处:在125行处的label_map_path的路径设置刚才创建label_map.pbtxt所在的路径,例如 label_map_path:"label_map.pbtxt"

第五处:在129行num_examples: 8000 修改为自己数据集的验证集的大小,例如num_examples: 3770

第六处:在137行处 input_path设置为验证集的路径,例如: input_path"./data_set/all_vaild_1.record"

第七处:在139行的修改同第四处修改一样。

(4)将train.py文件拷贝到自己的项目中。

前面对如何安装Tensorflow Object Detection API(https://blog.csdn.net/weixin_41644725/article/details/83007901) 进行了叙述。在该环境搭建好之后即可训练自己的物体检测和识别模型。在最新的Tensorflow Object Detection API中,train.py文件在以下目录中(和以前的有点不同),如图所示。

打开该文件,对该文件进行修改:
第一:在70行处的下列代码进行修改,设置训练模型存放的文件夹:

flags.DEFINE_string('train_dir', ' 训练模型存放的文件夹',
                    'Directory to save the checkpoints and training summaries.')
                    
例如:flags.DEFINE_string('train_dir', './training_model',
                    'Directory to save the checkpoints and training summaries.')

第二处:在73行处设置config文件的路径。

flags.DEFINE_string('pipeline_config_path', '设置config文件的路径',
                        'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
                        'file. If provided, other configs are ignored')
                        
例如:flags.DEFINE_string('pipeline_config_path', 'faster_rcnn_inception_v2_coco.config',
                    'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
                    'file. If provided, other configs are ignored')

模型配置完毕,开始训练模型。其结果如图下图所示。

使用tensorboard --logdir='./training_faster_rcnn_inception_v2_model'来查看训练情况,然后在浏览器中输入:0.0.0.0:6006