Mxnet2Caffe
将mxnet静态图symbol转换为caffe的prototxt文本,支持大部分op,caffe不需要的op则需要自己添加,再转换,否则会构建失败
- 将json转换为prototxt
- 利用caffe的python接口构建网络,将mxnet的参数param迁移到caffe网络中
- 构建caffe不支持的op
- 对结果进行比对
json_2_protxt
json2prototxt.py prototxt_basic.py
Read mxnet_json file and converte to prototxt
// json 格式,只要就是op(操作节点和辅助节点null) name attr(参数列表) inputs(输入列表list)
{
"op": "Activation",
"name": "part_0_stage1_unit1_relu1",
"attrs": {"act_type": "relu"},
"inputs": [[14, 0, 0]]
},
{
"op": "null",
"name": "part_0_stage1_unit1_conv1_weight",
"attrs": {
"kernel": "(3, 3)",
"no_bias": "True",
"num_filter": "64",
"pad": "(1, 1)",
"stride": "(1, 1)",
"workspace": "256"
},
"inputs": []
},
{
"op": "Convolution",
"name": "part_0_stage1_unit1_conv1",
"attrs": {
"kernel": "(3, 3)",
"no_bias": "True",
"num_filter": "64",
"pad": "(1, 1)",
"stride": "(1, 1)",
"workspace": "256"
},
"inputs": [[15, 0, 0], [16, 0, 0]]
},
读取json文件,并存储相应信息
with open(args.mx_json) as json_file:
jdata = json.load(json_file)
with open(args.cf_prototxt, "w") as prototxt_file:
for i_node in range(0,len(jdata['nodes'])):
#logging.info("i_node[%d],'name' %s" %(i_node,jdata['nodes'][i_node]['name']))
node_i = jdata['nodes'][i_node]
# 如果当前节点是辅助节点或输入节点(只转换操作节点) 则跳过
if str(node_i['op']) == 'null' and str(node_i['name']) != 'data':
continue
''' logging.info('%d, \top:%s, name:%s -> %s'.%(i_node,node_i['op'].ljust(20), node_i['name'].ljust(30), node_i['name']).ljust(20)) '''
##node[i]个节点 存在的信息 op name param input
info = node_i
info['top'] = info['name']
info['bottom'] = []
info['params'] = []
# 遍历当前节点的输入 存储辅助参数
for input_idx_i in node_i['inputs']:
# jdata['nodes'][input_idx_i[0]] jdana['nodes'][input_index]
input_i = jdata['nodes'][input_idx_i[0]]
#存储所有输入节点
if str(input_i['op']) != 'null' or (str(input_i['name']) == 'data'):
info['bottom'].append(str(input_i['name']))
if str(input_i['op']) == 'null':
info['params'].append(str(input_i['name']))
if not str(input_i['name']).startswith(str(node_i['name'])):
logging.info(' use shared weight -> %s'% str(input_i['name']))
info['share'] = True
write_node(prototxt_file, info)
写prototxt文件
# 转换 Convolution 节点操作
def Convolution(txt_file, info):
if info['attrs']['no_bias'] == 'True':
bias_term = 'false'
else:
bias_term = 'true'
txt_file.write('layer {\n')
txt_file.write(' bottom: "%s"\n' % info['bottom'][0])
txt_file.write(' top: "%s"\n' % info['top'])
txt_file.write(' name: "%s"\n' % info['top'])
txt_file.write(' type: "Convolution"\n')
txt_file.write(' convolution_param {\n')
txt_file.write(' num_output: %s\n' % info['attrs']['num_filter'])
txt_file.write(' kernel_size: %s\n' % info['attrs']['kernel'].split('(')[1].split(',')[0]) # TODO
if 'pad' not in info['attrs']:
logging.info('miss Conv_pad, make pad default: 0 ')
txt_file.write(' pad: %s\n' % 0) # TODO
else:
txt_file.write(' pad: %s\n' % info['attrs']['pad'].split('(')[1].split(',')[0]) # TODO
# txt_file.write(' group: %s\n' % info['attrs']['num_group'])
txt_file.write(' stride: %s\n' % info['attrs']['stride'].split('(')[1].split(',')[0])
txt_file.write(' bias_term: %s\n' % bias_term)
txt_file.write(' }\n')
if 'share' in info.keys() and info['share']:
txt_file.write(' param {\n')
txt_file.write(' name: "%s"\n' % info['params'][0])
txt_file.write(' }\n')
txt_file.write('}\n')
txt_file.write('\n')
# -------根据op操作,完善相应的转换函数-----------
# 目前包含Conv Pool DepthConv BN Act ele_add Concat FC Reshape etc.
def write_node(txt_file, info):
if 'label' in info['name']:
return
if info['op'] == 'null' and info['name'] == 'data':
data(txt_file, info)
elif info['op'] == 'Convolution':
Convolution(txt_file, info)
elif info['op'] == 'ChannelwiseConvolution':
ChannelwiseConvolution(txt_file, info)
elif info['op'] == 'BatchNorm':
BatchNorm(txt_file, info)
elif info['op'] == 'Activation':
Activation(txt_file, info)
# elif info['op'] == 'ElementWiseSum':
elif info['op'] == 'elemwise_add':
ElementWiseSum(txt_file, info)
elif info['op'] == '_Plus':
ElementWiseSum(txt_file, info)
elif info['op'] == 'Concat':
Concat(txt_file, info)
elif info['op'] == 'Pooling':
# Pooling(txt_file, info)
Pooling_global(txt_file, info)
elif info['op'] == 'Flatten':
Flatten(txt_file, info)
elif info['op'] == 'FullyConnected':
FullyConnected(txt_file, info)
elif info['op'] == 'SoftmaxOutput':
SoftmaxOutput(txt_file, info)
elif info['op'] == 'Cast':
Cast(txt_file, info)
elif info['op'] == 'SliceChannel':
SliceChannel(txt_file, info)
elif info['op'] == 'L2Normalization':
L2Normalization(txt_file, info)
elif info['op'] == 'Reshape':
Reshape(txt_file,info)
elif info['op'] == 'broadcast_mul':
broadcast_mul(txt_file,info)
else:
logging.warn("Unknown mxnet op: %s" %info['op'])
利用caffe的python接口,构建网络,并迁移mxnet的网络参数
1.mxnet2caffe.py
Read mxnet_model params_dict and converte to .caffemodel
转换的时候如果存在caffe不支持的op,需要自己添加自定义层,否则在构建网络时,会error,本工程添加了broadcast_mul
层caffe添加自定义层的介绍比较多,就跳过了
根据mxnet的API (load) 加载param文件的所有参数字典
try:
import caffe
except ImportError:
import os, sys
sys.path.append("/home/***/codes/mx2caffe/caffe/python/")
import caffe
#读取全部param 参数字典
_, arg_params, aux_params = mx.model.load_checkpoint(args.mx_model, args.mx_epoch)
all_keys = arg_params.keys() + aux_params.keys()
# 利用caffe的python接口,读取刚转换的proto构建网络,
net = caffe.Net(args.cf_prototxt, caffe.TRAIN)
for i_key,key_i in enumerate(all_keys):
try:
if 'data' is key_i:
pass
# 在mxnet字典中,存有caffe不需要的后缀,_weight _bias
# 需要确认caffe的参数保存顺序 [0]是weight [1]是bias 其它op 类似查看proto结构设计
elif '_weight' in key_i:
key_caffe = key_i.replace('_weight','')
net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat
elif '_bias' in key_i:
key_caffe = key_i.replace('_bias','')
net.params[key_caffe][1].data.flat = arg_params[key_i].asnumpy().flat
elif '_gamma' in key_i:
key_caffe = key_i.replace('_gamma','_scale')
net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat
elif '_beta' in key_i:
key_caffe = key_i.replace('_beta','_scale')
net.params[key_caffe][1].data.flat = arg_params[key_i].asnumpy().flat
elif '_moving_mean' in key_i:
key_caffe = key_i.replace('_moving_mean','')
net.params[key_caffe][0].data.flat = aux_params[key_i].asnumpy().flat
net.params[key_caffe][2].data[...] = 1
elif '_moving_var' in key_i:
key_caffe = key_i.replace('_moving_var','')
net.params[key_caffe][1].data.flat = aux_params[key_i].asnumpy().flat
net.params[key_caffe][2].data[...] = 1
else:
sys.exit("Warning! Unknown mxnet:{}".format(key_i))
print("% 3d | %s -> %s, initialized."
%(i_key, key_i.ljust(40), key_caffe.ljust(30)))
except KeyError:
print("\nWarning! key error mxnet:{}".format(key_i))
# ------------------------------------------
# Finish
net.save(args.cf_model)
print("\n- Finished.\n")
对转换结果进行比对确认
mxnet_test.py
Debug mxnet output and you can compare the result with the converted caffemodel
使用mxnet debug, 打印需要对比的参数,并且输出指定层的结果
import mxnet as mx
def load_checkpoint_single(model, param_path):
arg_params = {}
aux_params = {}
save_dict = mx.nd.load(param_path)
for k, value in save_dict.items():
arg_type, name = k.split(':', 1)
if arg_type == 'arg':
arg_params[name] = value
if arg_type == 'aux':
aux_params[name] = value
else :
pass
model.set_params(arg_params, aux_params, allow_missing=False)
arg_params, aux_params = model.get_params()
return arg_params, aux_params
full_param_path = 'se_resnet34/base-0000.params'
fmodel = mx.sym.load('se_resnet34/base-symbol.json')
# 获取mxnet网络的所有layer参数
all_layers = fmodel.get_internals()
# 修改这里为需要输出layer的name+output即可指定层输出 ‘name_output’
fmodel = all_layers['flat_output']
fullmodel = mx.mod.Module(symbol=fmodel,data_names=['data'],label_names=[])
img = []
img = get_image_gray('before_forward.jpg')
fullmodel.bind(data_shapes=[('data', (1, 1, 108, 108))], label_shapes=None, for_training=False, force_rebind=False)
arg_params, aux_params = load_checkpoint_single(fullmodel, full_param_path)
fullmodel.set_params(arg_params,aux_params)
file1=open('se_resnet34.txt','w')
tic=time.time()
fullmodel.forward(Batch([mx.nd.array(img)]))
prob = fullmodel.get_outputs()[0].asnumpy()
prob = prob.astype(np.float64)
prob = prob.reshape(-1,1)
# 以特定的格式保存结果
np.savetxt(file1,prob,fmt='%.12f')
file1.close()
然后利用Caffe 加载刚才转换的网络,打印输出,对比结果精度,如果出现问题,则需要逐层排查,本工程在SENet
网络上测试正常
https://github.com/junqiangwu/Mxnet2Caffe-Tensor-RT-SEnet
TODO:
add caffe_plugin_layerTensor RT load caffe_modelTensor RT supported Se_Resnet