超参数:epoch=200,batch_size=24(原paper),如果您没有足够的GPU内存,可以将bacth_size减少到12或6以节省内存。
import torch
from torch import nn
def no_weight_decay():
return {'absolute_pos_embed'}
def no_weight_decay_keywords():
return {'relative_position_bias_table'}
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
class Swin_Unet(nn.Module):
def __init__(self, img_size, patch_size, in_channels, num_classes,
embed_dim, depths, num_heads,
window_size, mlp_ratio, qkv_bias, qk_scale,
drop_rate, attn_drop_rate, drop_path_rate):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.num_features_up = int(embed_dim * 2)
self.mlp_ratio = mlp_ratio
# patch partition 和 linear embedding
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
norm_layer=nn.LayerNorm)
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding,
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
# build encoderStages and bottleneck layers,每个BasicLayer包含两个Swin Transformer Block和一个下采样
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=nn.LayerNorm,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) # bottleneck没有下采样
self.layers.append(layer)
# build decoder layers,解码器每个Stage
self.layers_up = nn.ModuleList()
self.concat_back_dim = nn.ModuleList()
for i_layer in range(self.num_layers):
concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
int(embed_dim * 2 ** (
self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
if i_layer == 0:
layer_up = PatchExpand(
input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer)
else:
layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
input_resolution=(
patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
depth=depths[(self.num_layers - 1 - i_layer)],
num_heads=num_heads[(self.num_layers - 1 - i_layer)],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum(
depths[:(self.num_layers - 1 - i_layer) + 1])],
norm_layer=nn.LayerNorm,
upsample=PatchExpand if (i_layer < self.num_layers - 1) else None)
self.layers_up.append(layer_up)
self.concat_back_dim.append(concat_linear)
self.norm = nn.LayerNorm(self.num_features)
self.norm_up = nn.LayerNorm(self.embed_dim)
self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size),
dim_scale=4, dim=embed_dim)
self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)
self.apply(_init_weights)
# Encoder and Bottleneck
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
x_down_sample = []
for layer in self.layers:
x_down_sample.append(x)
x = layer(x)
x = self.norm(x) # B L C
return x, x_down_sample
# Decoder and Skip connection
def forward_up_features(self, x, x_down_sample):
for inx, layer_up in enumerate(self.layers_up):
if inx == 0:
x = layer_up(x)
else:
x = torch.cat([x, x_down_sample[3 - inx]], -1)
x = self.concat_back_dim[inx](x)
x = layer_up(x)
x = self.norm_up(x) # B L C
return x
def up_x4(self, x):
H, W = self.patches_resolution
B, L, C = x.shape
assert L == H * W, "input features has wrong size"
x = self.up(x)
x = x.view(B, 4 * H, 4 * W, -1)
x = x.permute(0, 3, 1, 2) # B,C,H,W
x = self.output(x)
return x
def forward(self, x):
x, x_down_sample = self.forward_features(x)
x = self.forward_up_features(x, x_down_sample)
x = self.up_x4(x)
return x
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops