代码地址:https://github.com/XiaoMi/mace
NEON优化常见命令:https://blog.csdn.net/fuwenyan/article/details/78811034
UNROLL优化:https://blog.csdn.net/u013625961/article/details/62422097
Mace1 * N卷积和N * 1卷积:https://blog.csdn.net/XiaoHeiBlack/article/details/81987161

1*7 卷积源码解读

#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif

#include "mace/kernels/arm/conv_2d_neon.h"

namespace mace {
namespace kernels {

// Ho = 1, Wo = 4, Co = 4
//input代表输入tensor,output代表输出tensor,filter代表卷积核
//in_shape代表输入tensor的维度,out_shape代表输出tensor的维度
//tensor的表示为batch size x channel num x image height x image width

void Conv2dNeonK1x7S1(const float *input,
                      const float *filter,
                      const index_t *in_shape,
                      const index_t *out_shape,
                      float *output) {
  const index_t in_image_size = in_shape[2] * in_shape[3]; //image size
  const index_t out_image_size = out_shape[2] * out_shape[3];
  const index_t in_batch_size = in_shape[1] * in_image_size; //batch size
  const index_t out_batch_size = out_shape[1] * out_image_size;

#pragma omp parallel for collapse(2)
  for (index_t b = 0; b < out_shape[0]; ++b) { //for batch +1
    for (index_t m = 0; m < out_shape[1]; m += 4) { //for channel +4
      const index_t out_channels = out_shape[1]; //输出的channels数目
      const index_t out_height = out_shape[2]; //输出的height
      const index_t out_width = out_shape[3];  //输出的width
      const index_t in_channels = in_shape[1]; //输入的channels数目
      const index_t in_width = in_shape[3]; //输入的宽
      if (m + 3 < out_channels) { //可以被4整除的部分可以使用NEON优化
        float *out_ptr0_base = output + b * out_batch_size + m * out_image_size; //具体索引某个batch,某个channel的图
#if defined(MACE_ENABLE_NEON)  
        float *out_ptr1_base =
            output + b * out_batch_size + (m + 1) * out_image_size; //NEON加速一次读取4个float32
        float *out_ptr2_base =
            output + b * out_batch_size + (m + 2) * out_image_size;
        float *out_ptr3_base =
            output + b * out_batch_size + (m + 3) * out_image_size;
#endif
        for (index_t c = 0; c < in_channels; ++c) {//for in_channel +1
          const float *in_ptr_base =
              input + b * in_batch_size + c * in_image_size;
          const float *filter_ptr0 = filter + m * in_channels * 7 + c * 7; //7=kernel height*kernel width=1*7,而 in_channels*7是
		  //任何输出通道所对应的卷积参数
#if defined(MACE_ENABLE_NEON)
          const float *filter_ptr1 = filter + (m + 1) * in_channels * 7 + c * 7;
          const float *filter_ptr2 = filter + (m + 2) * in_channels * 7 + c * 7;
          const float *filter_ptr3 = filter + (m + 3) * in_channels * 7 + c * 7;//一次计算了4个filter_ptr,因为要一次输出4个out channel嘛,当然
		  //要对应的读4个卷积核(一个输出通道对应一个3维的卷积核)
          /* load filter (4 outch x 1 height x 4 width) */
		  //因为这里做的是1*7卷积,所以每个输入通道都需要一个对应的1*7个卷积核参数做乘加和。NEON内联函数vld1q_f32一次取出4个float放到向量中。
		  //如下图所示,把7个标量权重存在了两个向量中。 
          float32x4_t vf00, vf01;
          float32x4_t vf10, vf11;
          float32x4_t vf20, vf21;
          float32x4_t vf30, vf31;
          vf00 = vld1q_f32(filter_ptr0); //从数组中依次load4个元素到寄存器中
          vf01 = vld1q_f32(filter_ptr0 + 3); 
          vf10 = vld1q_f32(filter_ptr1);
          vf11 = vld1q_f32(filter_ptr1 + 3);
          vf20 = vld1q_f32(filter_ptr2);
          vf21 = vld1q_f32(filter_ptr2 + 3);
          vf30 = vld1q_f32(filter_ptr3);
          vf31 = vld1q_f32(filter_ptr3 + 3);

          for (index_t h = 0; h < out_height; ++h) { // for out_height +1
            for (index_t w = 0; w + 3 < out_width; w += 4) { //for out_width +4
              // output (4 outch x 1 height x 4 width): vo_outch_height
              float32x4_t vo0, vo1, vo2, vo3;
              // load output
              index_t out_offset = h * out_width + w;
              vo0 = vld1q_f32(out_ptr0_base + out_offset);
              vo1 = vld1q_f32(out_ptr1_base + out_offset);
              vo2 = vld1q_f32(out_ptr2_base + out_offset);
              vo3 = vld1q_f32(out_ptr3_base + out_offset);

              // input (3 slide)
              float32x4_t vi0, vi1, vi2, vi3, vi4, vi5, vi6, vi8;
              // input offset
              index_t in_offset = h * in_width + w;
              // load input
	    //依然使用vld1q_f32取出了12个float特征数据。略微不同的是使用了vextq_f32指令拼接出了额外的五个向量。
              vi0 = vld1q_f32(in_ptr_base + in_offset);
              vi4 = vld1q_f32(in_ptr_base + in_offset + 4);
              vi8 = vld1q_f32(in_ptr_base + in_offset + 8);
              vi1 = vextq_f32(vi0, vi4, 1);
              vi2 = vextq_f32(vi0, vi4, 2);
              vi3 = vextq_f32(vi0, vi4, 3);
              vi5 = vextq_f32(vi4, vi8, 1);
              vi6 = vextq_f32(vi4, vi8, 2); 
//ARMv8-A 是首款64 位架构的ARM 处理器,是移动手机端使用的CPU。
//其中的两种主要执行状态, AArch64 - 64 位执行状态是其中一种,
//这不是CPU的型号,而是处理器的指令集!HTC M9用的骁龙810就是这款CPU 
#if defined(__aarch64__)
              /* outch 0 */
              vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0);
              vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1);
              vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2);
              vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3);
              vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 1);
              vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 2);
              vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 3);
              /* outch 1 */
              vo1 = vfmaq_laneq_f32(vo1, vi0, vf10, 0);
              vo1 = vfmaq_laneq_f32(vo1, vi1, vf10, 1);
              vo1 = vfmaq_laneq_f32(vo1, vi2, vf10, 2);
              vo1 = vfmaq_laneq_f32(vo1, vi3, vf10, 3);
              vo1 = vfmaq_laneq_f32(vo1, vi4, vf11, 1);
              vo1 = vfmaq_laneq_f32(vo1, vi5, vf11, 2);
              vo1 = vfmaq_laneq_f32(vo1, vi6, vf11, 3);
              /* outch 2 */
              vo2 = vfmaq_laneq_f32(vo2, vi0, vf20, 0);
              vo2 = vfmaq_laneq_f32(vo2, vi1, vf20, 1);
              vo2 = vfmaq_laneq_f32(vo2, vi2, vf20, 2);
              vo2 = vfmaq_laneq_f32(vo2, vi3, vf20, 3);
              vo2 = vfmaq_laneq_f32(vo2, vi4, vf21, 1);
              vo2 = vfmaq_laneq_f32(vo2, vi5, vf21, 2);
              vo2 = vfmaq_laneq_f32(vo2, vi6, vf21, 3);
              /* outch 3 */
              vo3 = vfmaq_laneq_f32(vo3, vi0, vf30, 0);
              vo3 = vfmaq_laneq_f32(vo3, vi1, vf30, 1);
              vo3 = vfmaq_laneq_f32(vo3, vi2, vf30, 2);
              vo3 = vfmaq_laneq_f32(vo3, vi3, vf30, 3);
              vo3 = vfmaq_laneq_f32(vo3, vi4, vf31, 1);
              vo3 = vfmaq_laneq_f32(vo3, vi5, vf31, 2);
              vo3 = vfmaq_laneq_f32(vo3, vi6, vf31, 3);
#else
              /* outch 0 */
	    //vmlaq_lane_f32(a,b,c,i)函数为乘累加和指令。a+b*c[i],其中c[i]为标量
              vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0);
              vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1);
              vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0);
              vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1);
              vo0 = vmlaq_lane_f32(vo0, vi4, vget_low_f32(vf01), 1);
              vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0);
              vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1);
              /* outch 1 */
              vo1 = vmlaq_lane_f32(vo1, vi0, vget_low_f32(vf10), 0);
              vo1 = vmlaq_lane_f32(vo1, vi1, vget_low_f32(vf10), 1);
              vo1 = vmlaq_lane_f32(vo1, vi2, vget_high_f32(vf10), 0);
              vo1 = vmlaq_lane_f32(vo1, vi3, vget_high_f32(vf10), 1);
              vo1 = vmlaq_lane_f32(vo1, vi4, vget_low_f32(vf11), 1);
              vo1 = vmlaq_lane_f32(vo1, vi5, vget_high_f32(vf11), 0);
              vo1 = vmlaq_lane_f32(vo1, vi6, vget_high_f32(vf11), 1);
              /* outch 2 */
              vo2 = vmlaq_lane_f32(vo2, vi0, vget_low_f32(vf20), 0);
              vo2 = vmlaq_lane_f32(vo2, vi1, vget_low_f32(vf20), 1);
              vo2 = vmlaq_lane_f32(vo2, vi2, vget_high_f32(vf20), 0);
              vo2 = vmlaq_lane_f32(vo2, vi3, vget_high_f32(vf20), 1);
              vo2 = vmlaq_lane_f32(vo2, vi4, vget_low_f32(vf21), 1);
              vo2 = vmlaq_lane_f32(vo2, vi5, vget_high_f32(vf21), 0);
              vo2 = vmlaq_lane_f32(vo2, vi6, vget_high_f32(vf21), 1);
              /* outch 3 */
              vo3 = vmlaq_lane_f32(vo3, vi0, vget_low_f32(vf30), 0);
              vo3 = vmlaq_lane_f32(vo3, vi1, vget_low_f32(vf30), 1);
              vo3 = vmlaq_lane_f32(vo3, vi2, vget_high_f32(vf30), 0);
              vo3 = vmlaq_lane_f32(vo3, vi3, vget_high_f32(vf30), 1);
              vo3 = vmlaq_lane_f32(vo3, vi4, vget_low_f32(vf31), 1);
              vo3 = vmlaq_lane_f32(vo3, vi5, vget_high_f32(vf31), 0);
              vo3 = vmlaq_lane_f32(vo3, vi6, vget_high_f32(vf31), 1);
