【pytorch】训练网络通用架构

  1. WHAT IS TORCH.NN REALLY?
  2. pytorch中model eval和torch no grad()的区别
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

# 定义网络结构
class Model(nn.Module):
	def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1)

    def forward(self, xb):
        xb = xb.view(-1, 1, 28, 28)
        xb = F.relu(self.conv1(xb))
        xb = F.relu(self.conv2(xb))
        xb = F.relu(self.conv3(xb))
        xb = F.avg_pool2d(xb, 4)
        return xb.view(-1, xb.size(1))

# 网络实例化
model = Model()
# loss
loss_func = F.cross_entropy
# 优化器设置
opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=False)

for epoch in range(epochs):
	# 设置为训练模式
    model.train()
    # iterate: 每次一个batch
    for xb, yb in train_dl:
    	# 各权重参数的偏导清零 grad=>0
        opt.zero_grad()
    	# 前向传播
        pred = model(xb)
        # 计算损失
        loss = loss_func(pred, yb)
		# 反向传播,计算loss关于各权重参数的偏导,更新grad
        loss.backward()
        # 优化器基于梯度下降原则,更新(学习)权重参数parameters
        opt.step()
        
	# 设置为评估(推理)模式,设置BN、dropout等模块
    model.eval()
    # 不更新梯度
    with torch.no_grad():
        valid_loss = sum(loss_func(model(xb), yb) for xb, yb in valid_dl)

    print(epoch, valid_loss / len(valid_dl))
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值