首先下载Mnist数据集,解压后放入./
import numpy as np
import struct
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
#数据读入部分
def readfile(tgt):
times = {'train':60000,'t10k':10000}
def get_image(buf1):
image_index = 0
image_index += struct.calcsize('>IIII')
im = []
for i in range(times[tgt]):
temp = struct.unpack_from('>784B', buf1, image_index)
im.append(np.reshape(temp, (28, 28)))
image_index += struct.calcsize('>784B')
return im
def get_label(buf2):
label_index = 0
label_index += struct.calcsize('>II')
labels = []
for i in range(times[tgt]):
label = struct.unpack_from('>1B', buf2, label_index)
labels.append(label[0])
label_index += struct.calcsize('>1B')
return labels
with open(f'./{tgt}-images.idx3-ubyte', 'rb') as f1:
buf1 = f1.read()
im = get_image(buf1)
with open(f'./{tgt}-labels.idx1-ubyte', 'rb') as f2:
buf2 = f2.read()
label = get_label(buf2)
return im,label
train = ‘train’
test = ‘t10k’
a = readfile(train)
b = readfile(test)
X_train = a[0]
X_train = np.stack(X_train)
X_train = torch.tensor(X_train).float()
X_train.unsqueeze_(1)
y = a[1]
torch.manual_seed(0)
#定义网络,并实现sklearn风格接口
class Net(nn.Module):
def __init__(self, *args, **kwargs):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 1, (3, 3))
self.maxpoll1 = nn.MaxPool2d(kernel_size=(2, 2))
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(169,10)
self.fc2 = nn.Linear(10,10)
def forward(self, x, *args, **kwargs):
x = self.conv1(x)
x = self.maxpoll1(x)
x = self.flatten(x)
# print(x.shape)
x = self.fc1(x)
for i in range(5):
x = self.fc2(x)
x = F.softmax(x)
return x
def fit(self, X, y, epochs=10,batchsize = 10,need_categorize=True):
if need_categorize:
y = self.categorize(y).long()
self.criterion = nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(self.parameters())
lss = []
self.train()
for i in tqdm(range(epochs)):
ls = 0
start = 0
end = batchsize
cnt = 0
while(1):
if start >= X.shape[0] - 1:
break
if end >= X.shape[0] - 1:
end = X.shape[0] - 1
y_pred = self.forward(X[start:end,:,:,:])
# print(y[start:end,:].argmax(1))
self.loss = self.criterion(y_pred, y[start:end,:].argmax(1))