import torchvision.models as models
from torchinfo import summary
model = models.mobilenet_v3_small(pretrained=True,progress=True)
summary(model.features,input_size=(1,3,640,640))
type(model.features) # sequential支持切片操作
summary(model.features[:4],input_size=(1,3,640,640))# 可以取出对应的下采样特征图
分为三部分:features、avgpool、classifier
结果显示:
可以定义主干网路类:
class MobilenetV3(nn.Module):
def __init__(self, slice) :
super(MobilenetV3, self).__init__()
self.model = None
if silce == 1:
self.model = models.mobilenet_v3_small(pretrained=True).features[:4]
if silce == 2:
self.model = models.mobilenet_v3_small(pretrained=True).features[4:9]
if silce == 3:
self.model = models.mobilenet_v3_small(pretrained=True).features[9:]
def forward (self, x):
return self.model(x)
注册:
elif m is MobilenetV3:
C2 = args[0]
args = args[1:]