Pytorch的基本使用流程总结为七个步骤,方便理解。这里省去了部分细节。
PS: 内容纯属个人理解,难免有错误,告知即改!
Pytorch 七步成诗
-
第一步 : 定义module模块 ,即定义网络
generator = GeneratorDRRN()
-
第二步: 定义优化器,并告知优化器,模型的哪些参数需要学习
optim_generator = optim.Adam(generator.parameters(), lr=opt.generatorLR)
-
第三步: 定义损失函数:
content_criterion = nn.MSELoss()
-
第四步: 模型前馈运算
output = generator(input)
-
第五步: 计算损失
generator_content_loss = content_criterion(output, target)
-
第六步:计算梯度,基于损失反向计算梯度。注意梯度是累积的。
generator.zero_grad()
generator_content_loss .backward()
-
第七步:更新参数
optim_generator.step()
补充1 简单图示
补充2 优化器
优化器
for input, target in dataset:
# 清空梯度数据
optimizer.zero_grad()
# 推理
output = model(input)
# 计算loss
loss = loss_fn(output, target)
# 计算梯度
loss.backward()
# 更新参数
optimizer.step()