import torch.nn as nn
class OurNet(nn.Module):
def __init__(self):
super(OurNet, self).__init__()
self.conv1a = nn.Conv2d(3, 64, 3, padding=1, bias=False)
self.conv2a = nn.Conv2d(3, 64, 3, padding=1, bias=False)
self.conv3a = nn.Conv2d(3, 64, 3, padding=1, bias=False)
self.b3_2 = ResBlock(256, 256, 256)
self.not_training = [self.conv1a,self.b3_2]
def forward(self, x):
x=self.conv1a(x)
x=self.conv2a(x)
x=self.conv3a(x)
return x
#1.对参数可以设置初始化、2.是否需要required_grad(bool值,true or false)、3.参数的学习率
#这里定义哪些参数不更新
def train(self, mode=True):
super().train(mode) #继承原来pytorch提供的train函数,必须写
for layer in self.not_training: #第一次for循环layer=self.conv1a== nn.Conv2d(3, 64, 3, padding=1, bias=False) #第二次for循环layer=self.b3_2== ResBlock(256, 256, 256)
if isinstance(layer, torch.nn.Conv2d):
layer.weight.requires_grad = False
elif isinstance(layer, torch.nn.Module): #例如class ResBlock(nn.Module):写的类,实例化的对象就属于torch.nn.Module
for c in layer.children():
c.weight.requires_grad = False
if c.bias is not None:
c.bias.requires_grad = False
#获得模块后可以自定义初始化、冻结层
for layer in self.modules(): #冻结层
if isinstance(layer, torch.nn.BatchNorm2d):
layer.eval()
layer.bias.requires_grad = False
layer.weight.requires_grad = False
for layer in self.modules(): #自定义初始化
if isinstance(layer, torch.nn.Conv2d):
torch.nn.init.xavier_uniform_(layer.weight)
torch.nn.init.kaiming_normal_(layer.bias)
return
#设置参数的学习率
def get_parameter_groups(self):
groups = ([], [], [], [])
print('======================================================')
for m in self.modules():
if (isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.modules.normalization.GroupNorm)):
if m.weight.requires_grad:
if m in self.from_scratch_layers:
groups[2].append(m.weight)
else:
groups[0].append(m.weight)
if m.bias is not None and m.bias.requires_grad:
if m in self.from_scratch_layers:
groups[3].append(m.bias)
else:
groups[1].append(m.bias)
return groups
model.train()
param_groups = model.get_parameter_groups()
optimizer = torchutils.PolyOptimizer([
{'params': param_groups[0], 'lr': args.lr, 'weight_decay': args.wt_dec},
{'params': param_groups[1], 'lr': 2*args.lr, 'weight_decay': 0},
{'params': param_groups[2], 'lr': 10*args.lr, 'weight_decay': args.wt_dec},
{'params': param_groups[3], 'lr': 20*args.lr, 'weight_decay': 0}
], lr=args.lr, weight_decay=args.wt_dec, max_step=max_step)