一个特征金字塔FPN的总体架构包括四个方面:

  1. 自下而上网络
  2. 自上而下网络
  3. 横向连接网络
  4. 卷积融合

1.自下而上网络

如上图所示,最左侧为普通的卷积网络,这里使用ResNet来提取语义信息。C1通过ResNet前几层网络获得,而C2到C5通过不同数量的ResNet块获得。组内特征图大小相同,组间大小递减。

2.自上而下网络

首先对C5进行1×1卷积降低通道数得到P5,一次进行上采样得到P4,P3,P2。目的是为了得到和C4,C3,C2相同长宽的特征,方便逐元素相加。 这里采用二倍最近邻上采样,直接对临近元素进行复制,非线性插值。

代码表示:

F.interpolate(x,size=(H,W),mode='bilinear',align_corners=False)

3.横向连接网络

目的为了将上采样后的高语义特征和浅层的定位细节特征进行融合。高语义进行上采样后,通道固定为256,因此需要对C2-C4使用1×1卷积进行通道变换,从而才能逐元素相加。C1的特征图尺寸较大且语义信息不足,所以没有把C1放入横向连接。

4.卷积融合

得到相加特征后,利用3×3卷积对生成的P2-P4在进行再融合,目的是为了消除上采样过程带来的重叠效应,以生成最终的特征图。

5.代码

FPN网络代码,主要参数信息如顶上图所示:

#特征金字塔:FPN
import  numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import torch
#resNet的基本Bottleneck类
class Bottleneck(nn.Module):
    expansion = 4        #通道倍增数
    def __init__(self,in_planes, planes, stride=1, downsample = None):
        super(Bottleneck,self).__init__()
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_planes,planes,1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace= True),
            nn.Conv2d(planes, planes , 3, stride ,1 ,bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(planes, self.expansion * planes, 1 , bias=False),
            nn.BatchNorm2d(self.expansion * planes),)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
    def forward(self,x):
        identity = x
        out = self.bottleneck(x)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

#FPN的类,初始化需要一个list,代表resNet每一个阶段的Bottleneck的数量
class FPN(nn.Module):
    def __init__(self,layers):
        super(FPN,self).__init__()
        self.inplanes = 64

        #处理输入的C1模块
        self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias = False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace = True)
        self.maxpool = nn.MaxPool2d(3,2,1)

        #搭建自上而下的C2、C3、C4、C5
        self.layer1 = self._make_layer(64,layers[0])
        self.layer2 = self._make_layer(128,layers[1],2)
        self.layer3 = self._make_layer(256,layers[2],2)
        self.layer4 = self._make_layer(512,layers[3],2)

        #对C5减少通道数得到P5
        self.toplayer = nn.Conv2d(2048,256,1,1,0)

        #3x3卷积融合特征
        self.smooth1 = nn.Conv2d(256, 256, 3, 1, 1)
        self.smooth2 = nn.Conv2d(256, 256, 3, 1, 1)
        self.smooth3 = nn.Conv2d(256, 256, 3, 1, 1)

        #横向连接,保证通道数相同
        self.latlayer1 = nn.Conv2d(1024,256,1,1,0)
        self.latlayer2 = nn.Conv2d(512, 256, 1, 1, 0)
        self.latlayer3 = nn.Conv2d(256, 256, 1, 1, 0)


    #构建C2到C5需要注意stride值为1和2的情况
    def _make_layer(self,planes, blocks, stride = 1):
        downsample = None
        if stride != 1 or self.inplanes != Bottleneck.expansion * planes :
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes,Bottleneck.expansion * planes, 1 , stride , bias=False),
                nn.BatchNorm2d(Bottleneck.expansion * planes)
            )
        layers = []
        layers.append(Bottleneck(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * Bottleneck.expansion
        for i in range(1,blocks):
            layers.append(Bottleneck(self.inplanes, planes))
        return nn.Sequential(*layers)
    #自上而下的上采样模块
    def _upsample_add(self, x, y):
        _,_,H,W = y.shape
        return F.interpolate(x,size=(H,W),mode='bilinear',align_corners=False) + y
    def forward(self,x):
        # 自下而上
        c1 = self.maxpool(self.relu(self.bn1(self.conv1(x))))
        c2 = self.layer1(c1)
        c3 = self.layer2(c2)
        c4 = self.layer3(c3)
        c5 = self.layer4(c4)

        #自上而下
        p5 = self.toplayer(c5)
        p4 = self._upsample_add(p5, self.latlayer1(c4))
        p3 = self._upsample_add(p4, self.latlayer2(c3))
        p2 = self._upsample_add(p3, self.latlayer3(c2))

        #卷积融合,平滑处理
        p4 = self.smooth1(p4)
        p3 = self.smooth2(p3)
        p2 = self.smooth3(p2)
        return p2, p3, p4, p5

测试网络:

net_fpn = FPN([3,4,6,3]).cuda()
print(net_fpn.layer2)
#输出layer2网络结构信息
>>>
Sequential(
  (0): Bottleneck(
    (bottleneck): Sequential(
      (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (bottleneck): Sequential(
      (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (relu): ReLU(inplace=True)
  )
  (2): Bottleneck(
    (bottleneck): Sequential(
      (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (relu): ReLU(inplace=True)
  )
  (3): Bottleneck(
    (bottleneck): Sequential(
      (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (relu): ReLU(inplace=True)
  )
)

Process finished with exit code 0

网络输出:

net_fpn = FPN([3,4,6,3]).cuda()
input = torch.randn(1,3,224,224).cuda()
output = net_fpn(input)
print(output[0].shape)
#输出P2尺度
>>>
torch.Size([1, 256, 56, 56])