# -*- coding: utf-8 -*-
"""
Created on Sat Jul 18 12:27:15 2020
@author: 陨星落云
"""
#%%
from torchvision import datasets
import torch
#%% 下载数据并加载训练集
path2data = "./data"
train_data = datasets.MNIST(path2data,train=True,download=False)
#%% 抽取训练集数据与标签
x_train,y_train = train_data.data,train_data.targets
print("x_train:",x_train.shape)
print("y_train:",y_train.shape)
#%% 加载验证集
val_data = datasets.MNIST(path2data,train=False,download=False)
#%% 抽取验证集数据与标签
x_val,y_val = val_data.data,val_data.targets
print("x_val:",x_val.shape)
print("y_val:",y_val.shape)
#%% 在张量中增加一个维度
if len(x_train.shape)==3:
x_train = x_train.unsqueeze(1)
print(