#endif
              //用vst1q_f32指令一次把4个结果写回输出内存中去。在下一次的in_channels循环中(53行)
             //此块内存还会被取出,继续累加新的卷积结果。所以该操作也同时完成了输入层中多通道
             //卷积后的累加过程。MACE并没有把加偏置项和激活放在此类卷积函数中。
              vst1q_f32(out_ptr0_base + out_offset, vo0);
              vst1q_f32(out_ptr1_base + out_offset, vo1);
              vst1q_f32(out_ptr2_base + out_offset, vo2);
              vst1q_f32(out_ptr3_base + out_offset, vo3);
            }  // w
          }    // h
#else
          for (index_t oc = 0; oc < 4; ++oc) {
            Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 7,
                               in_width, 1, 7, out_height, out_width,
                               out_ptr0_base + oc * out_image_size, 1);
          }
#endif
        }  // c
      } else { //处理output channel部分对4取余的部分
        for (index_t mm = m; mm < out_channels; ++mm) {
          float *out_ptr0_base =
              output + b * out_batch_size + mm * out_image_size;
          for (index_t c = 0; c < in_channels; ++c) {
            const float *in_ptr_base =
                input + b * in_batch_size + c * in_image_size;
            const float *filter_ptr0 = filter + mm * in_channels * 7 + c * 7;
#if defined(MACE_ENABLE_NEON)
            /* load filter (1 outch x 1 height x 4 width) */
            float32x4_t vf00, vf01;
            vf00 = vld1q_f32(filter_ptr0);
            vf01 = vld1q_f32(filter_ptr0 + 3);

            for (index_t h = 0; h < out_height; ++h) {
              for (index_t w = 0; w + 3 < out_width; w += 4) {
                // output (1 outch x 1 height x 4 width): vo_outch_height
                float32x4_t vo0;
                // load output
                index_t out_offset = h * out_width + w;
                vo0 = vld1q_f32(out_ptr0_base + out_offset);

                // input (3 slide)
                float32x4_t vi0, vi1, vi2, vi3, vi4, vi5, vi6, vi8;
                // input offset
                index_t in_offset = h * in_width + w;
                // load input
                vi0 = vld1q_f32(in_ptr_base + in_offset);
                vi4 = vld1q_f32(in_ptr_base + in_offset + 4);
                vi8 = vld1q_f32(in_ptr_base + in_offset + 8);
                vi1 = vextq_f32(vi0, vi4, 1);
                vi2 = vextq_f32(vi0, vi4, 2);
                vi3 = vextq_f32(vi0, vi4, 3);
                vi5 = vextq_f32(vi4, vi8, 1);
                vi6 = vextq_f32(vi4, vi8, 2);

#if defined(__aarch64__)
                vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0);
                vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1);
                vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2);
                vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3);
                vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 1);
                vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 2);
                vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 3);
