Pytorch 模型修改
修改模型层
classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 128)),
('relu1', nn.ReLU()),
('dropout1',nn.Dropout(0.5)),
('fc2', nn.Linear(128, 10)),
('output', nn.Softmax(dim=1))
]))
net.fc = classifier
添加外部输入
有时候在模型训练中,除了已有模型的输入之外,还需要输入额外的信息。
基本思路是:将原模型添加输入位置前的部分作为一个整体,同时在forward中定义好原模型不变的部分、添加的输入和后续层之间的连接关系,从而完成模型的修改。
添加输出
基本的思路是修改模型定义中forward函数的return变量。