pytorch入门-环境安装与CNN

目录

一、安装步骤

二、CNN

1、导入库

2、导入训练集与测试集

3、定义CNN

4、损失函数、优化函数

5、训练模型

6、测试模型

7、加载模型与预测


一、安装步骤

1、pytorch下载

确保是在安装好的python环境下进行,先进入Start Locally | PyTorch官网。

根据机器的配置选择相应的信息,然后将下面的代码从控制台运行即可

 2、环境验证

在python中导入库的名称为torch

import torch
torch.randn(2,2,2)

二、CNN

初学,数据集用的是MNist手写数据集,CNN的处理步骤如下,在这里直接继承torch.nn类然后设置参数即可。

1、导入库

这是需要导入的库,如果提示no moudel的话直接install即可

import torch
import networkx as nx
import matplotlib.pyplot as plt
import torch.nn as nn
import numpy as np
import pandas as pd
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.utils.data as data_utils
import time

2、导入训练集与测试集

如果训练集没有下载的话,添加一个download属性然后赋值为True即可

#导入训练集测试集,如果没有下载,download设置为True
train_data=dataset.MNIST("mnist-data",train=True,transform=transforms.ToTensor())
test_data=dataset.MNIST("mnist-data",train=False,transform=transforms.ToTensor())

3、定义CNN

第一句代码需要注意一下,如果你的电脑是m1的话device的括号就填mps,一般是填cpu或者suad

device = torch.device('mps')
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Linear(64 * 7 * 7, 10)
    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        y = self.fc(x)
        return y
cnn=CNN().to(device)

4、损失函数、优化函数

#损失函数
los=torch.nn.CrossEntropyLoss()
#优化函数
optime=torch.optim.Adam(cnn.parameters(), lr=0.01)

5、训练模型

这里在输入lmages和lab的时候注意也要加上to(device),device在前面定义过

#训练模型
start = time.time()
for epo in range(1):
    for i, (images,lab) in enumerate(train_loader):
        optime.zero_grad()
        images=images.to(device)
        lab=lab.to(device)
        out = cnn(images)
        loss=los(out,lab)
        loss.backward()
        optime.step()
        print("epo:{},i:{},loss:{}".format(epo+1,i,loss))
end = time.time()
print(end-start,"s")

6、测试与保存模型

loss = 0
total = 0
correct = 0
with torch.no_grad():
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)
        output = cnn(data)
        _,p=output.max(1)
        loss += los(output, targets)
        correct += (p == targets).sum()
        total += data.size(0)
loss = loss.item()/len(test_loader)
acc = correct.item()/total
print(loss,acc)
#保存模型
torch.save(cnn.state_dict(), 'model.pt')

7、加载模型与预测

这里可以展示看一下模型的具体详细信息

#加载模型
model = CNN().to(device='mps')
model.load_state_dict(torch.load('model.pt', map_location=torch.device('mps')))
model.eval()

 接着通过画图软件在本地制作了0-9十个手写数字图片,格式为bmp灰度图,为了方便输入,在制作的时候就将大小设置为了28*28这样输入的时候就不用再做处理。保存之后进行读取看下效果。

#读取本地图片
plt.rcParams['font.sans-serif']='Heiti TC'
plt.rcParams['axes.unicode_minus'] = False  # 负号正常显示
fig = plt.figure(figsize=(10, 4)) 
plt.title("手写的数字",fontsize=20)
path_img=[]
for i in range(10):
    img=cv2.imread(f"/Users/van/Downloads/bmp/number{i}.bmp",cv2.IMREAD_GRAYSCALE)
    path_img.append(f"/Users/van/Downloads/bmp/number{i}.bmp")
    ax = fig.add_subplot(2, 5, i + 1, xticks=[], yticks=[])
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.imshow(img)
plt.show()

接着就可以将这些图片先转化为torch.tensor类型然后输入到模型了。

#读识别开始识别
import cv2
fig = plt.figure(figsize=(10, 4)) 
plt.xlabel("识别结果",fontsize=20)
j=0
for i in path_img:
    
    img=cv2.imread(i,cv2.IMREAD_GRAYSCALE)
    ax = fig.add_subplot(2, 5, j + 1, xticks=[], yticks=[])
    plt.subplots_adjust(wspace=0, hspace=0.4)
    plt.imshow(img)
    imgtensor=torch.from_numpy(img.reshape((1,1,28,28)))
    inputs = imgtensor.to(torch.float32).to(device)
    output=model(inputs)
    _,p=output.max(1)
    j=j+1
    plt.title(f"预测结果为数字: {p.item()}")

 

 三、总结

        go on!

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值