import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import math
import copy
import numpy as np
from scipy import ndimage
from os.path import join as pjoin
import ml_collections
def get_b16_config():
"""Returns the ViT-B/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (14, 14)})
config.hidden_size = 168
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 3072
config.transformer.num_heads = 12
config.transformer.num_layers = 12
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'seg'
config.representation_size = None
config.resnet_pretrained_path = None
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'
config.patch_size = 16
config.decoder_channels = (256, 128, 64, 16)
config.n_classes = 2
config.activation = 'softmax'
return config
def get_r50_b16_config():
"""Returns the Resnet50 + ViT-B/16 configuration."""
config = get_b16_config()
config.patches.grid = (14, 14)
config.resnet = ml_collections.ConfigDict()
config.resnet.num_layers = (3, 4, 9)
config.resnet.width_factor = 1
config.classifier = 'seg'
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
config.decoder_channels = (256, 128, 64, 16)
config.skip_channels = [512, 256, 64, 16]
config.n_classes = 2
config.n_skip = 3
config.activation = 'softmax'
return config
config = get_r50_b16_config()
logger = logging.getLogger(__name__)
ATTENTION_Q = "MultiHeadDotProductAttention_1/query/"
ATTENTION_K = "MultiHeadDotProductAttention_1/key/"
ATTENTION_V = "MultiHeadDotProductAttention_1/value/"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out/"
FC_0 = "MlpBlock_3/Dense_0/"
FC_1 = "MlpBlock_3/Dense_1/"
ATTENTION_NORM = "LayerNorm_0/"
MLP_NORM = "LayerNorm_2/"
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
def np2th(weights, conv=False):
"""Possibly convert HWIO to OIHW."""
if conv:
weights = weights.transpose([3, 2, 0, 1])
return torch.from_numpy(weights)
class Embeddings(nn.Module):
"""Construct the embeddings from patch, position embeddings.
"""
def __init__(self, config, in_channels=3):
super(Embeddings, self).__init__()
self.hybrid = None
self.config = config
self.patch_embeddings = nn.Conv3d(in_channels=256,
out_channels=config.hidden_size,
kernel_size=1,
stride=1)
self.position_embeddings = nn.Parameter(torch.zeros(1, 216, config.hidden_size))
self.dropout = nn.Dropout(config.transformer["dropout_rate"])
def forward(self, x):
x = self.patch_embeddings(x)
x = x.flatten(2)
x = x.transpose(-1, -2)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class SelfAttention(nn.Module):
def __init__(self, channels, size):
super(SelfAttention, self).__init__()
self.channels = channels
self.size = size
self.embedding = Embeddings(config)
self.mha = nn.MultiheadAttention(channels, 4)
# self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
self.ln = nn.LayerNorm([channels])
self.ff_self = nn.Sequential(
nn.LayerNorm([channels]),
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels),
)
def forward(self, x):
# x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
embedding_output = self.embeddings(x)
x = x.view(-1, self.channels, x.shape[2] * x.shape[3] * x.shape[4]).swapaxes(1, 2)
x_ln = self.ln(x)
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
attention_value = attention_value + x
attention_value = self.ff_self(attention_value) + attention_value
return attention_value.swapaxes(2, 1).reshape(-1, self.channels, self.size, self.size, self.size)
class Attention(nn.Module):
def __init__(self, config, vis):
super(Attention, self).__init__()
self.vis = vis
self.num_attention_heads = config.transformer["num_heads"]
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.out = nn.Linear(config.hidden_size, config.hidden_size)
self.attn_dropout = nn.Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout = nn.Dropout(config.transformer["attention_dropout_rate"])
self.softmax = nn.Softmax(dim=-1)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
weights = attention_probs if self.vis else None
attention_probs = self.attn_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
return attention_output, weights
class Mlp(nn.Module):
def __init__(self, config):
super(Mlp, self).__init__()
self.fc1 = nn.Linear(config.hidden_size, config.transformer["mlp_dim"])
self.fc2 = nn.Linear(config.transformer["mlp_dim"], config.hidden_size)
self.act_fn = ACT2FN["gelu"]
self.dropout = nn.Dropout(config.transformer["dropout_rate"])
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias, std=1e-6)
nn.init.normal_(self.fc2.bias, std=1e-6)
def forward(self, x):
x = self.fc1(x)
x = self.act_fn(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class Block(nn.Module):
def __init__(self, config, vis):
super(Block, self).__init__()
self.hidden_size = config.hidden_size
self.attention_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config, vis)
def forward(self, x):
h = x
x = self.attention_norm(x)
x, weights = self.attn(x)
x = x + h
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
return x, weights
def load_from(self, weights, n_block):
ROOT = f"Transformer/encoderblock_{n_block}/"
with torch.no_grad():
query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
self.attn.query.weight.copy_(query_weight)
self.attn.key.weight.copy_(key_weight)
self.attn.value.weight.copy_(value_weight)
self.attn.out.weight.copy_(out_weight)
self.attn.query.bias.copy_(query_bias)
self.attn.key.bias.copy_(key_bias)
self.attn.value.bias.copy_(value_bias)
self.attn.out.bias.copy_(out_bias)
mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
self.ffn.fc1.weight.copy_(mlp_weight_0)
self.ffn.fc2.weight.copy_(mlp_weight_1)
self.ffn.fc1.bias.copy_(mlp_bias_0)
self.ffn.fc2.bias.copy_(mlp_bias_1)
self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
class Encoder(nn.Module):
def __init__(self, config, vis):
super(Encoder, self).__init__()
self.vis = vis
self.layer = nn.ModuleList()
self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
for _ in range(config.transformer["num_layers"]):
layer = Block(config, vis)
self.layer.append(copy.deepcopy(layer))
def forward(self, hidden_states):
attn_weights = []
for layer_block in self.layer:
hidden_states, weights = layer_block(hidden_states)
if self.vis:
attn_weights.append(weights)
encoded = self.encoder_norm(hidden_states)
return encoded, attn_weights
class Conv3dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
bn = nn.BatchNorm3d(out_channels)
super(Conv3dReLU, self).__init__(conv, bn, relu)
class Detransformer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
head_channels = 256
self.conv_more = Conv3dReLU(
config.hidden_size,
head_channels,
kernel_size=3,
padding=1,
use_batchnorm=True,
)
def forward(self, hidden_states):
B, n_patch, hidden = hidden_states.size()
slices = 6
h, w = int(np.sqrt(n_patch/slices)), int(np.sqrt(n_patch/slices))
x = hidden_states.permute(0, 2, 1)
x = x.contiguous().view(B, hidden, h, w, 6)
x = self.conv_more(x)
return x
BN_MOMENTUM = 0.1
'''
[conv -> bn -> relu -> conv -> bn -> Residual -> relu
'''
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
'''
2x[conv -> bn -> relu] -> Residual -> relu
'''
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1,
bias=False)
self.bn3 = nn.BatchNorm3d(planes * self.expansion,
momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class StageModule(nn.Module):
def __init__(self, input_branches, output_branches, c):
"""
构建对应stage,即用来融合不同尺度的实现
:param input_branches: 输入的分支数,每个分支对应一种尺度
:param output_branches: 输出的分支数
:param c: 输入的第一个分支通道数
"""
super().__init__()
self.input_branches = input_branches
self.output_branches = output_branches
self.branches = nn.ModuleList() # 存储每一个branch上的block
for i in range(self.input_branches): # 每个分支上都先通过不同个BasicBlock
w = c * (2 ** i) # 对应第i个分支的通道数,每一层的通道数要翻倍
branch = nn.Sequential(
BasicBlock(w, w),
BasicBlock(w, w),
BasicBlock(w, w),
BasicBlock(w, w)
)
self.branches.append(branch) # 每一个分支上的Block已构建好
self.fuse_layers = nn.ModuleList() # 用于融合每个分支上的输出
for i in range(self.output_branches):
self.fuse_layers.append(nn.ModuleList())
for j in range(self.input_branches):
if i == j:
# 当输入、输出为同一个分支时不做任何处理
self.fuse_layers[-1].append(nn.Identity())
elif i < j:
# 当输入分支j大于输出分支i时(即输入分支下采样率大于输出分支下采样率),
# 此时需要对输入分支j进行通道调整以及上采样,方便后续相加
self.fuse_layers[-1].append(
nn.Sequential(
nn.Conv3d(c * (2 ** j), c * (2 ** i), kernel_size=1, stride=1, bias=False),
nn.BatchNorm3d(c * (2 ** i), momentum=BN_MOMENTUM),
nn.Upsample(scale_factor=2.0 ** (j - i), mode='trilinear', align_corners=True)
)
)
else: # i > j
# 当输入分支j小于输出分支i时(即输入分支下采样率小于输出分支下采样率),
# 此时需要对输入分支j进行通道调整以及下采样,方便后续相加
# 注意,这里每次下采样2x都是通过一个3x3卷积层实现的,4x就是两个,8x就是三个,总共i-j个
ops = []
# 前i-j-1个卷积层不用变通道,只进行下采样
for k in range(i - j - 1):
ops.append(
nn.Sequential(
nn.Conv3d(c * (2 ** j), c * (2 ** j), kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm3d(c * (2 ** j), momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
)
# 最后一个卷积层不仅要调整通道,还要进行下采样
ops.append(
nn.Sequential(
nn.Conv3d(c * (2 ** j), c * (2 ** i), kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm3d(c * (2 ** i), momentum=BN_MOMENTUM)
)
)
self.fuse_layers[-1].append(nn.Sequential(*ops))
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
# 每个分支通过对应的block
x = [branch(xi) for branch, xi in zip(self.branches, x)]
# 接着融合不同尺寸信息
x_fused = []
for i in range(len(self.fuse_layers)):
x_fused.append(
self.relu(
sum([self.fuse_layers[i][j](x[j]) for j in range(len(self.branches))]) # 第j个输出分支对 前面不同分支的输出进行处理,包括不处理(Indenty) 上采样x2 、 x4 ,相加
)
)
return x_fused
class ConvBlock(nn.Module):
def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
super(ConvBlock, self).__init__()
ops = []
for i in range(n_stages):
if i==0:
input_channel = n_filters_in
else:
input_channel = n_filters_out
ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class UpsamplingDeconvBlock(nn.Module):
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
super(UpsamplingDeconvBlock, self).__init__()
ops = []
if normalization != 'none':
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
else:
assert False
else:
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class Upsampling(nn.Module):
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
super(Upsampling, self).__init__()
ops = []
ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False))
ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class Transformer_HighResolutionNet(nn.Module):
def __init__(self, base_channel: int = 32, output_channels: int = 6, n_filters=16, normalization='none', device="cuda", num_classes=6, has_dropout=False, vis=False):
super().__init__()
'''
Stem层, 初始图像带步长卷积下采样了两次,变成1/4尺寸的特征图s和c=64)
然后进入Layer1. input: 1/4的尺寸+base channel * 4的通道。 只调整channel数
有两个分支,分支1再变为base channel /2 , 分支2变为 1/2尺寸+ base channel
'''
self.conv1 = nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm3d(32, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv3d(32, 32, kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm3d(32, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.sa1 = SelfAttention(128, 24)
self.device = device
# Stage1
downsample = nn.Sequential(
nn.Conv3d(32, 128, kernel_size=1, stride=1, bias=False),
nn.BatchNorm3d(128, momentum=BN_MOMENTUM)
)
'''
Layer1 在不同的stage 一直卷
'''
self.layer1 = nn.Sequential(
Bottleneck(32, 32, downsample=downsample), #ResNet bottleneck 操作,输入为c,输出为4c
Bottleneck(128, 32),
Bottleneck(128, 32),
Bottleneck(128, 32)
)
self.embedding = Embeddings(config)
self.sa2 = SelfAttention(32, 96)
self.transition1 = nn.ModuleList([ # 两个分支,1/4尺寸和1/8尺寸+
nn.Sequential(
nn.Conv3d(128, base_channel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm3d(base_channel, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
),
nn.Sequential(
nn.Sequential( # 这里又使用一次Sequential是为了适配原项目中提供的权重
nn.Conv3d(128, base_channel * 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm3d(base_channel * 2, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
)
])
# Stage2
self.stage2 = nn.Sequential(
StageModule(input_branches=2, output_branches=2, c=base_channel)
)
# transition2 ,先对Stage2输出的两个Block不做处理,下采样第二个Block
self.transition2 = nn.ModuleList([
nn.Identity(), # None, - Used in place of "None" because it is callable
nn.Identity(), # None, - Used in place of "None" because it is callable
nn.Sequential(
nn.Sequential(
nn.Conv3d(base_channel * 2, base_channel * 4, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm3d(base_channel * 4, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
)
])
# Stage3
self.stage3 = nn.Sequential(
StageModule(input_branches=3, output_branches=3, c=base_channel),
StageModule(input_branches=3, output_branches=3, c=base_channel),
StageModule(input_branches=3, output_branches=3, c=base_channel),
StageModule(input_branches=3, output_branches=3, c=base_channel)
)
# transition3
self.transition3 = nn.ModuleList([
nn.Identity(), # None, - Used in place of "None" because it is callable
nn.Identity(), # None, - Used in place of "None" because it is callable
nn.Identity(), # None, - Used in place of "None" because it is callable
nn.Sequential(
nn.Sequential(
nn.Conv3d(base_channel * 4, base_channel * 8, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm3d(base_channel * 8, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
)
])
# Stage4
# 注意,最后一个StageModule只输出分辨率最高的特征层
self.stage4 = nn.Sequential(
StageModule(input_branches=4, output_branches=4, c=base_channel),
StageModule(input_branches=4, output_branches=4, c=base_channel),
StageModule(input_branches=4, output_branches=1, c=base_channel)
)
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
# Final layer
self.final_layer = nn.Conv3d(base_channel*2, output_channels, kernel_size=1, stride=1)
self.embeddings = Embeddings(config)
self.transformer = Encoder(config, vis)
self.detransformer = Detransformer(config)
self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters * 2, normalization=normalization)
self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
self.out_conv = nn.Conv3d(n_filters, num_classes, 1, padding=0)
self.out_conv2 = nn.Conv3d(n_filters, num_classes, 1, padding=0)
self.tanh = nn.Tanh()
self.dropout = nn.Dropout3d(p=0.5, inplace=False)
def TransformerLayer(self, features):
x5 = features[3]
embedding_output = self.embeddings(x5)
transformer_output, attn_weights = self.transformer(embedding_output)
detransformer_output = self.detransformer(transformer_output)
features[3] = detransformer_output
return features
def decoder(self, features):
# x1 = features[0] #96,96,96
x2 = features[0] #48, 48,48
x3 = features[1] #24,24,24
x4 = features[2] #12, 12,12
x5 = features[3] #6,6,6
x5_up = self.block_five_up(x5)
x5_up = x5_up + x4 #12,12,12,
x6 = self.block_six(x5_up)
x6_up = self.block_six_up(x6)
x6_up = x6_up + x3 # 24,24,24
x7 = self.block_seven(x6_up)
x7_up = self.block_seven_up(x7)
x7_up = x7_up + x2 #48,48,48
# x8 = self.block_eight(x7_up)
# x8_up = self.block_eight_up(x8)
# x8_up = x8_up + x1 # 96,96,96
# x9 = self.block_nine(x7_up)
# if self.has_dropout:
# x9 = self.dropout(x9)
# out = self.out_conv(x9)
#out_tanh = self.tanh(out)
#out_seg = self.out_conv2(x9)
x8 = self.block_eight(x7_up)
x8_up = self.block_eight_up(x8)
return x8, x8_up
def load_from(self, weights):
with torch.no_grad():
res_weight = weights
self.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
self.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
self.transformer.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
self.transformer.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
posemb_new = self.embeddings.position_embeddings
if posemb.size() == posemb_new.size():
self.embeddings.position_embeddings.copy_(posemb)
elif posemb.size()[1]-1 == posemb_new.size()[1]:
posemb = posemb[:, 1:]
self.embeddings.position_embeddings.copy_(posemb)
else:
logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
ntok_new = posemb_new.size(1)
if self.classifier == "seg":
_, posemb_grid = posemb[:, :1], posemb[0, 1:]
gs_old = int(np.sqrt(len(posemb_grid)))
gs_new = int(np.sqrt(ntok_new))
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
posemb = posemb_grid
self.embeddings.position_embeddings.copy_(np2th(posemb))
for bname, block in self.transformer.named_children():
for uname, unit in block.named_children():
unit.load_from(weights, n_block=uname)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x0 = x
residual = x
# x0 ,residual tensor 2,32,96,96,96
# print(residual.shape)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x) # stem
# x = self.sa1(x)
# x0 is a tensor # (2,32,48,48,48)
x = self.layer1(x)
# (2,32,48,48,48)
# (2,128,48,48,48)
x = [trans(x) for trans in self.transition1] # x变成了一个列表。每个Stage有好几个输出
# (16,128,24,24,24)
x1 = x
# x1 is a list [0, 1] -> {[2,32,48,48,48], [2,64,24,24,24]}
x = self.stage2(x) # 把前一层的x输入传入
x = [
self.transition2[0](x[0]),
self.transition2[1](x[1]),
self.transition2[2](x[-1])
] # New branch derives from the "upper" branch only
x2 = x
# x2 is a list [0, 1, 2] -> {[2,32,48,48,48], [2,64,24,24,24], [2,128,12,12,12]}
x = self.stage3(x)
x = [
self.transition3[0](x[0]),
self.transition3[1](x[1]),
self.transition3[2](x[2]),
self.transition3[3](x[-1]),
] # New branch derives from the "upper" branch only
x3 = x
# x3 is a list [0, 1, 2, 3] -> {[2,32,48,48,48], [2,64,24,24,24], [2,128,12,12,12], [2,256,6,6,6]}
trans_features = self.TransformerLayer(x3) # list 0 1 2 3 32x48 64x24 128x12 256x6
x_hrnet = self.stage4(x) # list [0]= tensor 32,48,48,48
# x_t = self.stage4(trans_features) # 4个输入分支,1个输出分支 # stage4输出为1/2大小,需要上采用和Stem做Concat# list [0]= tensor 32,48,48,48
x_trans, x_trans_seg = self.decoder(trans_features) # tensor 32, 48,48,48
x_fuse =x_hrnet[0] + x_trans[0]
x = self.up(x_fuse) #(2,32,48,48,48)
residual = x_trans_seg[0] + residual
x = self.final_layer(torch.cat((x, residual),dim=1))
# print('x shape', x.shape)
return x
from torchsummary import summary
if __name__ == '__main__':
# torch.cuda.set_device(0)
network = Transformer_HighResolutionNet()
x = torch.randn(1, 1, 96, 96, 96)
device = torch.device("cuda:1")
net = network.to(device)
summary(net,(1,96,96,96))
#print(net(x, t, y).shape)
#print(net)
# summary(net(x, t, y),(1,96,96,96), device='cpu')
2023-1-30
最新推荐文章于 2024-04-07 15:22:43 发布