超分辨率医疗图像的元分割网络

MSN:Meta Segmentation Network for Ultra-Resolution Medical Images Tong

相关名词:

whole-slide image (WSI):全视野数字切片

ultra-resolution image (URI):超分辨率图像

摘要:

基于多分支结构的方法能够权衡计算代价和分割准确率,然而融合结构需要精妙的设计,这导致计算量增加。用元学习的方式,让融合模块简单有效。

MSN通过元学习的方法很快生成网络的权重,只需要很少的训练样本和epoch就可以收敛。为了防止每个branch都从零开始训练,进一步提出权值共享让不同的分支共享权重,实现快速的知识融合,使得效果提升,参数量减少。

在BACH和ISIC数据集上实现了SOTA。

一、介绍

超分辨率图像size巨大,对计算量要求很高,像Unet,DeepLab这样的网络无法计算。

超分辨率处理两种方式:图片下采样(image downsampling)和滑动窗口(sliding patches)

图片下采样将图像调整为合适的大小,通常512$\times$512喂到模型里,后者将图像切分为很多patch,然后做patch级别的分割,再将这些patch的分割结果结合(combine)起来。这些方法减少计算量的同时,几乎放弃了空间上下文和邻域依赖提供的全局信息,很难获得准确的分割结果。

基于patch最新的是AWMF-CNN,多分支结构,包含了不同分辨率和尺度的目标区域和感受野的patch的上下文信息。

多分支方法挑战:

需要精心设计最后的融合机制,比如融合层有很多大量通道的卷积层的叠加,或者是辅助的加权网络。幸运的是通过元学习的融合方法,只要很简单的结构就能保证很好的效果。

其次,所有的分支都要从零开始各自训练。

元融合机制解决多分支第一个挑战:直接使用branch输出层的负梯度作为元信息去训练一个元学习者,直接预测融合网络的参数,他不像端到端的反向传播那样,因为他收敛更快,所以它不需要精心设计融合结构,元融合结构只包含两个卷积层和一个元学习器。

为了避免所有分支都从零开始训练,提出权值共享机制,尽管输入的分支放大倍数不同,仍然是同一个域中的。在权值共享机制中,采用特殊的记忆机制实现元分支和非元分支的知识适配。

元分支表示要共享到其他分支的参数,此外如果直接权值共享会带来知识鸿沟,为了弥补差距,使用记忆机制存储元分支信息,然后对元分支和非元分支进行记忆转换实现快速知识适配。

图1:采用BiSeNet作为backbone,在\(X_3\)分支训练,固定参数,然后依次输入\(X_1, X_2, X_3\),比较均值和方差。红色的那些层成为gap 层,很明显gap 层\(X_3\)的均值和方差很大。

贡献:

  • 用元学习提出了多分支超分辨率图像分割的元融合机制,融合层的权重能够由元学习器生成。
  • 新颖权值共享能够快速知识适配

二、提出的方法

2.1 MSN架构

图2:mainbody结构没有权值共享,三个分支是分离的,mainbody接收不同分辨率的图像patch,输出各自的分割结果,只有\(X_3\)通过meta分支不通过Mem-FP和Mem-RM,另外两支一起通过带有Mem-FP和Mem-RM的元分支,修复层之间的隔阂,Meta-FM以元学习的方式融合两个分支的结果,两个卷积层的输出通道数都是\(N_c\)\(N_c\)表示类别数量。

Mainbody主要包含三个部分:meta-branch,Mem-FP, Mem-RM。Mem-FP存储meta-branch的meta-feature;Mem-RM部署在非元分支中,目的是用存储在Mem-FP中的特征补充非元分支的特征。

mainbody输出两个初步分割结果\(S_1, S_2\),它通过元学习的方式融合得到最后的结果。

\(X_1, X_2, X_3\)输入的size相同,分辨率不同(16\(\times\), 4\(\times\), 1\(\times\)),送到3个分支,\(X_3\)的感受野最大,分辨率最小,\(X_1\)相反。

不同分辨率的knowledge有共性和差异,将低分辨率的分支作为元分支,用中等分辨率和高分辨率的权重分享给元分支,因为包含更多的信息,用Mem-FP和Mem-RM调整权重。

低分辨率的\(X_3\)送到元分支的non-gap convolution和gap convolution,在gap层中,获得的feature map记录到Mem-FP中。

