Mxnet2Caffe

将mxnet静态图symbol转换为caffe的prototxt文本,支持大部分op,caffe不需要的op则需要自己添加,再转换,否则会构建失败

  1. 将json转换为prototxt
  2. 利用caffe的python接口构建网络,将mxnet的参数param迁移到caffe网络中
  3. 构建caffe不支持的op
  4. 对结果进行比对

json_2_protxt

  1. json2prototxt.py prototxt_basic.py Read mxnet_json file and converte to prototxt
// json 格式,只要就是op(操作节点和辅助节点null) name attr(参数列表) inputs(输入列表list)
    {
      "op": "Activation", 
      "name": "part_0_stage1_unit1_relu1", 
      "attrs": {"act_type": "relu"}, 
      "inputs": [[14, 0, 0]]
    }, 
    {
      "op": "null", 
      "name": "part_0_stage1_unit1_conv1_weight", 
      "attrs": {
        "kernel": "(3, 3)", 
        "no_bias": "True", 
        "num_filter": "64", 
        "pad": "(1, 1)", 
        "stride": "(1, 1)", 
        "workspace": "256"
      }, 
      "inputs": []
    }, 
    {
      "op": "Convolution", 
      "name": "part_0_stage1_unit1_conv1", 
      "attrs": {
        "kernel": "(3, 3)", 
        "no_bias": "True", 
        "num_filter": "64", 
        "pad": "(1, 1)", 
        "stride": "(1, 1)", 
        "workspace": "256"
      }, 
      "inputs": [[15, 0, 0], [16, 0, 0]]
    }, 

读取json文件,并存储相应信息


with open(args.mx_json) as json_file:    
  jdata = json.load(json_file)

with open(args.cf_prototxt, "w") as prototxt_file:
  for i_node in range(0,len(jdata['nodes'])):
    #logging.info("i_node[%d],'name' %s" %(i_node,jdata['nodes'][i_node]['name']))
  
    node_i  = jdata['nodes'][i_node]
	# 如果当前节点是辅助节点或输入节点(只转换操作节点) 则跳过
    if str(node_i['op']) == 'null' and str(node_i['name']) != 'data':
      continue
    ''' logging.info('%d, \top:%s, name:%s -> %s'.%(i_node,node_i['op'].ljust(20), node_i['name'].ljust(30), node_i['name']).ljust(20)) '''
    ##node[i]个节点 存在的信息 op name param input
    info = node_i
    info['top'] = info['name']
    info['bottom'] = []
    info['params'] = []
    
	# 遍历当前节点的输入 存储辅助参数
    for input_idx_i in node_i['inputs']:
      # jdata['nodes'][input_idx_i[0]] jdana['nodes'][input_index]
      input_i = jdata['nodes'][input_idx_i[0]]
	  #存储所有输入节点
      if str(input_i['op']) != 'null' or (str(input_i['name']) == 'data'):
        info['bottom'].append(str(input_i['name']))

      if str(input_i['op']) == 'null':
        info['params'].append(str(input_i['name']))
        if not str(input_i['name']).startswith(str(node_i['name'])):
          logging.info(' use shared weight -> %s'% str(input_i['name']))
          info['share'] = True
    write_node(prototxt_file, info)
    

写prototxt文件

# 转换 Convolution 节点操作
def Convolution(txt_file, info):
  if info['attrs']['no_bias'] == 'True':
    bias_term = 'false'
  else:
    bias_term = 'true'  
  txt_file.write('layer {\n')
  txt_file.write(' bottom: "%s"\n'       % info['bottom'][0])
  txt_file.write(' top: "%s"\n'          % info['top'])
  txt_file.write(' name: "%s"\n'         % info['top'])
  txt_file.write(' type: "Convolution"\n')
  txt_file.write(' convolution_param {\n')
  txt_file.write(' num_output: %s\n'   % info['attrs']['num_filter'])
  txt_file.write(' kernel_size: %s\n'  % info['attrs']['kernel'].split('(')[1].split(',')[0]) # TODO
  if 'pad' not in info['attrs']:
    logging.info('miss Conv_pad, make pad default: 0 ')
    txt_file.write(' pad: %s\n' % 0)  # TODO
  else:
    txt_file.write(' pad: %s\n'          % info['attrs']['pad'].split('(')[1].split(',')[0]) # TODO
