莫烦pytorch学习记录--cnn
#cnn.py
import os
import time
import numpy as np
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils import data
import torch.utils.data as Data
from torch.autograd import Variable
import matplotlib.pyplot as plt
import pdb
EPOCH=1
BATCH_SIZE=50
LR=0.001
DOWNLOAD_MNIST=False
train_data = torchvision.datasets.MNIST(
root='../catVsDog/data',
train=True,
transform=torchvision.transforms.ToTensor(),#0-255==>0-1
download=DOWNLOAD_MNIST
)
# print(train_data.train_data.size())
# print(train_data.train_labels.size())
# plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
# plt.title('%i'%train_data.train_labels[0])
# plt.show()
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_data = torchvision.datasets.MNIST(root='../catVsDog/data',train=False)
# plt.imshow(test_data.test_data[0].numpy(), cmap='gray')
# plt.title('%i'%test_data.test_labels[0])
# plt.show()
test_x=Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d( #(1,28,28)
in_channels =1,
out_channels=16,
kernel_size=5,
stride=1,
padding=2,
), #(16,28,28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),#(16,14,14)
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),#(32,14,14)
nn.ReLU(),
nn.MaxPool2d(2)#(32,7,7)
)
self.out = nn.Linear(32*7*7,10)
def forward(self,x):
x=self.conv1(x)
x=self.conv2(x) #(batch, 32,7,7)
x=x.view(x.size(0),-1)#展平(batch, 32*7*7)
output = self.out(x)
return output
cnn=CNN()
# print(cnn)
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
for step, (x,y) in enumerate(train_loader):
b_x = Variable(x)
b_y = Variable(y)
output =cnn(b_x)
loss=loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if(step%50 ==0):
test_output = cnn(test_x)
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
#pdb.set_trace()
test_y2 = test_y.data.numpy()#pred_y和test_y都要转成numpy()不然识别率一直是0
accuracy = sum(pred_y == test_y2) / test_y.size(0)
print('Epoch: ',epoch, '| train loss: %.4f' % loss.item(), '|test accuracy: %.2f' %accuracy)
test_output = cnn(test_x[:10])
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:10].numpy(), 'real number')