pytorch
convtranspose2d的outc和inc是交换了位置的,需要注意一下。这样自己在保存的时候,可能需要用permute交换两个的维度。直接看具体的例子:
import torch
import torch.nn as nn
conv = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1)
transpose = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=1)
print("conv weight shape:{}".format(conv.weight.shape))
print("transpose2d weight shape:{}".format(transpose.weight.shape))
log
conv weight shape:torch.Size([16, 32, 3, 3])
transpose2d weight shape:torch.Size([32, 16, 3, 3])