#else
                vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0);
                vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1);
                vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0);
                vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1);
                vo0 = vmlaq_lane_f32(vo0, vi4, vget_low_f32(vf01), 1);
                vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0);
                vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1);
#endif

                vst1q_f32(out_ptr0_base + out_offset, vo0);
              }  // w
            }    // h
#else
            Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0, in_width, 1, 7,
                               out_height, out_width, out_ptr0_base, 1);
#endif
          }  // c
        }
      }  // if
    }    // m
  }      // b
}

}  // namespace kernels
}  // namespace mace

1*1 卷积源码阅读

11卷积的思路就是对于每一个batch,调用了gemm矩阵乘法运算。这里面有一个关键点在于俊阿几如何变成了矩阵乘法,请看注释,这里举个例子:C1=2,C2=3,W=2,H=3,

可以看到,这两个矩阵做完gemm乘法之后就得到了多通道1
1卷积的结果,那么是不是需要把H*W这个行向量reshape一下,成为3维的呢?实际上不需要,因为内存排布没变,所以直接这样访问的话少了reshape还可以降低复杂度。

#include "mace/kernels/arm/conv_2d_neon.h"
//思路是在每一个batch中调用了gemm矩阵乘法运算
//假设输入通道数为C1,输出通道数为C2。则一般卷积核参数为C1xC2xkhxkw,
//因此卷积核大小为1*1时,卷积核就从四维变成了两维矩阵K,大小为C1*C2。
//在单batch下,假设上一次输入数据大小为 C1*H*W,把它reshape成一个C1*(H*W)
//的矩阵F,这样多通道分别卷积再求和的过程就可以用这两个矩阵乘积来表示:
//$Z = K^t * F$得到了大小为C2*(H*W)的矩阵Z。其实就是单通道的卷积运算
//退化成了一个矩阵和一个标量的点乘运算了。

