传参
*args 不确定参数个数,以元组形式传入
**args 不确定参数个数,以字典形式传入
e.g.
def create_model(opt, step=0, **opt_kwargs):
if local_config is not None:
opt['path']['pretrain_model_G'] = os.path.join(local_config.checkpoint_path, os.path.basename(opt['path']['results_root'] + '.pth'))
for k, v in opt_kwargs.items():
opt[k] = v
model = opt['model']
model = model.split("-")[0]
M = find_model_using_name(model)
m = M(opt, step)
logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
return m
functional.partial
官方文档
functools.partial(func, /, *args, **keywords)
def partial(func, /, *args, **keywords):
def newfunc(*fargs, **fkeywords):
newkeywords = {**keywords, **fkeywords}
return func(*args, *fargs, **newkeywords)
newfunc.func = func
newfunc.args = args
newfunc.keywords = keywords
return newfunc
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
def find_model_using_name(model_name):
# Given the option --model [modelname],
# the file "models/modelname_model.py"
# will be imported.
model_filename = "models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
# In the file, the class called ModelNameModel() will
# be instantiated. It has to be a subclass of torch.nn.Module,
# and it is case-insensitive.
model = None
target_model_name = model_name.replace('_', '') + 'Model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower():
model = cls
if model is None:
print(
"In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % (
model_filename, target_model_name))
exit(0)
return model
Pytorch固定部分参数
参考
参考1
参考2
1.requires_grad设置为False
2.过滤优化器的参数
for k,v in model2.named_parameters():
if k in Layer1pre.keys():
v.requires_grad = False
params = filter(lambda p: p.requires_grad, model.parameters())
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model2.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(params, lr=learning_rate)
冻结之后再打开
for k,v in model.named_parameters():
v.requires_grad=True # 固定层打开
optimizer = optim.Adam(model.parameters(),lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=10, gamma=0.5)
或者
if epoch==50:
for parameter in deeplabv3.classifier.parameters():
if parameter.requires_grad==False:
parameter.requires_grad = True
optimizer.add_param_group({'params': parameter})