1.加载数据集
这里我们加载的是mnist数据集,这里我直接下载下来的了。
import torch.utils.data as Data
import matplotlib.pyplot as plt
import numpy as np
# import keras
# plt.ion()
data = np.load('./data/mnist.npz')
# print(data.files)
X_test = data[data.files[0]][:1000]
X_train = data[data.files[1]][:20000]
y_train = data[data.files[2]][:20000]
y_test = data[data.files[3]][:1000]
# to Tensor
X_train = torch.Tensor(X_train)
X_train = X_train.unsqueeze(dim=1)
y_train = torch.Tensor(y_train).long()
X_test = torch.Tensor(X_test).unsqueeze(dim=1)
y_test = torch.Tensor(y_test).long()
2.模型的搭建
# hyper parameters
BATCH_SIZE = 64 # batch_size
LEARNING_RATE = 0.02 # learning_rate
EPOCH = 2 # epochs
torch_dataset = Data.TensorDataset(X_train, y_train)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,