初学深度学习时使用斋藤康毅的鱼书,见到其中代码涉及到函数调用时使用“**”来传递参数,有些不理解,故做以下记录,将关键代码提取如下:
以下为optimizer.py文件中对各优化算法的实现:
class Momentum:
def __init__(self, lr=0.01, momentum=0.9):
……
class Adam:
def __init__(self, lr=0.001, beta1=0.9, beta2=0.999):
……
以下为trainer.py文件中用于训练的代码,其中调用了optimizer.py文件中的函数:
class Trainer:
def __init__(self, ……, optimizer='SGD', optimizer_param={'lr':0.01}, ……):
……
optimizer_class_dict = {'sgd':SGD, 'momentum':Momentum, 'nesterov':Nesterov,
'adagrad':AdaGrad, 'rmsprpo':RMSprop, 'adam':Adam}
self.optimizer = optimizer_class_dict[optimizer.lower()](**optimizer_param)
……
如果训练时使用的是Adam算法,那么optimizer_class_dict[optimizer.lower()](**optimizer_param)就相当于是函数调用语句Adam(**optimizer_param),其中**optimizer_param会将optimizer_param字典中各关键字对应的参数传递给函数,以下结合简单代码进行理解:
def optimizer(lr=0.001, beta1=0.9, beta2=0.999):
print('lr:', lr)
print('beta1:', beta1)
print('beta2:', beta2)
optimizer_param={'lr': 1, 'beta1': 2, 'beta2': 3}
optimizer(**optimizer_param)
#打乱关键字顺序并不影响参数的传递,依旧按照参数名传递'
optimizer_param={'beta2': 3, 'beta1': 2, 'lr': 1}
optimizer(**optimizer_param)
输出结果为:
lr: 1
beta1: 2
beta2: 3
lr: 1
beta1: 2
beta2: 3
还可以用单星号“*”来进行参数传递,但单星号后面须使用list或tuple类型的变量,且参数关联方式必须使用位置实参:
optimizer_param=(1, 2, 3) #[1,2,3]
optimizer(*optimizer_param)
optimizer_param=(3, 2, 1) #[3,2,1]
optimizer(*optimizer_param)
输出结果:
lr: 1
beta1: 2
beta2: 3
lr: 3
beta1: 2
beta2: 1
如果字典搭配单星号使用,则输出如下:
lr: lr
beta1: beta1
beta2: beta2
lr: beta2
beta1: beta1
beta2: lr
如果列表或元组搭配双星号使用,则报错:
optimizer() argument after ** must be a mapping, not tuple
总结
只需牢记单星号搭配列表或元组使用,参数关联方式为位置实参;双星号搭配字典使用,参数关联方式为关键字实参。