相当于做个笔记吧,太久了,一些基础的知识记不太清了。
现在我想要搭建一个预训练的Resnet18,但是最终输出分类为21的分类模型。
import torch
import torchvision
from torch import nn
class Resnet18(nn.Module):
def __init__(self, n_class=21):
super().__init__()
self.n_class = n_class
pretrained_net = torchvision.models.resnet18(pretrained=True)
self.model=nn.Sequential(*list(pretrained_net.children())[:-1])
self.linear=nn.Linear(512,n_class)
def forward(self,x):
x=self.model(x).squeeze()
output=self.linear(x)
return output
if __name__ == '__main__':
net=Resnet18(21)
x=torch.ones((5,3,256,256))
y=net(x)
print(y.shape)