本期的作业内容为:
PyTorch实现Logistic regression
1.PyTorch基础实现代码
2.用PyTorch类实现Logistic regression,torch.nn.module写网络结构
代码为:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Hyper-parameters
input_size = 784
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001
# MNIST dataset (images and labels)
train_dataset = torchvision.datasets.MNIST(root='data/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='data/',
train=False,
transform&#