# txt_file.write(' group: %s\n' % info['attrs']['num_group'])
  txt_file.write(' stride: %s\n'       % info['attrs']['stride'].split('(')[1].split(',')[0])
  txt_file.write(' bias_term: %s\n'    % bias_term)
  txt_file.write(' }\n')
  if 'share' in info.keys() and info['share']:  
    txt_file.write(' param {\n')
    txt_file.write(' name: "%s"\n'     % info['params'][0])
    txt_file.write(' }\n')
  txt_file.write('}\n')
  txt_file.write('\n')

# -------根据op操作,完善相应的转换函数-----------
# 目前包含Conv Pool DepthConv BN Act ele_add Concat FC Reshape etc. 
def write_node(txt_file, info):
    if 'label' in info['name']:
        return        
    if info['op'] == 'null' and info['name'] == 'data':
        data(txt_file, info)
    elif info['op'] == 'Convolution':
        Convolution(txt_file, info)
    elif info['op'] == 'ChannelwiseConvolution':
        ChannelwiseConvolution(txt_file, info)
    elif info['op'] == 'BatchNorm':
        BatchNorm(txt_file, info)
    elif info['op'] == 'Activation':
        Activation(txt_file, info)
# elif info['op'] == 'ElementWiseSum':
    elif info['op'] == 'elemwise_add':
        ElementWiseSum(txt_file, info)
    elif info['op'] == '_Plus':
        ElementWiseSum(txt_file, info)
    elif info['op'] == 'Concat':
        Concat(txt_file, info)
    elif info['op'] == 'Pooling':
# Pooling(txt_file, info)
        Pooling_global(txt_file, info)
    elif info['op'] == 'Flatten':
        Flatten(txt_file, info)
    elif info['op'] == 'FullyConnected':
        FullyConnected(txt_file, info)
    elif info['op'] == 'SoftmaxOutput':
        SoftmaxOutput(txt_file, info)
    elif info['op'] == 'Cast':
        Cast(txt_file, info)
    elif info['op'] == 'SliceChannel':
        SliceChannel(txt_file, info)
    elif info['op'] == 'L2Normalization':
        L2Normalization(txt_file, info)
    elif info['op'] == 'Reshape':
      Reshape(txt_file,info)
    elif info['op'] == 'broadcast_mul':
      broadcast_mul(txt_file,info)
    else:
        logging.warn("Unknown mxnet op: %s" %info['op'])

利用caffe的python接口,构建网络,并迁移mxnet的网络参数

1.mxnet2caffe.py Read mxnet_model params_dict and converte to .caffemodel

转换的时候如果存在caffe不支持的op,需要自己添加自定义层,否则在构建网络时,会error,本工程添加了broadcast_mul层caffe添加自定义层的介绍比较多,就跳过了

根据mxnet的API (load) 加载param文件的所有参数字典

try:
    import caffe
except ImportError:
    import os, sys
    sys.path.append("/home/***/codes/mx2caffe/caffe/python/")
    import caffe
#读取全部param 参数字典 
_, arg_params, aux_params = mx.model.load_checkpoint(args.mx_model, args.mx_epoch)
all_keys = arg_params.keys() + aux_params.keys()
# 利用caffe的python接口,读取刚转换的proto构建网络,
net = caffe.Net(args.cf_prototxt, caffe.TRAIN)

