完整工程
工程目录结构 Code
import torch
import torch. optim as optim
import torch. nn as nn
from torch. utils. data import DataLoader
from torchvision. datasets import ImageFolder
import torchvision. models as models
import torchvision. transforms as transforms
import numpy as np
import copy
model = models. alexnet( pretrained= True )
model. classifier[ 6 ] = nn. Linear( in_features= 4096 , out_features= 2 )
transform = transforms. Compose( [ transforms. Resize( ( 227 , 227 ) ) ,
transforms. ToTensor( ) ] )
train_dataset = ImageFolder( root= './data/train' , transform= transform)
val_dataset = ImageFolder( root= './data/val' , transform= transform)
train_dataloader = DataLoader( dataset= train_dataset, batch_size= 4 , num_workers= 4 , shuffle= True )
val_dataloader = DataLoader( dataset= val_dataset, batch_size= 4 , num_workers= 4 , shuffle= False )
optimizer = optim. SGD( model. parameters( ) , lr= 0.001 , momentum= 0.9 )
loss_fc = nn. CrossEntropyLoss( )
scheduler = optim. lr_scheduler. StepLR( optimizer, 20 , 0.1 )
device = torch. device( 'cuda:0' if torch. cuda. is_available( ) else 'cpu' )
model. to( device)
epoch_nums = 50
best_model_wts = model. state_dict( )
best_acc = 0
for epoch in range ( epoch_nums) :
scheduler. step( )
running_loss = 0.0
epoch_loss = 0.0
correct = 0
total = 0
for i, sample_batch in enumerate ( train_dataloader) :
inputs = sample_batch[ 0 ]
labels = sample_batch[ 1 ]
inputs. to( device)
labels. to( device)
model. train( )
optimizer. zero_grad( )
outputs = model( inputs)
loss = loss_fc( outputs, labels)
loss. backward( )
optimizer. step( )
running_loss += loss. item( )
if i % 10 == 9 :
correct = 0
total = 0
for images_test, labels_test in val_dataloader:
model. eval ( )
images_test = images_test. to( device)
labels_test = labels_test. to( device)
outputs_test = model( images_test)
_, prediction = torch. max ( outputs_test, 1 )
correct += ( ( prediction == labels_test) . sum ( ) ) . item( )
total += labels_test. size( 0 )
accuracy = correct/ total
print ( '[{}, {}] running loss={:.5f}, accuracy={:.5f}' . format ( epoch + 1 , i + 1 , running_loss/ 10 , accuracy) )
running_loss = 0.0
if accuracy > best_acc:
best_acc = accuracy
best_model_wts = copy. deepcopy( model. state_dict( ) )
print ( 'Train finish' )
torch. save( best_model_wts, './models/model_50.pth' )
https://www.jianshu.com/p/2e5a9bd5ad36