至于高分辨率的图像\(X_1\),在固定meta分支的所有层之后,像\(X_3\)一样通过meta-branch的non-gap层,当遇到gap层,从Mem-FP中调出Memory,作为当前gap层输出的feature map送到Mem-RM中调整学习权重,\(X_2\)也是一样。此后,用meta-FM融合两个非元分支的输出结果,

🎁Memory Feature Pooling

就是个存储池,在CNN的一些层中\(X_1\)\(X_2, X_3\)的输出有差距,当处理低分辨率的\(X_3\),meta-branch的gap层输出的feature map称为meta-feature,保存在Mem-FP中,用来传给其他分支。实际上只有\(X_3\)通过meta-branch的时候,获得的meta-feature才被保存下来,另外两个是不保存的。

🎯Memory Recall Module

为了在权重共享机制中使得元分支适应其他分支,非元分支应该recall在Mem-FP的gap layer的缺失的特征。因此构造Mem-RM嵌入到meta-branch帮助meta-branch的gap layer找回(recall)memory。

\(X_1, X_2\)喂到非元分支,直到遇到gap layer,在gap layer中,将预先保留的meta-feature和来自非元分支对应的特征输入到Mem-RM中。

图3:Mem-RM结构,在meta-branch中的gap layer,Mem-RM使用Mem-FP中存的meta-feature实现memory recall。

输入包含两个部分:顶部输入\(X_3\)的meta-feature(A),底部输入\(X_2\)或者\(X_1\)的feature map,为了使A,B对齐,裁剪了meta-feature的中间区域,放大到和B相同的大小,然后拼接起来一起卷积。

\[\hat{B} = f(cat(B, up(crop(A)))) \]

\(\hat{B}\)表示Mem-RM最后的输出,\(f()\)是非线性转换函数,\(cat(), up(), crop()\)分别对应拼接,上采样,和裁剪操作。

🎨Meta Fusion Module

\(X_1, X_2\)捕捉到了\(X_3\)的特征,所以只需要融合\(X_1, X_2\)就好了,最常见的做法是用几十个精心设计的卷积层,通过优化器如SGD调整网络参数,但是这个过程要很多次迭代才能收敛。

负梯度包含预测值和真实值差异信息,借鉴负梯度理论,构造Meta-FM直接预测卷积层权值。Meta-FM两个分支输出的负梯度作为元信息,通过两个全连接层输出预测的权值。

\[W = [W_1, W_2] = f(\sigma) \]

\(W_1, W_2\)表示两个卷积层的参数,注意\(W_1, W_2\)应该被调整为权值矩阵,因为全连接层的输出是一个向量,\(f()\)是一个非线性函数,包含FC-ReLU-FC结构。\(\sigma\)是两个branch输出的梯度向量,

