Pytorch实战 | 彩色图片识别

参考:
Pytorch实战 | 第P2周:彩色图片识别
代码:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

device = torch.device( "cpu")

# 数据变换格式
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=0.5, std=[0.5])])

data_train = torchvision.datasets.CIFAR10(root='./data/',
                                          transform=transforms,
                                          download=True,
                                          train=True)

data_test = torchvision.datasets.CIFAR10(root='./data/',
                                         transform=transforms,
                                         train=False,
                                         download=False)

data_train, _ = torch.utils.data.random_split(dataset=data_train,
                                              lengths=[1000, 49000],
                                              generator=torch.Generator().manual_seed(0))
data_test, _ = torch.utils.data.random_split(dataset=data_test,
                                             lengths=[1000, 9000],
                                             generator=torch.Generator().manual_seed(0))

batch_size = 4

data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
                                                batch_size=batch_size,
                                                shuffle=True)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                               batch_size=batch_size,
                                               shuffle=True)

import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        torch.nn.Sequential()
        # 特征提取网络
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)  # 第一层卷积,卷积核大小为3*3
        self.pool1 = nn.MaxPool2d(kernel_size=2)  # 设置池化层,池化核大小为2*2
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3)  # 第二层卷积,卷积核大小为3*3
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)  # 第二层卷积,卷积核大小为3*3
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        # 分类网络
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 10)

    # 前向传播
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))

        x = torch.flatten(x, start_dim=1)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x


from torch.autograd import Variable

model = Model()
model.to(device)
cost = torch.nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.Adam(model.parameters())
epochs = 5
for i in range(epochs):
    running_loss = 0.0
    running_correct = 0
    print("Epoch{}/{}".format(i + 1, epochs))
    print("-" * 10)
    for data in data_loader_train:
        X_train, Y_train = data
        X_train, Y_train = X_train.to(device), Y_train.to(device)
        # X_train, Y_train = X_train.to(device), Y_train.to(device)
        X_train, Y_train = Variable(X_train), Variable(Y_train)
        outputs = model(X_train)
        _, pred = torch.max(outputs, 1)
        optimizer.zero_grad()
        loss = cost(outputs, Y_train)
        loss.backward()
        optimizer.step()
        running_loss += loss.data
        running_correct += torch.sum(pred.data == Y_train.data)

    testing_correct = 0

    for data in data_loader_test:
        X_test, Y_test = data
        X_test, Y_test = X_test.to(device), Y_test.to(device)
        X_test, Y_test = Variable(X_test), Variable(Y_test)
        outputs = model(X_test)
        _, pred = torch.max(outputs.data, 1)
        testing_correct += torch.sum(pred.data == Y_test.data.data).data
        print("Loss is:{:.4f},Train Accuracy is:{:.4f}%, Test Accuracy is:{:.4f}".format(running_loss / len(data_train),
                                                                                         100 * running_correct / len(
                                                                                             data_train),
                                                                                         100 * testing_correct / len(
                                                                                             data_test)))

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

回家种蜜柚

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值