torch笔记十四 | 卷积网络

文章目录

1.LeNet网络

import torch
from torch import nn, optim
import torch.utils.data as Data
import torchvision.transforms as transforms
import torchvision
import sys
import time

# 1.读入数据集

mnist_train = torchvision.datasets.FashionMNIST(root=".", train=True, download=True,
                                                transform=transforms.ToTensor())

mnist_test = torchvision.datasets.FashionMNIST(root=".", train=False, download=True,
                                               transform=transforms.ToTensor())
batch_size = 256
if sys.platform.startswith("win"):
    num_workers = 0
else:
    num_workers = 8
train_iter = Data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = Data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# 2.搭建lenet
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.a1 = nn.Sigmoid()
        self.p1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.c2 = nn.Conv2d(6, 16, 5)
        self.a2 = nn.Sigmoid()
        self.p2 = nn.MaxPool2d(2, 2)

        self.l1 = nn.Linear(16*4*4, 120)
        self.a3 = nn.Sigmoid()
        self.l2 = nn.Linear(120, 84)
        self.a4 = nn.Sigmoid()
        self.l3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.c1(x)
        x = self.a1(x)
        x = self.p1(x)

        x = self.c2(x)
        x = self.a2(x)
        x = self.p2(x)

        x = x.view(x.shape[0], -1)      # 将特征图拉直
        x = self.l1(x)
        x = self.a3(x)
        x = self.l2(x)
        x = self.a4(x)
        y = self.l3(x)
        return y
net = LeNet()
net = net.cuda(0)					# 将net放到GPU

# 3.配置方法
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)

# 4.训练
start = time.time()
for epoch in range(10):
    for X, y in train_iter:
        X, y = X.cuda(), y.cuda()	# 将tensor放到GPU
        out = net(X)
        l = loss(out, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
    print("epoch: %d    loss: %f" % (epoch, l.item()))
print("time: %f" % (time.time()-start))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值