计算模型中参数数量的函数
def get_nb_trainable_params(model):
'''
Return the number of trainable parameters
'''
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])