本博客基于github上某位大神的pytorch入门学习代码,在他的基础上加上了更详细的中文注释以及不懂的模块使用方法。github连接:https://github.com/yunjey/pytorch-tutorial
逻辑回归模型
运行代码之前,请确定当前环境下已经安装torch、torchvision。
# 导入的包
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# 超参数设置
input_size = 784 # 输入的大小(28*28)
num_classes = 10 # 输出分类的个数
num_epochs = 5 # 迭代的次数
batch_size = 100 # 一次输入的数据量
learning_rate = 0.001 # 学习率
# 加载训练集
# 通过torchvision的这个方法从网上下载MINIST数据集
train_dataset = torchvision.datasets.MNIST(root="../../data/minist", train=True, transform=transforms.ToTensor(), download=True)
# 加载测试集
test_dataset = torchvision.datasets.MNIST(root="../../data/minist", train=False, transform=transforms.ToTensor(