参考package的使用方法:https://pytorch-cn.readthedocs.io/zh/latest/
神经网络可以通过 torch.nn 包来构建。一个典型的神经网络训练过程包括以下:
- 定义一个包含可训练参数的神经网络
- 迭代整个输入
- 通过神经网络处理输入
- 计算损失loss
- 反向传播梯度到神经网络的参数
- 更新网络的参数,典型的更新方法是:weight = weight - learning_rate *gradient
例子:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def