一. DA-TransUNet/train.py
net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda()
print(net)
from torchsummary import summary
summary(net,input_size=(3,224,224),batch_size=1,device="cuda")
tensor = torch.rand(1,1, 3, 224, 224)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tensor = tensor.to(device)
from thop import profile
flops, params = profile(net, tensor)
print('flops: ', flops, 'params: ', params)
二. CMT-pytorch-master\model\Transformers\CMT\cmt.py
if __name__ == "__main__":
x = torch.randn(1, 3, 160, 160)
model = CmtTi()
print(model)
from torchsummary import summary
summary(model, input_size=(3,160,160), batch_size=1, device='cuda')
from thop import profile
x = torch.randn(1,1, 3, 160, 160)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tensor = x.to(device)
flops, params = profile(model, tensor)
print('flops: ', flops, 'params: ', params)
net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda()
print(net)
(base) ➜ work conda run -n base --no-capture-output --live-stream python /home/featurize/work/5DA/DA-TransUnet/DA-TransUNet/DA-TransUNet/train.py
/environment/miniconda3/lib/python3.10/site-packages/torch/nn/init.py:405: UserWarning: Initializing zero-element tensors is a no-op
warnings.warn("Initializing zero-element tensors is a no-op")
DA_Transformer(
(transformer): Transformer(
(embeddings): Embeddings(
(DAblock1): DANetHead(
(conv5a): Sequential(
(0): Conv2d(768, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=0.001, momentum=0.95, affine=True, track_running_stats=True)
(2): ReLU()
)
(conv5c): Sequential(
(0): Conv2d(768, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=0.001, momentum=0.95, affine=True, track_running_stats=True)
(2): ReLU()
)
(sa): PAM_Module(
(query_conv): Conv2d(48, 6, kernel_size=(1, 1), stride=(1, 1))
(key_conv): Conv2d(48, 6, kernel_size=(1, 1), stride=(1, 1))
(value_conv): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1))
(softmax): Softmax(dim=-1)
)
(sc): CAM_Module(
(softmax): Softmax(dim=-1)
)
net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda()
from torchsummary import summary
summary(net,input_size=(3,224,224),batch_size=1,device="cuda")
/environment/miniconda3/lib/python3.10/site-packages/torch/nn/init.py:405: UserWarning: Initializing zero-element tensors is a no-op
warnings.warn("Initializing zero-element tensors is a no-op")
----------------------------------------------------------------
Layer (type) Output Shape Param
================================================================
StdConv2d-1 [1, 64, 112, 112] 9,408
GroupNorm-2 [1, 64, 112, 112] 128
ReLU-3 [1, 64, 112, 112] 0
StdConv2d-4 [1, 256, 55, 55] 16,384
GroupNorm-5 [1, 256, 55, 55] 512
StdConv2d-6 [1, 64, 55, 55] 4,096
GroupNorm-7 [1, 64, 55, 55] 128
ReLU-8 [1, 64, 55, 55] 0
Conv2d-492 [1, 9, 224, 224] 1,305
Identity-493 [1, 9, 224, 224] 0
================================================================
Total params: 106,405,265
Trainable params: 106,405,265
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 5921468054582.70
Params size (MB): 405.90
Estimated Total Size (MB): 5921468054989.18
----------------------------------------------------------------
net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda()
tensor = torch.rand(1,1, 3, 224, 224)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tensor = tensor.to(device)
from thop import profile
flops, params = profile(net, tensor)
print('flops: ', flops, 'params: ', params)
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.UpsamplingBilinear2d'>.
flops: 25491496060.0 params: 94510417.0