训练模型并保存
import torch
import torch. nn as nn
import torch. optim as optim
from torchvision import datasets, transforms, models
from torch. utils. data import Dataset
import sys
transform = transforms. Compose( [
transforms. RandomResizedCrop( 224 ) ,
transforms. RandomRotation( 20 ) ,
transforms. RandomHorizontalFlip( p= 0.5 ) ,
transforms. ToTensor( )
] )
root = "image"
train_dataset = datasets. ImageFolder( root + "/train" , transform)
test_dataset = datasets. ImageFolder( root + "/test" , transform)
train_loader = torch. utils. data. DataLoader( train_dataset, batch_size= 8 , shuffle= True )
test_loader = torch. utils. data. DataLoader( test_dataset, batch_size= 8 , shuffle= True )
classes = train_dataset. classes
classes_index = train_dataset. class_to_idx
print