\[\sigma = cat(v(-\frac{\partial L(S_1, Y_1)}{\partial W_{o1}}), v(-\frac{\partial L(S_2', Y_1)}{\partial W_{o2}})) \]

L是loss函数,\(S_1\)\(X_1\)的预分割结果,\(Y_1\)是ground truth,\(S_2' = up(crop(S_2))\)\(W_{o1}, W_{o2}\)\(X_1, X_2\)分支输出层的权重,\(v()\)操作是将梯度矩阵转为列向量。

2.2 Loss Function

用交叉熵作为损失函数,

\[L(P, Y) = -\sum_i^N \sum_{j \in {P_i}} Y_{i, j} \; log P_{i, j} \]

P是预测的分割结果,Y记为对应的ground truth,N是样本总数,j表示\(P_i\)的第j个像素,这个损失函数用作多个分割结果,\(S_1, S_2, S_3\)以及融合结果S,来训练模型。

2.3 训练

采用3步训练MSN,

  1. 训练mainbody中的meta-branch获得元参数共享到其他分支中。

  2. 训练non-meta-branch的Mem-RM修补知识差距。

  3. Meta-FM学习多种分辨率的分割结果。

训练数据分为2部分,子训练集很小,训练集用于第1步,子训练集用于2,3步

训练meta-branch

低分辨率\(X_3\)喂进去获得分割结果\(S_3\),branch的参数用交叉熵\(S_3, Y_3\)

训练Mem-RM

在第1步之后,有了元参数,固定住;用子训练集训练Mem-RM减轻gap layer对meta branch的影响,将\(X_1, X_2\)输入到固定的元分支,用他们特定的Mem-RM得到对应的分割结果\(S_1, S_2\),然后用交叉熵\(L(S_1, Y_1)和L(S_2, Y_2)\)更新每个Mem-RM(对应\(X_1和X_2\)分支)。

训练Meta-FM

首先固定好训练的meta-branch和Mem-RM,用步骤2中的方式获得分割结果\(S_1, S_2\),在前面提到的操作之后(裁剪和拼接),喂到融合层,融合层的权重是Meta-FM生成的,最后获得融合结果S,由于填充到融合层之前对权重向量的reshape是可微的,因此可以在子训练集上最小化loss \(L(S, Y_1)\)来调整Meta-FM的参数。

三、实验

在BACH和ISIC上评估,指标是平均交并比(mIoU)和模型参数量

3.1 数据集

BACH有10张WSI,多种不同分辨率(1倍,4倍,16倍),4种类别,正常,良性,原位癌,浸润癌(后两种是恶性),将10张WSI切分为7:1:2对应训练集,子训练集和测试集。

ISIC包含2596张图片,包含两类标注,病变和正常,随机切分为训练(2077),子训练(360)和测试(157).

3.2 实现细节

用BiSeNet作为backbone,用patch 为256x256喂到MSN中,从左到右裁剪\(X_1\),除了最后一层的patch外没有重叠,对目标区域的中心对齐,裁剪出\(X_2,X_3\),如果裁剪patch超过了边界,那么填充0。对于BACH,使用专业工具OpenSlide读取WSI,最后对每个分辨率,获得训练,子训练和测试9520, 2379, 75603个patch。

ISIC,三种分辨率,4倍,2倍,1倍,原图有最大的分辨率,每个分辨率的patch的数量分别为21471(训练), 3001(子训练), 52166(测试),batch size为32, 元分支训练30个epoch,非元分支和Meta-FM只训练10个epoch,优化器Adam,初始学习率1e-4。

3.3 与SOTA比较

在BACH的结果

与5种方法比较,Unet,PSPNet,BiSeNet,DeepLab V3+, AWMF-CNN(前4种是通用语义分割任务,后1个是处理超分辨率医疗图的多分支方法)。

前4种方法训练30个epoch,AWMF-CNN两种方式:一是:先在3个分支上训10个epoch,然后训练融合部分,此后多分辨率分支和融合分支交替训练20个epoch。二是:固定训练分支,在融合分支上训30个epoch。

通过权值共享机制,非元分支提升了10%的mIoU,并且参数量和单网络差不多

在ISIC的结果

快速实现又不损失一般性,与最新的三种方法做比对

可视化

图4:BACH和ISIC数据集的可视化,第一行包括图像patch和对应的标签,第2行是backbone BiSeNet结果,三种分支分开训练,第3行是MSN的结果,前两列是非元分支,第3列是融合结果。

由于特殊的权值共享机制,非元分支的输出比BiSeNet的所有输出都好。

3.4 消融实验

权值共享机制的有效性

不仅能减少多分支结构的参数,也能实现分支之间的知识迁移,非元分支的结果能够由元分支提高。

比较了4种方法:

  • Meta-branch:只用meta-branch获得结果,没有修复gap layer
  • Multi-branch:所有分支都是单独训练
  • \(\mathrm{MSN}^{\dagger}\)
  • \(\mathrm { MSN }^{\dagger *}\)

所有的backbone都是BiSeNet,为了公平比较,还对比了前两种方法的融合机制

表4:gap 层权值共享的影响,on non-gap指的是只添加Mem-FP和Mem-RM到non-gap层。

权值共享机制能够消除元分支和非元分支的差距,用已有知识提高表现,融合结果也是以分支为基础的,因此最终的表现也会提升。

其次比较了非元分支和多分支的收敛性,如图5

不仅在分割上表现好,收敛得也更快,用训练集而不是子训练集能够取得更好的效果。

最后,为了探究gap 层的影响,尝试只在non-gap层添加Mem-RM,结果掉了一大半,证明修补元分支和非元分支是必要的。

Meta-fusion的有效性

为了验证meta-fusion的有效性,比较了三种方法:

  • w/o meta:没有meta融合机制,是相同的融合结构,端到端丛零开始训练
  • AWMF-CNN:AWMF-CNN的融合机制,介绍了一种重要性加权的网络,分支的权重不同,卷积层相同,从零开始训练。
  • MSN:将BiSeNet作为backbone,固定三种分辨率,在子训练集上对比结果

能看出融合机制是有效的,AWMF-CNN的融合结构很复杂,参数量很大。

进一步在BACH和ISIC说明融合的重要性,如料想的那样很快就收敛了,几乎只用了1个epoch。