Author: Sijin Yu
[1] Ali Hatamizadeh, Vishwesh Nath, Yucheng Tang, Dong Yang, Holger R. Roth, and Daguang Xu. Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images. MICCAI, 2022.
[论文笔记] Swin UNETR 论文笔记: MRI 图像脑肿瘤语义分割1. Abstract2. Motivation & Contribution2.1 Motivation2.2 Contribution3. Model4. Experiment4.1 Dataset4.2 对比实验5. Code5.1 数据预处理和增强5.2 Swin UNETR 模型架构5.2.1 SwinTransformer5.2.2 UnetrBasicBlock5.2.3 UnetrUpBlock5.2.4 UnetOutBlock
脑肿瘤的语义分割是一项基本的医学影像分析任务, 涉及多种 MRI 成像模态, 可协助临床医生诊断病人并随后研究恶性实体的进展.
近年来, 完全卷积神经网络 (Fully Convolutional Neural Networks, FCNNs) 方法已成为 3D 医学影像分割的事实标准.
流行的 “U形” 网络架构在不同的 2D 和 3D 语义分割任务以及各种成像模式上实现了最先进的性能基准.
然而, 由于 FCNNs 中卷积层的核大小有限, 它们在建模长距离信息方面的性能是次优的, 这可能导致在分割大小不一的肿瘤时出现缺陷.
另一方面, Transformer 模型在多个领域展示了捕获长距离信息的卓越能力, 包括自然语言处理和计算机视觉.
受 ViT 及其变体成功的启发, 我们提出了一种名为 Swin UNEt TRansformers (Swin UNETR) 的新型分割模型.
具体来说, 3D 脑肿瘤语义分割任务被重新定义为序列到序列预测问题, 其中多模态输入数据被投影成一维嵌入序列, 并用作层级 Swin 变换器编码器的输入.
Swin Transformer 编码器使用移位窗口计算自注意力, 在五个不同的分辨率上提取特征, 并通过跳跃连接在每个分辨率上连接到基于FCNN 的解码器.
我们参加了 2021 年 BraTS 分割挑战赛, 我们提出的模型在验证阶段位列表现最佳的方法之一.
在医疗保健的人工智能领域, 特别是脑肿瘤分析中, 需要更先进的分割技术来准确划定肿瘤, 以便诊断和术前规划.
当前基于 CNN 的脑肿瘤分割方法由于其小感受野, 难以捕捉长距离依赖关系.
ViTs 在捕捉各种领域的长距离信息方面显示出潜力, 暗示其在改善医学图像分割中的适用性.
提出了一种新型架构, Swin UNEt TRansformers (Swin UNETR), 结合了 Swin Transformer 编码器与 U 形 CNN 解码器, 用于多模态三维脑肿瘤分割.
在 2021 年多模态脑肿瘤分割挑战 (BraTS) 中展示了 Swin UNETR 模型的有效性, 验证阶段取得了排名靠前的成绩, 并在测试中表现出竞争力.

将输入的图像打成 Patch.
输入的图像为
则图像变为一个 Patch 的序列, 序列长度为
在本文中, Patch size 为
对于每个 patch, 将其映射为一个嵌入维度为
对 3D tokens 应用 Swin Transformer.
一层 Swin Transformer Block 由两个子层组成: W-MSA, SW-MSA.
经过一层 Swin Transformer Block, 一个 3D tokens 每个方向上的分辨率变为原来的
W-MSA 和 SW-MSA 分别是规则的、循环移动的 partitioning multi-head self-attention, 如下图所示.

BraTS 2021

