优化器选择
'''
优化器设置
'''
optimizer_ft=optim.Adam(params_to_update,lr=1e-2)
scheduler=optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1)#学习率每7个epoch衰减成原来的1/10
#最后一层已经logSoftmax()了,所以不能nn.CrossEntropyLoss()计算,nn.CrossEntropyLoss()相当于logSoftmax()和nn.NULLoss()整合
criterion=nn.NULLoss()
训练模块直接上代码(待续)
'''
训练模块
'''
import torch
from torchvision import models
import torch.optim as optim
import torch.nn as nn
import time
import copy
def train_model(model,dataloaders,criterion,optimizer,num_class=30,is_inception=False,filename=filename):
since=time.time()
best_acc=0
model.to(device)
val_acc_history=[]
train_acc_history=[]
train_losses=[]
valid_losses=[]
LRs=[optimizer.param_groups[0]['lr']]
best_model_wts=copy.deepcopy(model.state_dict())
for epoch in range(num_class):
print("Epoch{}/{}".format(epoch,num_class-1))
print('-'*10)
for phase in ['train','valid']:
if phase=='train':
model.train()
else :
model.eval()
running_loss=0.0
running_correct=0
for inputs,labels in dataloaders[phase]:
inputs=inputs.to(device)
labels=labels.to(device)
#清零
optimizer.zero_grad()
with torch.set_grad_enabled(phase=='train'):
if is_inception and phase=='train':
outputs,anx_outputs=model(inputs)
loss1=criterion(outputs,labels)
loss2=criterion(anx_outputs,labels)
loss=loss1+0.4*loss2
else:
outputs=model(input)
loss=criterion(outputs,labels)
#预测结果,预测值最大的
_,pred=torch.max(outputs,1)
#训练阶段更新权重
if phase=='train':
loss.backward()
optimizer.stap()
running_loss+=loss.item()*inputs.size(0)
running_correct+=torch.sum(preds=labels.data)
epoch_loss=running_loss/len(dataloaders[phase].dataset)
epoch_acc=running_correct.double()/len(dataloaders[phase].dataset)
time_elapsed=time.time()-(since)
print("Time elapsed(:.0f)m(:.0f)s".format(time_elapsed//60,time_elapsed%60))
print('{}loss:{:.4f}Acc:{:.4f}'.format(phase,epoch_loss,epoch_acc))
#保存最好模型
if phase=='valid' and epoch_acc>best_acc:
best_acc=epoch_acc
best_model_wts=copy.deepcopy(model.state_dict())
state={
'state_dict':model.state_dict(),
'best_acc':best_acc,
'optimizer':optimizer.state_dict(),
}
torch.save(state,filename)
if phase =='valid':
val_acc_history.append(epoch_acc)
valid_losses.append(epoch_loss)
scheduler.step(epoch)
if phase=='train':
train_acc_history.append(epoch_acc)
train_losses.append(epoch_loss)
print('Optimizer learning rate: {:.7f}'.format(optimizer.param_groups[0]['lr']))
LRs.append(optimizer.param_groups[0]['lr'])
print()
time_elapsed=time.time()-since
print('Training complete in {:.0f}m{:.0f}s'.format(time_elapsed//60,time_elapsed%60))
print("Best val Acc:{:4f}".format(best_acc))
model.load_state_dict(best_model_wts)
return model ,val_acc_history,train_acc_history,valid_losses,train_losses,LRs