1.重写,假设模型名叫alexnet,除了init、forward方法外,还包含若干个函数
创建一个model类,去继承alexnet
代码示例
class model(alexnet):
def __init__(self):
super(model,self).__init__()
#alexnet里的self变量继承
def forward(self,x):
x = self.func(x)
return x
2.嵌入,将alexnet模型嵌入到model模型中
from torch.nn import *
class model(nn.Module):
def __init__(self):
super(model,self).__init__()
self.head = nn.Conv2d(1,1,1)
self.alexnet = models.alexnet(pretrained=True) # func 1
for k,v in alexnet._modules.items(): # func2
key = 'self.' + k
globals()[key] = v
self.detect = nn.Linear(1000,10)
def forward(self,x):
x = self.head(x)
x = self.alexnet(x)
x = x.flatten()
x = self.detect(x)
return x
globals()[k] 可将字符串k转换为变量k