以下链接提供了使用Swin UNETR模型进行BraTS21脑肿瘤分割的教程:
下面是部分核心代码注释:
421from monai import transforms2
3train_transform = transforms.Compose(4 [ 5 # 读入图像6 transforms.LoadImaged(keys=["image", "label"]),7 8 # 将单通道的标签图像转换成多通道格式, 每个通道表示不同的肿瘤类别. (转换前是所有类别标签图共用一个单通道图像) transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),9 # 裁剪掉图像周围的背景区域10 transforms.CropForegroundd(11 keys=["image", "label"],12 source_key="image",13 k_divisible=[roi[0], roi[1], roi[2]],14 ),15 # 将图像随机裁剪为指定大小16 transforms.RandSpatialCropd(17 keys=["image", "label"],18 roi_size=[roi[0], roi[1], roi[2]],19 random_size=False,20 ),21 # 在0轴方向上随机翻转22 transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),23 # 在1轴方向上随机翻转24 transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),25 # 在2轴方向上随机翻转26 transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),27 # 对每个单独通道, 进行强度归一化, 且忽略0值28 transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),29 # 随机调整图像的强度, img = img * (1 + eps)30 transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),31 # 随机调整图像的强度, img = img + eps32 transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),33 ]34)35
36val_transform = transforms.Compose(37 [38 transforms.LoadImaged(keys=["image", "label"]),39 transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),40 transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),41 ]42)161def forward(self, x_in):2 if not torch.jit.is_scripting():3 self._check_input_size(x_in.shape[2:])4 hidden_states_out = self.swinViT(x_in, self.normalize)5 enc0 = self.encoder1(x_in)6 enc1 = self.encoder2(hidden_states_out[0])7 enc2 = self.encoder3(hidden_states_out[1])8 enc3 = self.encoder4(hidden_states_out[2])9 dec4 = self.encoder10(hidden_states_out[4])10 dec3 = self.decoder5(dec4, hidden_states_out[3])11 dec2 = self.decoder4(dec3, enc3)12 dec1 = self.decoder3(dec2, enc2)13 dec0 = self.decoder2(dec1, enc1)14 out = self.decoder1(dec0, enc0)15 logits = self.out(out)16 return logits组件的定义如下:
1211self.normalize = normalize2
3self.swinViT = SwinTransformer(4 in_chans=in_channels,5 embed_dim=feature_size,6 window_size=window_size,7 patch_size=patch_sizes,8 depths=depths,9 num_heads=num_heads,10 mlp_ratio=4.0,11 qkv_bias=True,12 drop_rate=drop_rate,13 attn_drop_rate=attn_drop_rate,14 drop_path_rate=dropout_path_rate,15 norm_layer=nn.LayerNorm,16 use_checkpoint=use_checkpoint,17 spatial_dims=spatial_dims,18 downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,19 use_v2=use_v2,20)21
22self.encoder1 = UnetrBasicBlock(23 spatial_dims=spatial_dims,24 in_channels=in_channels,25 out_channels=feature_size,26 kernel_size=3,27 stride=1,28 norm_name=norm_name,29 res_block=True,30)31
32self.encoder2 = UnetrBasicBlock(33 spatial_dims=spatial_dims,34 in_channels=feature_size,35 out_channels=feature_size,36 kernel_size=3,37 stride=1,38 norm_name=norm_name,39 res_block=True,40)41
42self.encoder3 = UnetrBasicBlock(43 spatial_dims=spatial_dims,44 in_channels=2 * feature_size,45 out_channels=2 * feature_size,46 kernel_size=3,47 stride=1,48 norm_name=norm_name,49 res_block=True,50)51
52self.encoder4 = UnetrBasicBlock(53 spatial_dims=spatial_dims,54 in_channels=4 * feature_size,55 out_channels=4 * feature_size,56 kernel_size=3,57 stride=1,58 norm_name=norm_name,59 res_block=True,60)61
62self.encoder10 = UnetrBasicBlock(63 spatial_dims=spatial_dims,64 in_channels=16 * feature_size,65 out_channels=16 * feature_size,66 kernel_size=3,67 stride=1,68 norm_name=norm_name,69 res_block=True,70)71
72self.decoder5 = UnetrUpBlock(73 spatial_dims=spatial_dims,74 in_channels=16 * feature_size,75 out_channels=8 * feature_size,76 kernel_size=3,77 upsample_kernel_size=2,78 norm_name=norm_name,79 res_block=True,80)81
82self.decoder4 = UnetrUpBlock(83 spatial_dims=spatial_dims,84 in_channels=feature_size * 8,85 out_channels=feature_size * 4,86 kernel_size=3,87 upsample_kernel_size=2,88 norm_name=norm_name,89 res_block=True,90)91
92self.decoder3 = UnetrUpBlock(93 spatial_dims=spatial_dims,94 in_channels=feature_size * 4,95 out_channels=feature_size * 2,96 kernel_size=3,97 upsample_kernel_size=2,98 norm_name=norm_name,99 res_block=True,100)101self.decoder2 = UnetrUpBlock(102 spatial_dims=spatial_dims,103 in_channels=feature_size * 2,104 out_channels=feature_size,105 kernel_size=3,106 upsample_kernel_size=2,107 norm_name=norm_name,108 res_block=True,109)110
111self.decoder1 = UnetrUpBlock(112 spatial_dims=spatial_dims,113 in_channels=feature_size,114 out_channels=feature_size,115 kernel_size=3,116 upsample_kernel_size=2,117 norm_name=norm_name,118 res_block=True,119)120
121self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)1551class SwinTransformer(nn.Module):2 """3 Swin Transformer based on: "Liu et al.,4 Swin Transformer: Hierarchical Vision Transformer using Shifted Windows5 <https://arxiv.org/abs/2103.14030>"6 https://github.com/microsoft/Swin-Transformer7 """8
9 def __init__(10 self,11 in_chans: int,12 embed_dim: int,13 window_size: Sequence[int],14 patch_size: Sequence[int],15 depths: Sequence[int],16 num_heads: Sequence[int],17 mlp_ratio: float = 4.0,18 qkv_bias: bool = True,19 drop_rate: float = 0.0,20 attn_drop_rate: float = 0.0,21 drop_path_rate: float = 0.0,22 norm_layer: type[LayerNorm] = nn.LayerNorm,23 patch_norm: bool = False,24 use_checkpoint: bool = False,25 spatial_dims: int = 3,26 downsample="merging",27 use_v2=False,28 ) -> None:29 """30 Args:31 in_chans: dimension of input channels.32 embed_dim: number of linear projection output channels.33 window_size: local window size.34 patch_size: patch size.35 depths: number of layers in each stage.36 num_heads: number of attention heads.37 mlp_ratio: ratio of mlp hidden dim to embedding dim.38 qkv_bias: add a learnable bias to query, key, value.39 drop_rate: dropout rate.40 attn_drop_rate: attention dropout rate.41 drop_path_rate: stochastic depth rate.42 norm_layer: normalization layer.43 patch_norm: add normalization after patch embedding.44 use_checkpoint: use gradient checkpointing for reduced memory usage.45 spatial_dims: spatial dimension.46 downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a47 user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.48 The default is currently `"merging"` (the original version defined in v0.9.0).49 use_v2: using swinunetr_v2, which adds a residual convolution block at the beginning of each swin stage.50 """51 super().__init__()52 self.num_layers = len(depths)53 self.embed_dim = embed_dim54 self.patch_norm = patch_norm55 self.window_size = window_size56 self.patch_size = patch_size57 self.patch_embed = PatchEmbed(58 patch_size=self.patch_size,59 in_chans=in_chans,60 embed_dim=embed_dim,61 norm_layer=norm_layer if self.patch_norm else None, # type: ignore62 spatial_dims=spatial_dims,63 )64 self.pos_drop = nn.Dropout(p=drop_rate)65 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]66 self.use_v2 = use_v267 self.layers1 = nn.ModuleList()68 self.layers2 = nn.ModuleList()69 self.layers3 = nn.ModuleList()70 self.layers4 = nn.ModuleList()71 if self.use_v2:72 self.layers1c = nn.ModuleList()73 self.layers2c = nn.ModuleList()74 self.layers3c = nn.ModuleList()75 self.layers4c = nn.ModuleList()76 down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample77 for i_layer in range(self.num_layers):78 layer = BasicLayer(79 dim=int(embed_dim * 2**i_layer),80 depth=depths[i_layer],81 num_heads=num_heads[i_layer],82 window_size=self.window_size,83 drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],84 mlp_ratio=mlp_ratio,85 qkv_bias=qkv_bias,86 drop=drop_rate,87 attn_drop=attn_drop_rate,88 norm_layer=norm_layer,89 downsample=down_sample_mod,90 use_checkpoint=use_checkpoint,91 )92 if i_layer == 0:93 self.layers1.append(layer)94 elif i_layer == 1:95 self.layers2.append(layer)96 elif i_layer == 2:97 self.layers3.append(layer)98 elif i_layer == 3:99 self.layers4.append(layer)100 if self.use_v2:101 layerc = UnetrBasicBlock(102 spatial_dims=3,103 in_channels=embed_dim * 2**i_layer,104 out_channels=embed_dim * 2**i_layer,105 kernel_size=3,106 stride=1,107 norm_name="instance",108 res_block=True,109 )110 if i_layer == 0:111 self.layers1c.append(layerc)112 elif i_layer == 1:113 self.layers2c.append(layerc)114 elif i_layer == 2:115 self.layers3c.append(layerc)116 elif i_layer == 3:117 self.layers4c.append(layerc)118 self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))119
120 def proj_out(self, x, normalize=False):121 if normalize:122 x_shape = x.size()123 if len(x_shape) == 5:124 n, ch, d, h, w = x_shape125 x = rearrange(x, "n c d h w -> n d h w c")126 x = F.layer_norm(x, [ch])127 x = rearrange(x, "n d h w c -> n c d h w")128 elif len(x_shape) == 4:129 n, ch, h, w = x_shape130 x = rearrange(x, "n c h w -> n h w c")131 x = F.layer_norm(x, [ch])132 x = rearrange(x, "n h w c -> n c h w")133 return x134
135 def forward(self, x, normalize=True):136 x0 = self.patch_embed(x)137 x0 = self.pos_drop(x0)138 x0_out = self.proj_out(x0, normalize)139 if self.use_v2:140 x0 = self.layers1c[0](x0.contiguous())141 x1 = self.layers1[0](x0.contiguous())142 x1_out = self.proj_out(x1, normalize)143 if self.use_v2:144 x1 = self.layers2c[0](x1.contiguous())145 x2 = self.layers2[0](x1.contiguous())146 x2_out = self.proj_out(x2, normalize)147 if self.use_v2:148 x2 = self.layers3c[0](x2.contiguous())149 x3 = self.layers3[0](x2.contiguous())150 x3_out = self.proj_out(x3, normalize)151 if self.use_v2:152 x3 = self.layers4c[0](x3.contiguous())153 x4 = self.layers4[0](x3.contiguous())154 x4_out = self.proj_out(x4, normalize)155 return [x0_out, x1_out, x2_out, x3_out, x4_out]501class UnetrBasicBlock(nn.Module):2 """3 A CNN module that can be used for UNETR, based on: "Hatamizadeh et al.,4 UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"5 """6
7 def __init__(8 self,9 spatial_dims: int,10 in_channels: int,11 out_channels: int,12 kernel_size: Sequence[int] | int,13 stride: Sequence[int] | int,14 norm_name: tuple | str,15 res_block: bool = False,16 ) -> None:17 """18 Args:19 spatial_dims: number of spatial dimensions.20 in_channels: number of input channels.21 out_channels: number of output channels.22 kernel_size: convolution kernel size.23 stride: convolution stride.24 norm_name: feature normalization type and arguments.25 res_block: bool argument to determine if residual block is used.26 """27
28 super().__init__()29
30 if res_block:31 self.layer = UnetResBlock(32 spatial_dims=spatial_dims,33 in_channels=in_channels,34 out_channels=out_channels,35 kernel_size=kernel_size,36 stride=stride,37 norm_name=norm_name,38 )39 else:40 self.layer = UnetBasicBlock( # type: ignore41 spatial_dims=spatial_dims,42 in_channels=in_channels,43 out_channels=out_channels,44 kernel_size=kernel_size,45 stride=stride,46 norm_name=norm_name,47 )48
49 def forward(self, inp):50 return self.layer(inp)631class UnetrUpBlock(nn.Module):2 """3 An upsampling module that can be used for UNETR: "Hatamizadeh et al.,4 UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"5 """6
7 def __init__(8 self,9 spatial_dims: int,10 in_channels: int,11 out_channels: int,12 kernel_size: Sequence[int] | int,13 upsample_kernel_size: Sequence[int] | int,14 norm_name: tuple | str,15 res_block: bool = False,16 ) -> None:17 """18 Args:19 spatial_dims: number of spatial dimensions.20 in_channels: number of input channels.21 out_channels: number of output channels.22 kernel_size: convolution kernel size.23 upsample_kernel_size: convolution kernel size for transposed convolution layers.24 norm_name: feature normalization type and arguments.25 res_block: bool argument to determine if residual block is used.26 """27 super().__init__()28 upsample_stride = upsample_kernel_size29 self.transp_conv = get_conv_layer(30 spatial_dims,31 in_channels,32 out_channels,33 kernel_size=upsample_kernel_size,34 stride=upsample_stride,35 conv_only=True,36 is_transposed=True,37 )38
39 if res_block:40 self.conv_block = UnetResBlock(41 spatial_dims,42 out_channels + out_channels,43 out_channels,44 kernel_size=kernel_size,45 stride=1,46 norm_name=norm_name,47 )48 else:49 self.conv_block = UnetBasicBlock( # type: ignore50 spatial_dims,51 out_channels + out_channels,52 out_channels,53 kernel_size=kernel_size,54 stride=1,55 norm_name=norm_name,56 )57
58 def forward(self, inp, skip):59 # number of channels for skip should equals to out_channels60 out = self.transp_conv(inp)61 out = torch.cat((out, skip), dim=1)62 out = self.conv_block(out)63 return out201class UnetOutBlock(nn.Module):2 def __init__(3 self, spatial_dims: int, in_channels: int, out_channels: int, dropout: tuple | str | float | None = None4 ):5 super().__init__()6 self.conv = get_conv_layer(7 spatial_dims,8 in_channels,9 out_channels,10 kernel_size=1,11 stride=1,12 dropout=dropout,13 bias=True,14 act=None,15 norm=None,16 conv_only=False,17 )18
19 def forward(self, inp):20 return self.conv(inp)