namespace mace {
namespace kernels {

void Conv2dNeonK1x1S1(const float *input,
                      const float *filter, //卷积核
                      const index_t batch, //batch的个数
                      const index_t height, 
                      const index_t width,
                      const index_t in_channels,
                      const index_t out_channels,
                      float *output,
                      SGemm *sgemm,
                      ScratchBuffer *scratch_buffer) {
  for (index_t b = 0; b < batch; ++b) {
    sgemm->Run(filter,
               input + b * in_channels * height * width,
               1,
               out_channels,
               in_channels,
               in_channels,
               height * width,
               false,
               false,
               true,
               false,
               output + b * out_channels * height * width,
               scratch_buffer);
  }
}

}  // namespace kernels
}  // namespace mace

Mace把大矩阵的运算分成2级的矩阵分块乘法,第一级的实现名字都是GemmXYZ这种形式,表示大小为[X,Y]和[Y,Z]的矩阵相乘,主要的NEON优化也是在这些函数中。这一级的矩阵计算大小都很小,最大为Gemm688,所以多数情况下变量可以保持在寄存器中,避免必存器溢出到栈上带来的时间开销。这一级的分块矩阵乘法叫register tiling。第二级优化则是把若干register tilling组成一个block,保存一个block内的内存开销(2个输入矩阵+1个输出矩阵)在L1 cache的大小范围内,提高cache命中率。称为cache tiling

