打印模型参数量
total_params = sum(p.numel() for p in model.parameters())
print('参数量为:')
print(f'{total_params:,} total parameters.')
添加位置为模型训练模块下,需要注意的部分是,model.parameters()中的model为训练模型如下:
def Train(model, config, train_loader, test_loader):
start_time = time.time()
model.train()
# 打印模型参数量
total_params = sum(p.numel() for p in model.parameters())
print('参数量为:')
print(f'{total_params:,} total parameters.')