深度学习与PyTorch入门实战教程
P39 多分类问题代码
'''
Logistic Regression
使用sigmoid
MSE =>regression
softmax
交叉熵
熵 = 不确定性
Entropy = -∑p(x)log q(x)
'''
import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision.transforms import transforms
batch_size = 200
learning_rate = 0.01
epochs = 10
train_loader = torch.utils.data.DataLoader(
datasets.MNIST("../data", train = True, download = True, transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.1308,))
])),
batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data',train = False,transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.1308,))
])),
batch_size = batch_size, shuffle =True
)
w1, b1 = torch.randn(200, 784, requires_grad = True),\
torch.zeros(200,requires_grad=True) #784 input, 200 output
w2, b2 = torch.randn(200, 200,requires_grad = True), \
torch.zeros(200,requires_grad = True)
w3, b3 = torch.randn(10,200,requires_grad = True),\
torch.zeros(10,requires_grad = True)
#对w1,w2,w3初始化, 初始化很重要
torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)
def forward(x):
x = x@w1.t() + b1
x = F.relu(x)
x = x@w2.t() + b2
x = F.relu(x)
x = x@w3.t() +b3
#x = F.relu(x)
return x
optimizer = torch.optim.SGD([w1, b1, w2, b2, w3, b3], lr = learning_rate)
criteon = torch.nn.CrossEntropyLoss()
for epoch in range(epochs):
for batch_idx,(data,target) in enumerate(train_loader):
data = data.view(-1,28*28)
logits = forward(data)
loss = criteon(logits, target)
optimizer.zero_grad()
loss.backward()
#print(w1.gras.norm(),w2.grad.norm())
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch:{} [{}/{} ({:.0f}%)]\tLoss : {:.6f}'.format(
epoch, batch_idx * len(data),len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()
))
test_loss = 0
correct = 0
for data, target in test_loader:
data = data.view(-1, 28*28)
logits = forward(data)
test_loss += criteon(logits, target).item()
pred = logits.data.max(1)[1]
correct += pred.eq(target.data).sum()
test_loss /= len(test_loader.dataset)
print("\nTest set: Average loss : {: .4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)
))