西瓜书《机器学习》课后答案——chapter5_5.10

这里使用了PyTorch深度学习库实现CNN。此网络的结构为卷积-池化-卷积-全连接-全连接。在程序的后面附上了实验结果,CNN在MNIST上的准确率可以达到99.16%。

"""
Author: Victoria
Created on 2017.9.25 13:00
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torchvision import datasets, transforms
import numpy
import matplotlib.pyplot as plt

class LeNet(nn.Module):
    def __init__(self, cuda):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(16, 64, kernel_size=5, stride=1)
        self.fc1 = nn.Linear(4096, 84)
        self.fc2 = nn.Linear(84, 10)
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
        self.use_cuda = cuda
        if cuda:
            self.cuda()

    def forward(self, x):
        #print x.size()
        x = self.relu(self.max_pool(self.conv1(x)))
        x = self.conv2(x)
        #print x.size()
        x = self.relu(x)
        x = x.view(-1, 4096)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x)

def train(model, train_loader, epochs):
        model.train()
        optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
        for epoch in range(epochs):
            for batch_idx, (data, target) in enumerate(train_loader):
                if model.use_cuda:
                    data, target = data.cuda(), target.cuda()
                data, target = Variable(data), Variable(target)
                optimizer.zero_grad()
                output = model(data)
                loss = F.nll_loss(output, target)
                loss.backward()
                optimizer.s
  • 1
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值