"""
一个简单的调用GPU的卷积神经网络示例!
"""
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
batch_size = 64
device = torch.device("cuda:0")
class CNN(torch.nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.convl1 = torch.nn.Conv2d(1,10,5)
self.convl2 = torch.nn.Conv2d(10,20,5)
self.pooling = torch.nn.MaxPool2d(2)
self.activate = torch.nn.ReLU()
self.linear = torch.nn.Linear(320,10)
def forward(self,x):
x = x.view(-1,1,28,28)
x = self.convl1(x)
x = self.pooling(x)
x = self.activate(x)
x = self.convl2(x)
x = self.pooling(x)
x = x.view(-1,320)
x = self.linear(x)
return x
def train(train_loader, model, criterion, optimizer, epoch):
loss_sum = 0.0
for index, (x, y) in enumerate(train_loader):
x = x.to(device)
y = y.to(device)
y_hat = model(x)
loss = criterion(y_hat, y)
loss_sum += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (index % 300 == 299):
print(epoch,",",index//300, ":", loss_sum/300)
loss_sum = 0.0
def test(test_loader, model):
total = 0
correct = 0
for x,y in test_loader:
x = x.to(device)
y = y.to(device)
y_hat = model(x)
_,guess = torch.max(y_hat,1)
correct += (guess == y).sum().item()
total += y.size(0)
print("ACC == ", correct / total)
if __name__ == '__main__':
transformer=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(0.1307, 0.3018)])
train_data = datasets.MNIST('MNIST',True,transformer,download=True)
test_data=datasets.MNIST('MNIST',True,transformer,download=True)
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True, num_workers=2)
test_loader = DataLoader(test_data,batch_size=batch_size,shuffle=True,num_workers=2)
model = CNN()
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
criterion.to(device)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
for epoch in range(10):
train(train_loader,model,criterion,optimizer,epoch)
test(test_loader,model)
Pytorch调用GPU实现简单的卷积神经网络
最新推荐文章于 2023-11-01 09:06:41 发布