class wide_resnet50_compress4(nn.Module):
"""
使用wide_resnet50_2将图片压为h/4,w/4
"""
def __init__(self):
super(wide_resnet50_compress4,self).__init__()
self.model = models.wide_resnet50_2(pretrained=True)
# 冻结模型的参数
model = self.model
for param in self.model.parameters():
param.requires_grad = False
model.layer2 = nn.Identity()#转到layer2
# 创建一个新的模型,仅包含前面的卷积层
self.model1 = nn.Sequential(
*list(model.children())[:-4] # 移除最后两个模块(卷积层)
)
def forward(self,x):
x=self.model1(x)
return x
class wide_resnet50_compress8(nn.Module):
"""
使用wide_resnet50_2将图片压为h/8,w/8
"""
def __init__(self):
super(wide_resnet50_compress8,self).__init__()
self.model = models.wide_resnet50_2(pretrained=True)
# 冻结模型的参数
model = self.model
for param in self.model.parameters():
param.requires_grad = False
model.layer3 = nn.Identity()#转到layer3
# 创建一个新的模型,仅包含前面的卷积层
self.model1 = nn.Sequential(
*list(model.children())[:-3] # 移除最后两个模块(卷积层)
)
def forward(self,x):
x=self.model1(x)
return x
class wide_resnet50_compress16(nn.Module):
"""
使用wide_resnet50_2将图片压为h/16,w/16
"""
def __init__(self):
super(wide_resnet50_compress16,self).__init__()
self.model = models.wide_resnet50_2(pretrained=True)
# 冻结模型的参数
model=self.model
for param in self.model.parameters():
param.requires_grad = False
model.layer4 = nn.Identity()#转到layer3
# 创建一个新的模型,仅包含前面的卷积层
self.model1 = nn.Sequential(
*list(model.children())[:-2] # 移除最后两个模块(卷积层)
)
def forward(self,x):
x=self.model1(x)
return x
data=torch.ones(1,3,512,512)
model=wide_resnet50_compress8()
summary(model,(3,512,512),1,'cpu')#查看网络结构chan,w,h
torch.onnx.export(model, data, 'name.onnx',verbose=True)#保存网络结构图
print(model(data).shape)
----------------------------------------------------------------
Layer (type) Output Shape Param #
=====================================================
基于pytorch自带预训练模型的网络更改,图片压缩
于 2023-05-30 15:23:37 首次发布