register tilling

//输入矩阵A(1*4),B(4*4)。
//输入矩阵A,B分别可以装载到1个和4个1*4的浮点向量中去。再通过乘累加指令把计算结果存入1*4的结果向量中。
//像Gemm884这样的函数相当于A矩阵每行多取一个向量。
inline void Gemm144(const float *a_ptr,
                    const float *b_ptr,
                    const index_t stride_a,
                    const index_t stride_b,
                    const index_t stride_c,
                    float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
  MACE_UNUSED(stride_a);
  MACE_UNUSED(stride_c);
  float32x4_t a0;
  float32x4_t b0, b1, b2, b3;
  float32x4_t c0;
//vld1q_f32(a_ptr) load a_ptr起始地址的4个float数据到a0
  a0 = vld1q_f32(a_ptr);
  b0 = vld1q_f32(b_ptr);
  b1 = vld1q_f32(b_ptr + 1 * stride_b);
  b2 = vld1q_f32(b_ptr + 2 * stride_b);
  b3 = vld1q_f32(b_ptr + 3 * stride_b);

  c0 = vld1q_f32(c_ptr);
//0代表只有1行,这一行的下标
  MACE_GEMM_PART_CAL_4(0);
//将c0中的4个float32数据,赋值给以c_ptr为起始地址的4个float32。
  vst1q_f32(c_ptr, c0);
#else
  GemmBlock(a_ptr, b_ptr, 1, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}

其中调用了一个MACE_GEMM_PART_CAL_4(0)函数,这个函数有什么用呢?先来看一下函数原型长啥样?

//vget_low_type:获取128bit vector的低半部分元素,输出的是元素类型相同的64bit vector。
//vget_high_type:获取128bit vector的高半部分元素.输出的是元素类型相同的64bit vector。
//vget_lane_type:获取元素类型为type的vector中指定的某个元素值。
//vmlaq_lane_f32(_sum, _r0000, vget_low_f32(k), 0) //取k中的每个数,分别与r0000中的四个数相乘。
#define MACE_GEMM_PART_CAL_4(RC)                              \
  c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RC), 0);  \
  c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RC), 1);  \
  c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RC), 0); \
  c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RC), 1);

同理,还有一个叫做MACE_GEMM_PART_CAL_8(RC),实现如下:

#define MACE_GEMM_PART_CAL_8(RC, RA, RAN)                      \
  c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0);   \
  c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1);   \
  c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0);  \
  c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RA), 1);  \
  c##RC = vmlaq_lane_f32(c##RC, b4, vget_low_f32(a##RAN), 0);  \
  c##RC = vmlaq_lane_f32(c##RC, b5, vget_low_f32(a##RAN), 1);  \
  c##RC = vmlaq_lane_f32(c##RC, b6, vget_high_f32(a##RAN), 0); \
  c##RC = vmlaq_lane_f32(c##RC, b7, vget_high_f32(a##RAN), 1);

这个和CAL_4十分完全类似,就不解释了。我们再来看一段Gemm884的源码:这个源码解决的是8*8的A矩阵和84的矩阵B相乘,最后得到84C矩阵的问题。

