代码地址:https://github.com/XiaoMi/mace
NEON优化常见命令:https://blog.csdn.net/fuwenyan/article/details/78811034
UNROLL优化:https://blog.csdn.net/u013625961/article/details/62422097
Mace1 * N卷积和N * 1卷积:https://blog.csdn.net/XiaoHeiBlack/article/details/81987161
1*7 卷积源码解读
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/kernels/arm/conv_2d_neon.h"
namespace mace {
namespace kernels {
// Ho = 1, Wo = 4, Co = 4
//input代表输入tensor,output代表输出tensor,filter代表卷积核
//in_shape代表输入tensor的维度,out_shape代表输出tensor的维度
//tensor的表示为batch size x channel num x image height x image width
void Conv2dNeonK1x7S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output) {
const index_t in_image_size = in_shape[2] * in_shape[3]; //image size
const index_t out_image_size = out_shape[2] * out_shape[3];
const index_t in_batch_size = in_shape[1] * in_image_size; //batch size
const index_t out_batch_size = out_shape[1] * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < out_shape[0]; ++b) { //for batch +1
for (index_t m = 0; m < out_shape[1]; m += 4) { //for channel +4
const index_t out_channels = out_shape[1]; //输出的channels数目
const index_t out_height = out_shape[2]; //输出的height
const index_t out_width = out_shape[3]; //输出的width
const index_t in_channels = in_shape[1]; //输入的channels数目
const index_t in_width = in_shape[3]; //输入的宽
if (m + 3 < out_channels) { //可以被4整除的部分可以使用NEON优化
float *out_ptr0_base = output + b * out_batch_size + m * out_image_size; //具体索引某个batch,某个channel的图
#if defined(MACE_ENABLE_NEON)
float *out_ptr1_base =
output + b * out_batch_size + (m + 1) * out_image_size; //NEON加速一次读取4个float32
float *out_ptr2_base =
output + b * out_batch_size + (m + 2) * out_image_size;
float *out_ptr3_base =
output + b * out_batch_size + (m + 3) * out_image_size;
#endif
for (index_t c = 0; c < in_channels; ++c) {//for in_channel +1
const float *in_ptr_base =
input + b * in_batch_size + c * in_image_size;
const float *filter_ptr0 = filter + m * in_channels * 7 + c * 7; //7=kernel height*kernel width=1*7,而 in_channels*7是
//任何输出通道所对应的卷积参数
#if defined(MACE_ENABLE_NEON)
const float *filter_ptr1 = filter + (m + 1) * in_channels * 7 + c * 7;
const float *filter_ptr2 = filter + (m + 2) * in_channels * 7 + c * 7;
const float *filter_ptr3 = filter + (m + 3) * in_channels * 7 + c * 7;//一次计算了4个filter_ptr,因为要一次输出4个out channel嘛,当然
//要对应的读4个卷积核(一个输出通道对应一个3维的卷积核)
/* load filter (4 outch x 1 height x 4 width) */
//因为这里做的是1*7卷积,所以每个输入通道都需要一个对应的1*7个卷积核参数做乘加和。NEON内联函数vld1q_f32一次取出4个float放到向量中。
//如下图所示,把7个标量权重存在了两个向量中。
float32x4_t vf00, vf01;
float32x4_t vf10, vf11;
float32x4_t vf20, vf21;
float32x4_t vf30, vf31;
vf00 = vld1q_f32(filter_ptr0); //从数组中依次load4个元素到寄存器中
vf01 = vld1q_f32(filter_ptr0 + 3);
vf10 = vld1q_f32(filter_ptr1);
vf11 = vld1q_f32(filter_ptr1 + 3);
vf20 = vld1q_f32(filter_ptr2);
vf21 = vld1q_f32(filter_ptr2 + 3);
vf30 = vld1q_f32(filter_ptr3);
vf31 = vld1q_f32(filter_ptr3 + 3);
for (index_t h = 0; h < out_height; ++h) { // for out_height +1
for (index_t w = 0; w + 3 < out_width; w += 4) { //for out_width +4
// output (4 outch x 1 height x 4 width): vo_outch_height
float32x4_t vo0, vo1, vo2, vo3;
// load output
index_t out_offset = h * out_width + w;
vo0 = vld1q_f32(out_ptr0_base + out_offset);
vo1 = vld1q_f32(out_ptr1_base + out_offset);
vo2 = vld1q_f32(out_ptr2_base + out_offset);
vo3 = vld1q_f32(out_ptr3_base + out_offset);
// input (3 slide)
float32x4_t vi0, vi1, vi2, vi3, vi4, vi5, vi6, vi8;
// input offset
index_t in_offset = h * in_width + w;
// load input
//依然使用vld1q_f32取出了12个float特征数据。略微不同的是使用了vextq_f32指令拼接出了额外的五个向量。
vi0 = vld1q_f32(in_ptr_base + in_offset);
vi4 = vld1q_f32(in_ptr_base + in_offset + 4);
vi8 = vld1q_f32(in_ptr_base + in_offset + 8);
vi1 = vextq_f32(vi0, vi4, 1);
vi2 = vextq_f32(vi0, vi4, 2);
vi3 = vextq_f32(vi0, vi4, 3);
vi5 = vextq_f32(vi4, vi8, 1);
vi6 = vextq_f32(vi4, vi8, 2);
//ARMv8-A 是首款64 位架构的ARM 处理器,是移动手机端使用的CPU。
//其中的两种主要执行状态, AArch64 - 64 位执行状态是其中一种,
//这不是CPU的型号,而是处理器的指令集!HTC M9用的骁龙810就是这款CPU
#if defined(__aarch64__)
/* outch 0 */
vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0);
vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1);
vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2);
vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3);
vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 1);
vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 2);
vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 3);
/* outch 1 */
vo1 = vfmaq_laneq_f32(vo1, vi0, vf10, 0);
vo1 = vfmaq_laneq_f32(vo1, vi1, vf10, 1);
vo1 = vfmaq_laneq_f32(vo1, vi2, vf10, 2);
vo1 = vfmaq_laneq_f32(vo1, vi3, vf10, 3);
vo1 = vfmaq_laneq_f32(vo1, vi4, vf11, 1);
vo1 = vfmaq_laneq_f32(vo1, vi5, vf11, 2);
vo1 = vfmaq_laneq_f32(vo1, vi6, vf11, 3);
/* outch 2 */
vo2 = vfmaq_laneq_f32(vo2, vi0, vf20, 0);
vo2 = vfmaq_laneq_f32(vo2, vi1, vf20, 1);
vo2 = vfmaq_laneq_f32(vo2, vi2, vf20, 2);
vo2 = vfmaq_laneq_f32(vo2, vi3, vf20, 3);
vo2 = vfmaq_laneq_f32(vo2, vi4, vf21, 1);
vo2 = vfmaq_laneq_f32(vo2, vi5, vf21, 2);
vo2 = vfmaq_laneq_f32(vo2, vi6, vf21, 3);
/* outch 3 */
vo3 = vfmaq_laneq_f32(vo3, vi0, vf30, 0);
vo3 = vfmaq_laneq_f32(vo3, vi1, vf30, 1);
vo3 = vfmaq_laneq_f32(vo3, vi2, vf30, 2);
vo3 = vfmaq_laneq_f32(vo3, vi3, vf30, 3);
vo3 = vfmaq_laneq_f32(vo3, vi4, vf31, 1);
vo3 = vfmaq_laneq_f32(vo3, vi5, vf31, 2);
vo3 = vfmaq_laneq_f32(vo3, vi6, vf31, 3);
#else
/* outch 0 */
//vmlaq_lane_f32(a,b,c,i)函数为乘累加和指令。a+b*c[i],其中c[i]为标量
vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0);
vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1);
vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0);
vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1);
vo0 = vmlaq_lane_f32(vo0, vi4, vget_low_f32(vf01), 1);
vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0);
vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1);
/* outch 1 */
vo1 = vmlaq_lane_f32(vo1, vi0, vget_low_f32(vf10), 0);
vo1 = vmlaq_lane_f32(vo1, vi1, vget_low_f32(vf10), 1);
vo1 = vmlaq_lane_f32(vo1, vi2, vget_high_f32(vf10), 0);
vo1 = vmlaq_lane_f32(vo1, vi3, vget_high_f32(vf10), 1);
vo1 = vmlaq_lane_f32(vo1, vi4, vget_low_f32(vf11), 1);
vo1 = vmlaq_lane_f32(vo1, vi5, vget_high_f32(vf11), 0);
vo1 = vmlaq_lane_f32(vo1, vi6, vget_high_f32(vf11), 1);
/* outch 2 */
vo2 = vmlaq_lane_f32(vo2, vi0, vget_low_f32(vf20), 0);
vo2 = vmlaq_lane_f32(vo2, vi1, vget_low_f32(vf20), 1);
vo2 = vmlaq_lane_f32(vo2, vi2, vget_high_f32(vf20), 0);
vo2 = vmlaq_lane_f32(vo2, vi3, vget_high_f32(vf20), 1);
vo2 = vmlaq_lane_f32(vo2, vi4, vget_low_f32(vf21), 1);
vo2 = vmlaq_lane_f32(vo2, vi5, vget_high_f32(vf21), 0);
vo2 = vmlaq_lane_f32(vo2, vi6, vget_high_f32(vf21), 1);
/* outch 3 */
vo3 = vmlaq_lane_f32(vo3, vi0, vget_low_f32(vf30), 0);
vo3 = vmlaq_lane_f32(vo3, vi1, vget_low_f32(vf30), 1);
vo3 = vmlaq_lane_f32(vo3, vi2, vget_high_f32(vf30), 0);
vo3 = vmlaq_lane_f32(vo3, vi3, vget_high_f32(vf30), 1);
vo3 = vmlaq_lane_f32(vo3, vi4, vget_low_f32(vf31), 1);
vo3 = vmlaq_lane_f32(vo3, vi5, vget_high_f32(vf31), 0);
vo3 = vmlaq_lane_f32(vo3, vi6, vget_high_f32(vf31), 1);
#endif
//用vst1q_f32指令一次把4个结果写回输出内存中去。在下一次的in_channels循环中(53行)
//此块内存还会被取出,继续累加新的卷积结果。所以该操作也同时完成了输入层中多通道
//卷积后的累加过程。MACE并没有把加偏置项和激活放在此类卷积函数中。
vst1q_f32(out_ptr0_base + out_offset, vo0);
vst1q_f32(out_ptr1_base + out_offset, vo1);
vst1q_f32(out_ptr2_base + out_offset, vo2);
vst1q_f32(out_ptr3_base + out_offset, vo3);
} // w
} // h
#else
for (index_t oc = 0; oc < 4; ++oc) {
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0 + oc * in_channels * 7,
in_width, 1, 7, out_height, out_width,
out_ptr0_base + oc * out_image_size, 1);
}
#endif
} // c
} else { //处理output channel部分对4取余的部分
for (index_t mm = m; mm < out_channels; ++mm) {
float *out_ptr0_base =
output + b * out_batch_size + mm * out_image_size;
for (index_t c = 0; c < in_channels; ++c) {
const float *in_ptr_base =
input + b * in_batch_size + c * in_image_size;
const float *filter_ptr0 = filter + mm * in_channels * 7 + c * 7;
#if defined(MACE_ENABLE_NEON)
/* load filter (1 outch x 1 height x 4 width) */
float32x4_t vf00, vf01;
vf00 = vld1q_f32(filter_ptr0);
vf01 = vld1q_f32(filter_ptr0 + 3);
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// output (1 outch x 1 height x 4 width): vo_outch_height
float32x4_t vo0;
// load output
index_t out_offset = h * out_width + w;
vo0 = vld1q_f32(out_ptr0_base + out_offset);
// input (3 slide)
float32x4_t vi0, vi1, vi2, vi3, vi4, vi5, vi6, vi8;
// input offset
index_t in_offset = h * in_width + w;
// load input
vi0 = vld1q_f32(in_ptr_base + in_offset);
vi4 = vld1q_f32(in_ptr_base + in_offset + 4);
vi8 = vld1q_f32(in_ptr_base + in_offset + 8);
vi1 = vextq_f32(vi0, vi4, 1);
vi2 = vextq_f32(vi0, vi4, 2);
vi3 = vextq_f32(vi0, vi4, 3);
vi5 = vextq_f32(vi4, vi8, 1);
vi6 = vextq_f32(vi4, vi8, 2);
#if defined(__aarch64__)
vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0);
vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1);
vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2);
vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3);
vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 1);
vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 2);
vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 3);
#else
vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0);
vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1);
vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0);
vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1);
vo0 = vmlaq_lane_f32(vo0, vi4, vget_low_f32(vf01), 1);
vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0);
vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1);
#endif
vst1q_f32(out_ptr0_base + out_offset, vo0);
} // w
} // h
#else
Conv2dCPUKHxKWCalc(in_ptr_base, filter_ptr0, in_width, 1, 7,
out_height, out_width, out_ptr0_base, 1);
#endif
} // c
}
} // if
} // m
} // b
}
} // namespace kernels
} // namespace mace
1*1 卷积源码阅读
11卷积的思路就是对于每一个batch,调用了gemm矩阵乘法运算。这里面有一个关键点在于俊阿几如何变成了矩阵乘法,请看注释,这里举个例子:C1=2,C2=3,W=2,H=3,
可以看到,这两个矩阵做完gemm乘法之后就得到了多通道11卷积的结果,那么是不是需要把H*W这个行向量reshape一下,成为3维的呢?实际上不需要,因为内存排布没变,所以直接这样访问的话少了reshape还可以降低复杂度。
#include "mace/kernels/arm/conv_2d_neon.h"
//思路是在每一个batch中调用了gemm矩阵乘法运算
//假设输入通道数为C1,输出通道数为C2。则一般卷积核参数为C1xC2xkhxkw,
//因此卷积核大小为1*1时,卷积核就从四维变成了两维矩阵K,大小为C1*C2。
//在单batch下,假设上一次输入数据大小为 C1*H*W,把它reshape成一个C1*(H*W)
//的矩阵F,这样多通道分别卷积再求和的过程就可以用这两个矩阵乘积来表示:
//$Z = K^t * F$得到了大小为C2*(H*W)的矩阵Z。其实就是单通道的卷积运算
//退化成了一个矩阵和一个标量的点乘运算了。
namespace mace {
namespace kernels {
void Conv2dNeonK1x1S1(const float *input,
const float *filter, //卷积核
const index_t batch, //batch的个数
const index_t height,
const index_t width,
const index_t in_channels,
const index_t out_channels,
float *output,
SGemm *sgemm,
ScratchBuffer *scratch_buffer) {
for (index_t b = 0; b < batch; ++b) {
sgemm->Run(filter,
input + b * in_channels * height * width,
1,
out_channels,
in_channels,
in_channels,
height * width,
false,
false,
true,
false,
output + b * out_channels * height * width,
scratch_buffer);
}
}
} // namespace kernels
} // namespace mace
Mace把大矩阵的运算分成2级的矩阵分块乘法,第一级的实现名字都是GemmXYZ这种形式,表示大小为[X,Y]和[Y,Z]的矩阵相乘,主要的NEON优化也是在这些函数中。这一级的矩阵计算大小都很小,最大为Gemm688,所以多数情况下变量可以保持在寄存器中,避免必存器溢出到栈上带来的时间开销。这一级的分块矩阵乘法叫register tiling。第二级优化则是把若干register tilling组成一个block,保存一个block内的内存开销(2个输入矩阵+1个输出矩阵)在L1 cache的大小范围内,提高cache命中率。称为cache tiling
register tilling
//输入矩阵A(1*4),B(4*4)。
//输入矩阵A,B分别可以装载到1个和4个1*4的浮点向量中去。再通过乘累加指令把计算结果存入1*4的结果向量中。
//像Gemm884这样的函数相当于A矩阵每行多取一个向量。
inline void Gemm144(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
MACE_UNUSED(stride_a);
MACE_UNUSED(stride_c);
float32x4_t a0;
float32x4_t b0, b1, b2, b3;
float32x4_t c0;
//vld1q_f32(a_ptr) load a_ptr起始地址的4个float数据到a0
a0 = vld1q_f32(a_ptr);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
//0代表只有1行,这一行的下标
MACE_GEMM_PART_CAL_4(0);
//将c0中的4个float32数据,赋值给以c_ptr为起始地址的4个float32。
vst1q_f32(c_ptr, c0);
#else
GemmBlock(a_ptr, b_ptr, 1, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
其中调用了一个MACE_GEMM_PART_CAL_4(0)函数,这个函数有什么用呢?先来看一下函数原型长啥样?
//vget_low_type:获取128bit vector的低半部分元素,输出的是元素类型相同的64bit vector。
//vget_high_type:获取128bit vector的高半部分元素.输出的是元素类型相同的64bit vector。
//vget_lane_type:获取元素类型为type的vector中指定的某个元素值。
//vmlaq_lane_f32(_sum, _r0000, vget_low_f32(k), 0) //取k中的每个数,分别与r0000中的四个数相乘。
#define MACE_GEMM_PART_CAL_4(RC) \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RC), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RC), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RC), 0); \
c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RC), 1);
同理,还有一个叫做MACE_GEMM_PART_CAL_8(RC),实现如下:
#define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b4, vget_low_f32(a##RAN), 0); \
c##RC = vmlaq_lane_f32(c##RC, b5, vget_low_f32(a##RAN), 1); \
c##RC = vmlaq_lane_f32(c##RC, b6, vget_high_f32(a##RAN), 0); \
c##RC = vmlaq_lane_f32(c##RC, b7, vget_high_f32(a##RAN), 1);
这个和CAL_4十分完全类似,就不解释了。我们再来看一段Gemm884的源码:这个源码解决的是8*8的A矩阵和84的矩阵B相乘,最后得到84C矩阵的问题。
inline void Gemm884(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
//加载A矩阵的16个寄存器
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14,
a15;
//加载B矩阵的8个寄存器
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
//保存输出C矩阵的8个寄存器
float32x4_t c0, c1, c2, c3, c4, c5, c6, c7;
//加载A矩阵
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_a);
a9 = vld1q_f32(a_ptr + 4 * stride_a + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_a);
a11 = vld1q_f32(a_ptr + 5 * stride_a + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_a);
a13 = vld1q_f32(a_ptr + 6 * stride_a + 4);
a14 = vld1q_f32(a_ptr + 7 * stride_a);
a15 = vld1q_f32(a_ptr + 7 * stride_a + 4);
//加载b矩阵
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
//加载c矩阵
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
c5 = vld1q_f32(c_ptr + 5 * stride_c);
c6 = vld1q_f32(c_ptr + 6 * stride_c);
c7 = vld1q_f32(c_ptr + 7 * stride_c);
//Gemm_PART_CAl
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL_8(4, 8, 9);
MACE_GEMM_PART_CAL_8(5, 10, 11);
MACE_GEMM_PART_CAL_8(6, 12, 13);
MACE_GEMM_PART_CAL_8(7, 14, 15);
//将c0到c7 8个变量,赋值给以c_ptr为起始地址的4个float32
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
vst1q_f32(c_ptr + 5 * stride_c, c5);
vst1q_f32(c_ptr + 6 * stride_c, c6);
vst1q_f32(c_ptr + 7 * stride_c, c7);
#else
GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
这就是第一级矩阵乘法的构成,很多个GemmXYZ,而它们的调用就是在cache tilling。先来看一个函数:
#cache tiling
//可以看到这里是根据row的个数来判断第一级矩阵乘法使用哪个GemmXYZ乘法
inline void GemmX84(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr,
int row) {
switch (row) {
case 1:
Gemm184(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 2:
Gemm284(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 3:
Gemm384(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 4:
Gemm484(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 5:
Gemm584(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 6:
Gemm684(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 7:
Gemm784(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 8:
Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
default:
MACE_NOT_IMPLEMENTED;
}
}
然后我们来看GemmTile函数:
这个地方为了突出GemmTitle的主体逻辑代码,将aarch64和clang宏控制的部分代码删除。可以看出这个函数离面有一些边界处理,实际上不看边界的话,并且Gemm被看成单个元素的话,这里的3层for循环就相当于矩阵乘法的3层for循环。
inline void GemmTile(const float *A,
const float *B,
const index_t height,
const index_t K,
const index_t width,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *C) {
index_t h = 0;
index_t w = 0;
index_t k = 0;
int reg_height_tile = 6;
int reg_K_tile = 4;
for (h = 0; h < height - reg_height_tile + 1; h += reg_height_tile) {
for (k = 0; k < K - reg_K_tile + 1; k += reg_K_tile) {
const float *a_ptr = A + (h * stride_a + k);
for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
}
if (w < width) {
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
GemmBlock(a_ptr, b_ptr, reg_height_tile, reg_K_tile, width - w,
stride_a, stride_b, stride_c, c_ptr);
}
}
if (k < K) {
const float *a_ptr = A + (h * stride_a + k);
const float *b_ptr = B + k * stride_b;
float *c_ptr = C + h * stride_c;
GemmBlock(a_ptr, b_ptr, reg_height_tile, K - k, width, stride_a, stride_b,
stride_c, c_ptr);
}
}
if (h < height) {
index_t remain_h = height - h;
for (k = 0; k < K - reg_K_tile; k += reg_K_tile) {
const float *a_ptr = A + (h * stride_a + k);
index_t w;
for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
GemmX44(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h);
}
if (w < width) {
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
GemmBlock(a_ptr, b_ptr, remain_h, reg_K_tile, width - w, stride_a,
stride_b, stride_c, c_ptr);
}
}
if (k < K) {
const float *a_ptr = A + (h * stride_a + k);
const float *b_ptr = B + k * stride_b;
float *c_ptr = C + h * stride_c;
GemmBlock(a_ptr, b_ptr, remain_h, K - k, width, stride_a, stride_b,
stride_c, c_ptr);
}
}
}
最后就是Gemm函数了,这是最高级的封装了,这个函数去掉细枝末节之后长啥样呢?
// A: height x K, B: K x width, C: height x width
void Gemm(const float *A,
const float *B,
const index_t batch,
const index_t height,
const index_t K,
const index_t width,
float *C,
const bool transpose_a,
const bool transpose_b) {
if (width == 1) {
for (index_t b = 0; b < batch; ++b) {
Gemv(A + b * height * K, B + b * K, 1, K, height, C + b * height);
}
return;
}
memset(C, 0, sizeof(float) * batch * height * width);
// It is better to use large block size if it fits for fast cache.
// Assume l1 cache size is 32k, we load three blocks at a time (A, B, C),
// the block size should be sqrt(32k / sizeof(T) / 3).
// As number of input channels of convolution is normally power of 2, and
// we have not optimized tiling remains, we use the following magic number
const index_t block_size = 64;
const index_t block_tile_height = RoundUpDiv(height, block_size);
const index_t block_tile_width = RoundUpDiv(width, block_size);
const index_t block_tile_k = RoundUpDiv(K, block_size);
const index_t block_tile[3] = {block_tile_height, block_tile_width,
block_tile_k};
const index_t remain_height = height % block_size;
const index_t remain_width = width % block_size;
const index_t remain_k = K % block_size;
const index_t remain[3] = {remain_height, remain_width, remain_k};
#pragma omp parallel for collapse(3)
for (index_t n = 0; n < batch; ++n) {
for (index_t bh = 0; bh < block_tile[0]; ++bh) {
for (index_t bw = 0; bw < block_tile[1]; ++bw) {
const float *a_base = A + n * height * K;
const float *b_base = B + n * K * width;
float *c_base = C + n * height * width;
const index_t ih_begin = bh * block_size;
const index_t ih_end =
bh * block_size +
(bh == block_tile[0] - 1 && remain[0] > 0 ? remain[0] : block_size);
const index_t iw_begin = bw * block_size;
const index_t iw_end =
bw * block_size +
(bw == block_tile[1] - 1 && remain[1] > 0 ? remain[1] : block_size);
for (index_t bk = 0; bk < block_tile[2]; ++bk) {
const index_t ik_begin = bk * block_size;
const index_t ik_end =
bk * block_size + (bk == block_tile[2] - 1 && remain[2] > 0
? remain[2]
: block_size);
Tensor trans_a;
Tensor trans_b;
const float *real_a = nullptr;
const float *real_b = nullptr;
float *real_c = c_base + (ih_begin * width + iw_begin);
index_t stride_a;
index_t stride_b;
index_t stride_c = width;
if (transpose_a) {
trans_a.Resize({block_size, block_size});
float *trans_a_data = trans_a.mutable_data<float>();
// A[K, H] -> A[H, K]
Transpose(a_base + (ik_begin * height + ih_begin),
ik_end - ik_begin, ih_end - ih_begin, height,
trans_a_data);
real_a = trans_a_data;
stride_a = ik_end - ik_begin;
} else {
real_a = a_base + (ih_begin * K + ik_begin);
stride_a = K;
}
if (transpose_b) {
trans_b.Resize({block_size, block_size});
float *trans_b_data = trans_b.mutable_data<float>();
// B[W, K] -> B[K, W]
Transpose(b_base + (iw_begin * K + ik_begin), iw_end - iw_begin,
ik_end - ik_begin, K, trans_b_data);
real_b = trans_b_data;
stride_b = iw_end - iw_begin;
} else {
real_b = b_base + (ik_begin * width + iw_begin);
stride_b = width;
}
// inside block:
// calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k
GemmTile(real_a, real_b, ih_end - ih_begin, ik_end - ik_begin,
iw_end - iw_begin, stride_a, stride_b, stride_c, real_c);
} // bk
} // bw
} // bh
} // n
}
这个函数的主体仍是3个for循环,只是这次基础元素从gemm计算变成了对block的计算,这里的block默认大小为64,这样做的原因就是将该block涉及的到的内存可以在一个L1 cache大小内存下来。可以看到实现上对于不足步长部分,不仅会导致逻辑分支,且不能使用NEON优化,所以网络涉及时长,宽,通道数都尽量取4,64的整数倍,以取得最好的性能。