Tensorflow 训练自己的目标检测与识别模型(四)
前面对数据集的创建以及图像数据增强进行了叙述,见链接:https://blog.csdn.net/weixin_41644725/article/details/85678348 和 https://blog.csdn.net/weixin_41644725/article/details/85687049
接下来对项目中的配置进行叙述:
(1)下载所需要的预训练模型。
见链接:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
如图下图所示。
下载好模型之后,解压该文件之后得到3个model.ckpt文件,将其考到存放预训练模型的文件夹下,如下图所示。
(2)配置label_map.pbtxt文件。
在当前文件夹下创建并打开该文件,在文件中写入以下信息:
item{
id:1
name:'class1’
}
item{
id:2
name:'class2'
}
item{
id:3
name:'class3'
}
item{
id: 4
name:'class4'
}
item{
id:5
name:'class5'
}
item{
id:6
name:'class6'
}
(3)配置config文件
在object_detection/samples/configs下找到模型多对应的.config文件,如图所示:
将其拷贝到自己的目录中,打开该文件,对以下地方进行修改。
第一处:在第10行处将num_classes: 96
中的96修改为自己索要训练的类别数目,若训练的类别有6类,则num_classes: 6
第二处:在108行处的fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt
的路径设置为自己的训练模型所存放的路径,例如 fine_tune_checkpoint: ./data_set/fine_tune_model/faster_rcnn_inception_resnet_v2/model.ckpt
第三处:在123行的input_path
的路径设置为自己的训练数据集的路径,例如input_path: "./new_data_set/all_train_1.record"
第四处:在125行处的label_map_path
的路径设置刚才创建label_map.pbtxt所在的路径,例如 label_map_path:"label_map.pbtxt"
第五处:在129行num_examples: 8000
修改为自己数据集的验证集的大小,例如num_examples: 3770
第六处:在137行处 input_path设置为验证集的路径,例如: input_path"./data_set/all_vaild_1.record"
第七处:在139行的修改同第四处修改一样。
(4)将train.py文件拷贝到自己的项目中。
前面对如何安装Tensorflow Object Detection API(https://blog.csdn.net/weixin_41644725/article/details/83007901) 进行了叙述。在该环境搭建好之后即可训练自己的物体检测和识别模型。在最新的Tensorflow Object Detection API中,train.py文件在以下目录中(和以前的有点不同),如图所示。
打开该文件,对该文件进行修改:
第一:在70行处的下列代码进行修改,设置训练模型存放的文件夹:
flags.DEFINE_string('train_dir', ' 训练模型存放的文件夹',
'Directory to save the checkpoints and training summaries.')
例如:flags.DEFINE_string('train_dir', './training_model',
'Directory to save the checkpoints and training summaries.')
第二处:在73行处设置config文件的路径。
flags.DEFINE_string('pipeline_config_path', '设置config文件的路径',
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file. If provided, other configs are ignored')
例如:flags.DEFINE_string('pipeline_config_path', 'faster_rcnn_inception_v2_coco.config',
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file. If provided, other configs are ignored')
模型配置完毕,开始训练模型。其结果如图下图所示。
使用tensorboard --logdir='./training_faster_rcnn_inception_v2_model'
来查看训练情况,然后在浏览器中输入:0.0.0.0:6006