保存、加载模型
import torch
from torch import nn,optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets,transforms
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,),(0.5,))])
trainset = datasets.MNIST('MNIST_data/',download=True,train=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)
tetsset = datasets.MNIST('MNIST_data/',download=True,train=False,transform=transform)
testloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)
class Network(nn.Module):
def __init__(self,insize,outsize):
super().__init__()
self.insize = insize
self.outsize = outsize
self.fc1 = nn.Linear(insize,256)
self.fc2 = nn.Linear(256,128)
self.fc3 = nn.Linea