for i_key,key_i in enumerate(all_keys):
  try:    
    if 'data' is key_i:
      pass
    # 在mxnet字典中,存有caffe不需要的后缀,_weight _bias 
    # 需要确认caffe的参数保存顺序 [0]是weight [1]是bias 其它op 类似查看proto结构设计
    elif '_weight' in key_i:
    
      key_caffe = key_i.replace('_weight','')
      net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat
    elif '_bias' in key_i:
      key_caffe = key_i.replace('_bias','')
      net.params[key_caffe][1].data.flat = arg_params[key_i].asnumpy().flat
    elif '_gamma' in key_i:
      key_caffe = key_i.replace('_gamma','_scale')
      net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat
    elif '_beta' in key_i:
      key_caffe = key_i.replace('_beta','_scale')
      net.params[key_caffe][1].data.flat = arg_params[key_i].asnumpy().flat
    elif '_moving_mean' in key_i:
      key_caffe = key_i.replace('_moving_mean','')
      net.params[key_caffe][0].data.flat = aux_params[key_i].asnumpy().flat
      net.params[key_caffe][2].data[...] = 1
    elif '_moving_var' in key_i:
      key_caffe = key_i.replace('_moving_var','')
      net.params[key_caffe][1].data.flat = aux_params[key_i].asnumpy().flat
      net.params[key_caffe][2].data[...] = 1
    else:
      sys.exit("Warning! Unknown mxnet:{}".format(key_i))
  
    print("% 3d | %s -> %s, initialized." 
           %(i_key, key_i.ljust(40), key_caffe.ljust(30)))
    
  except KeyError:
    print("\nWarning! key error mxnet:{}".format(key_i))  
      
# ------------------------------------------
# Finish
net.save(args.cf_model)
print("\n- Finished.\n")

对转换结果进行比对确认

  1. mxnet_test.py Debug mxnet output and you can compare the result with the converted caffemodel

使用mxnet debug, 打印需要对比的参数,并且输出指定层的结果

import mxnet as mx

def load_checkpoint_single(model, param_path):
    arg_params = {}
    aux_params = {}
    save_dict = mx.nd.load(param_path)
    for k, value in save_dict.items():
        arg_type, name = k.split(':', 1)
        if arg_type == 'arg':
            arg_params[name] = value
        if arg_type == 'aux':
            aux_params[name] = value
        else :
            pass
    model.set_params(arg_params, aux_params, allow_missing=False)
    arg_params, aux_params = model.get_params()
    return arg_params, aux_params

full_param_path = 'se_resnet34/base-0000.params'
fmodel = mx.sym.load('se_resnet34/base-symbol.json')
# 获取mxnet网络的所有layer参数
all_layers = fmodel.get_internals()

# 修改这里为需要输出layer的name+output即可指定层输出 ‘name_output’
fmodel = all_layers['flat_output']
fullmodel = mx.mod.Module(symbol=fmodel,data_names=['data'],label_names=[])

img = []
img = get_image_gray('before_forward.jpg')
fullmodel.bind(data_shapes=[('data', (1, 1, 108, 108))], label_shapes=None, for_training=False, force_rebind=False)

arg_params, aux_params = load_checkpoint_single(fullmodel, full_param_path)
fullmodel.set_params(arg_params,aux_params)

file1=open('se_resnet34.txt','w')
tic=time.time()

fullmodel.forward(Batch([mx.nd.array(img)]))

prob = fullmodel.get_outputs()[0].asnumpy()
prob = prob.astype(np.float64)
prob = prob.reshape(-1,1)
# 以特定的格式保存结果
np.savetxt(file1,prob,fmt='%.12f')

file1.close()

然后利用Caffe 加载刚才转换的网络,打印输出,对比结果精度,如果出现问题,则需要逐层排查,本工程在SENet网络上测试正常

工程项目地址

https://github.com/junqiangwu/Mxnet2Caffe-Tensor-RT-SEnet

TODO:

  • add caffe_plugin_layer
  • Tensor RT load caffe_model
  • Tensor RT supported Se_Resnet