pytorch 预测手写体数字_pytorch CNN 手写数字识别

importtorchimporttorch.nn as nnimporttorch.utils.data as Dataimportglobimportosimportnumpy as npfrom PIL importImageimportdatetimefrom torchvision importtransformsimporttorch.nn.functional as F#6272=8x32x32

EPOCH= 1BATCH_SIZE= 50

classMyNet(nn.Module):def __init__(self):

super(MyNet, self).__init__()

self.con1=nn.Sequential(

nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),

nn.MaxPool2d(kernel_size=2),

nn.ReLU(),

)

self.con2=nn.Sequential(

nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),

nn.MaxPool2d(kernel_size=2),

nn.ReLU(),

)

self.fc=nn.Sequential(#线性分类器

nn.Linear(128*7*7, 128), #修改图片大小后要重新计算

nn.ReLU(),

nn.Linear(128, 10),#nn.Softmax(dim=1),

)

self.mls=nn.MSELoss()

self.opt= torch.optim.Adam(params=self.parameters(), lr=1e-3)

self.start=datetime.datetime.now()defforward(self, inputs):

out=self.con1(inputs)

out=self.con2(out)

out= out.view(out.size(0), -1) #展开成一维

out =self.fc(out)#out = F.log_softmax(out, dim=1)

returnoutdeftrain(self, x, y):

out=self.forward(x)

loss=self.mls(out, y)print('loss:', loss)

self.opt.zero_grad()

loss.backward()

self.opt.step()deftest(self, x):

out=self.forward(x)returnoutclassParseImage(object):def __init__(self):

self.transform1=transforms.Compose([

transforms.ToTensor(),#range [0, 255] -> [0.0,1.0] 归一化

]

)defget_data(self, path):#load_image()

#将图片转为矩阵,标签进行独热编码

x_data =[]

y_data=[]

img_path= glob.glob(path) #图片读取路径

for file inimg_path:

one_hot=[]

img=Image.open(file)#img = self.transform1(img)

#img = transforms.ToPILImage()(img)

data =img.getdata()

data=np.matrix(data)

data= np.reshape(data, (28, 28))#..手动归一化

data = data/255x_data.append(data)

name, ext=os.path.splitext(file)

label= name.split('-')[1]print('label', label)for i in range(10):if str(i) ==label:

one_hot.append(1)else:

one_hot.append(0)

y_data.append(one_hot)#先转为数组,在转为tensor

x_data =np.array(x_data)

y_data=np.array(y_data)

x_data=torch.from_numpy(x_data).float()#输入数据增加频道维度

x_data = torch.unsqueeze(x_data, 1)

y_data=torch.from_numpy(y_data).float()returnx_data, y_dataif __name__ == '__main__':

data=ParseImage()

train_path= 'D:/AI/MR_AIStudy/MNIST/dataset/train/*.png'test_path= 'D:/AI/MR_AIStudy/MNIST/dataset/test/*.png'x_data, y_data=data.get_data(train_path)

net=MyNet()#批训练

torch_dataset =Data.TensorDataset(x_data, y_data)

loader=Data.DataLoader(

dataset=torch_dataset,

batch_size=BATCH_SIZE,

shuffle=True,

num_workers=2,

)for epoch inrange(EPOCH):for step, (batch_x, batch_y) inenumerate(loader):print(step)

net.train(batch_x, batch_y)

torch.save(net,'net.pkl') #存储模型, 全部存储

#只测试的话加载模型即可

model = torch.load('net.pkl') #恢复模型

net =model

test_x, test_y=data.get_data(test_path)

predict=net.test(test_x)print(predict)

end=datetime.datetime.now()print('耗时:{}s'.format(end-net.start))

# 预测结果#tensor([[ 9.1531e-01, -2.5804e-02, 1.2001e-02, 8.3876e-03, -1.6330e-02,#-1.7501e-03, -1.0589e-02, 2.6951e-02, 2.1836e-02, -4.5546e-02],#[-6.4733e-02, 7.7697e-01, 2.2536e-02, 8.3758e-03, 4.2895e-02,#1.1602e-02, -3.0644e-02, 2.2412e-02, 1.1579e-01, 3.2196e-02],#[ 2.6631e-02, -5.3223e-02, 7.9808e-01, 6.0601e-03, 2.2453e-02,#-3.9522e-02, 3.4775e-02, 1.5853e-02, -6.9575e-03, 1.7208e-02],#[-1.3861e-02, -1.8332e-02, 4.9981e-02, 9.6510e-01, -1.5838e-02,#9.0347e-03, 1.9342e-02, -3.8044e-02, -5.7994e-03, 1.4480e-02],#[-2.0864e-03, -5.9021e-02, 6.5524e-02, -2.1486e-02, 1.0074e+00,#9.3356e-03, 1.0758e-02, 6.6142e-02, 1.4841e-02, 2.2529e-03],#[-8.4950e-02, -2.4841e-02, -7.7684e-02, 1.6404e-01, 4.3458e-02,#8.6580e-01, -3.5630e-02, 4.2452e-02, 7.0675e-02, 2.9663e-02],#[-5.4024e-02, -1.7111e-02, -3.7085e-03, 3.8194e-03, -3.0645e-02,#-4.4164e-02, 1.0109e+00, 4.4349e-03, 1.3218e-01, -2.2839e-02],#[-2.0932e-02, 6.4831e-03, -1.3301e-02, 2.8091e-02, -3.0815e-02,#-3.2140e-02, 5.2251e-03, 1.0215e+00, 3.2592e-02, 1.0505e-02],#[ 1.5922e-02, -3.9700e-02, 2.4425e-02, -1.7313e-04, -1.5997e-02,#-5.2336e-02, -7.7526e-04, -2.1901e-02, 9.7167e-01, 1.3339e-01],#[-1.9283e-02, 2.4373e-02, -7.5621e-02, 1.1338e-01, -5.7805e-02,#-5.2936e-03, 1.0090e-03, 2.2471e-02, -3.5736e-02, 1.1243e+00]],#grad_fn=)#耗时:0:09:59.665343s

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值