inline void Gemm884(const float *a_ptr,
                    const float *b_ptr,
                    const index_t stride_a,
                    const index_t stride_b,
                    const index_t stride_c,
                    float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
  //加载A矩阵的16个寄存器
  float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14,
      a15;
  //加载B矩阵的8个寄存器
  float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
  //保存输出C矩阵的8个寄存器
  float32x4_t c0, c1, c2, c3, c4, c5, c6, c7;
 //加载A矩阵
  a0 = vld1q_f32(a_ptr);
  a1 = vld1q_f32(a_ptr + 4);
  a2 = vld1q_f32(a_ptr + 1 * stride_a);
  a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
  a4 = vld1q_f32(a_ptr + 2 * stride_a);
  a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
  a6 = vld1q_f32(a_ptr + 3 * stride_a);
  a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
  a8 = vld1q_f32(a_ptr + 4 * stride_a);
  a9 = vld1q_f32(a_ptr + 4 * stride_a + 4);
  a10 = vld1q_f32(a_ptr + 5 * stride_a);
  a11 = vld1q_f32(a_ptr + 5 * stride_a + 4);
  a12 = vld1q_f32(a_ptr + 6 * stride_a);
  a13 = vld1q_f32(a_ptr + 6 * stride_a + 4);
  a14 = vld1q_f32(a_ptr + 7 * stride_a);
  a15 = vld1q_f32(a_ptr + 7 * stride_a + 4);
 //加载b矩阵 
  b0 = vld1q_f32(b_ptr);
  b1 = vld1q_f32(b_ptr + 1 * stride_b);
  b2 = vld1q_f32(b_ptr + 2 * stride_b);
  b3 = vld1q_f32(b_ptr + 3 * stride_b);
  b4 = vld1q_f32(b_ptr + 4 * stride_b);
  b5 = vld1q_f32(b_ptr + 5 * stride_b);
  b6 = vld1q_f32(b_ptr + 6 * stride_b);
  b7 = vld1q_f32(b_ptr + 7 * stride_b);
 //加载c矩阵
  c0 = vld1q_f32(c_ptr);
  c1 = vld1q_f32(c_ptr + 1 * stride_c);
  c2 = vld1q_f32(c_ptr + 2 * stride_c);
  c3 = vld1q_f32(c_ptr + 3 * stride_c);
  c4 = vld1q_f32(c_ptr + 4 * stride_c);
  c5 = vld1q_f32(c_ptr + 5 * stride_c);
  c6 = vld1q_f32(c_ptr + 6 * stride_c);
  c7 = vld1q_f32(c_ptr + 7 * stride_c);
 //Gemm_PART_CAl
  MACE_GEMM_PART_CAL_8(0, 0, 1);
  MACE_GEMM_PART_CAL_8(1, 2, 3);
  MACE_GEMM_PART_CAL_8(2, 4, 5);
  MACE_GEMM_PART_CAL_8(3, 6, 7);
  MACE_GEMM_PART_CAL_8(4, 8, 9);
  MACE_GEMM_PART_CAL_8(5, 10, 11);
  MACE_GEMM_PART_CAL_8(6, 12, 13);
  MACE_GEMM_PART_CAL_8(7, 14, 15);
 //将c0到c7 8个变量,赋值给以c_ptr为起始地址的4个float32
  vst1q_f32(c_ptr, c0);
  vst1q_f32(c_ptr + 1 * stride_c, c1);
  vst1q_f32(c_ptr + 2 * stride_c, c2);
  vst1q_f32(c_ptr + 3 * stride_c, c3);
  vst1q_f32(c_ptr + 4 * stride_c, c4);
  vst1q_f32(c_ptr + 5 * stride_c, c5);
  vst1q_f32(c_ptr + 6 * stride_c, c6);
  vst1q_f32(c_ptr + 7 * stride_c, c7);
#else
  GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}

这就是第一级矩阵乘法的构成,很多个GemmXYZ,而它们的调用就是在cache tilling。先来看一个函数:
#cache tiling

//可以看到这里是根据row的个数来判断第一级矩阵乘法使用哪个GemmXYZ乘法
inline void GemmX84(const float *a_ptr,
                    const float *b_ptr,
                    const index_t stride_a,
                    const index_t stride_b,
                    const index_t stride_c,
                    float *c_ptr,
                    int row) {
  switch (row) {
    case 1:
      Gemm184(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
      break;
    case 2:
      Gemm284(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
      break;
    case 3:
      Gemm384(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
      break;
    case 4:
      Gemm484(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
      break;
    case 5:
      Gemm584(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
      break;
    case 6:
      Gemm684(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
      break;
    case 7:
      Gemm784(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
      break;
    case 8:
      Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
      break;
    default:
      MACE_NOT_IMPLEMENTED;
  }
}

然后我们来看GemmTile函数:
这个地方为了突出GemmTitle的主体逻辑代码,将aarch64和clang宏控制的部分代码删除。可以看出这个函数离面有一些边界处理,实际上不看边界的话,并且Gemm被看成单个元素的话,这里的3层for循环就相当于矩阵乘法的3层for循环。

inline void GemmTile(const float *A,
                     const float *B,
                     const index_t height,
                     const index_t K,
                     const index_t width,
                     const index_t stride_a,
                     const index_t stride_b,
                     const index_t stride_c,
                     float *C) {
  index_t h = 0;
  index_t w = 0;
  index_t k = 0;
  int reg_height_tile = 6;
  int reg_K_tile = 4;

  for (h = 0; h < height - reg_height_tile + 1; h += reg_height_tile) {
    for (k = 0; k < K - reg_K_tile + 1; k += reg_K_tile) {
      const float *a_ptr = A + (h * stride_a + k);
      for (w = 0; w + 3 < width; w += 4) {
        const float *b_ptr = B + (k * stride_b + w);
        float *c_ptr = C + (h * stride_c + w);
        Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
      }
      if (w < width) {
          const float *b_ptr = B + (k * stride_b + w);
          float *c_ptr = C + (h * stride_c + w);
          GemmBlock(a_ptr, b_ptr, reg_height_tile, reg_K_tile, width - w,
              stride_a, stride_b, stride_c, c_ptr);
      }
    }
    if (k < K) {
        const float *a_ptr = A + (h * stride_a + k);
        const float *b_ptr = B + k * stride_b;
        float *c_ptr = C + h * stride_c;
        GemmBlock(a_ptr, b_ptr, reg_height_tile, K - k, width, stride_a, stride_b,
            stride_c, c_ptr);
    }
  }
  if (h < height) {
      index_t remain_h = height - h;
      for (k = 0; k < K - reg_K_tile; k += reg_K_tile) {
          const float *a_ptr = A + (h * stride_a + k);
          index_t w;
          for (w = 0; w + 3 < width; w += 4) {
              const float *b_ptr = B + (k * stride_b + w);
              float *c_ptr = C + (h * stride_c + w);
              GemmX44(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h);
          }
          if (w < width) {
              const float *b_ptr = B + (k * stride_b + w);
              float *c_ptr = C + (h * stride_c + w);
              GemmBlock(a_ptr, b_ptr, remain_h, reg_K_tile, width - w, stride_a,
                  stride_b, stride_c, c_ptr);
          }
      }
      if (k < K) {
          const float *a_ptr = A + (h * stride_a + k);
          const float *b_ptr = B + k * stride_b;
          float *c_ptr = C + h * stride_c;
          GemmBlock(a_ptr, b_ptr, remain_h, K - k, width, stride_a, stride_b,
              stride_c, c_ptr);
      }
  }
}

最后就是Gemm函数了,这是最高级的封装了,这个函数去掉细枝末节之后长啥样呢?

// A: height x K, B: K x width, C: height x width
void Gemm(const float *A,
          const float *B,
          const index_t batch,
          const index_t height,
          const index_t K,
          const index_t width,
          float *C,
          const bool transpose_a,
          const bool transpose_b) {
  if (width == 1) {
    for (index_t b = 0; b < batch; ++b) {
      Gemv(A + b * height * K, B + b * K, 1, K, height, C + b * height);
    }
    return;
  }
  memset(C, 0, sizeof(float) * batch * height * width);

  // It is better to use large block size if it fits for fast cache.
  // Assume l1 cache size is 32k, we load three blocks at a time (A, B, C),
  // the block size should be sqrt(32k / sizeof(T) / 3).
  // As number of input channels of convolution is normally power of 2, and
  // we have not optimized tiling remains, we use the following magic number
  const index_t block_size = 64;
  const index_t block_tile_height = RoundUpDiv(height, block_size);
  const index_t block_tile_width = RoundUpDiv(width, block_size);
  const index_t block_tile_k = RoundUpDiv(K, block_size);
  const index_t block_tile[3] = {block_tile_height, block_tile_width,
                                 block_tile_k};
  const index_t remain_height = height % block_size;
  const index_t remain_width = width % block_size;
  const index_t remain_k = K % block_size;
  const index_t remain[3] = {remain_height, remain_width, remain_k};

#pragma omp parallel for collapse(3)
  for (index_t n = 0; n < batch; ++n) {
    for (index_t bh = 0; bh < block_tile[0]; ++bh) {
      for (index_t bw = 0; bw < block_tile[1]; ++bw) {
        const float *a_base = A + n * height * K;
        const float *b_base = B + n * K * width;
        float *c_base = C + n * height * width;

        const index_t ih_begin = bh * block_size;
        const index_t ih_end =
            bh * block_size +
            (bh == block_tile[0] - 1 && remain[0] > 0 ? remain[0] : block_size);
        const index_t iw_begin = bw * block_size;
        const index_t iw_end =
            bw * block_size +
            (bw == block_tile[1] - 1 && remain[1] > 0 ? remain[1] : block_size);

        for (index_t bk = 0; bk < block_tile[2]; ++bk) {
          const index_t ik_begin = bk * block_size;
          const index_t ik_end =
              bk * block_size + (bk == block_tile[2] - 1 && remain[2] > 0
                                     ? remain[2]
                                     : block_size);

          Tensor trans_a;
          Tensor trans_b;
          const float *real_a = nullptr;
          const float *real_b = nullptr;
          float *real_c = c_base + (ih_begin * width + iw_begin);
          index_t stride_a;
          index_t stride_b;
          index_t stride_c = width;

          if (transpose_a) {
            trans_a.Resize({block_size, block_size});
            float *trans_a_data = trans_a.mutable_data<float>();
            // A[K, H] -> A[H, K]
            Transpose(a_base + (ik_begin * height + ih_begin),
                      ik_end - ik_begin, ih_end - ih_begin, height,
                      trans_a_data);
            real_a = trans_a_data;
            stride_a = ik_end - ik_begin;
          } else {
            real_a = a_base + (ih_begin * K + ik_begin);
            stride_a = K;
          }

          if (transpose_b) {
            trans_b.Resize({block_size, block_size});
            float *trans_b_data = trans_b.mutable_data<float>();
            // B[W, K] -> B[K, W]
            Transpose(b_base + (iw_begin * K + ik_begin), iw_end - iw_begin,
                      ik_end - ik_begin, K, trans_b_data);
            real_b = trans_b_data;
            stride_b = iw_end - iw_begin;
          } else {
            real_b = b_base + (ik_begin * width + iw_begin);
            stride_b = width;
          }

          // inside block:
          // calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k
          GemmTile(real_a, real_b, ih_end - ih_begin, ik_end - ik_begin,
                   iw_end - iw_begin, stride_a, stride_b, stride_c, real_c);
        }  // bk
      }    // bw
    }      // bh
  }        // n
}

这个函数的主体仍是3个for循环,只是这次基础元素从gemm计算变成了对block的计算,这里的block默认大小为64,这样做的原因就是将该block涉及的到的内存可以在一个L1 cache大小内存下来。可以看到实现上对于不足步长部分,不仅会导致逻辑分支,且不能使用NEON优化,所以网络涉及时长,宽,通道数都尽量取4,64的整数倍,以取得最好的性能。