from matplotlib import pyplot as pltimport torch.optim as optimfrom torch.utils.data import DataLoaderfrom modle.cnn import *#是把一个模块中所有函数都导入进来;from data.dataset import Hyperspectral_DataSet#最高准确率和DEVICE指定best_acc=0train_loss=[]train_acc=[]test_loss=[]test_acc=[]device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')def main(): global best_acc #模型加载和指定 model=Hyperspectral_CNN(num_classes=10) model=model.to(device) #损失函数 优化器 BATCH EPOH指定 criterion=nn.CrossEntropyLoss() optimizer=optim.Adam(model.parameters(),lr=0.0001) BATCH_SIZE=128 EPOCH=100 #数据集预处理 train_db=Hyperspectral_DataSet('./hyperspectral-files/PaviaU.mat', './hyperspectral-files/PaviaU_gt.mat') #数据集划分和加载 train_db,val_db= torch.utils.data.random_split(train_db,[int(len(train_db)*0.8),int(len(train_db)*0.2)]) train_loader=DataLoader(train_db,batch_size=BATCH_SIZE,shuffle=True) test_loader=DataLoader(val_db,batch_size=BATCH_SIZE) for epoch in range(EPOCH): model.train() count = 0. all = 0 total_loss = 0. for batch_idx, (input, target) in enumerate(train_loader): #all计算 all+=input.shape[0] #把数据加载到GPU input=input.to(device) target=target.to(device) #input进行modle运算 output=model(input) #loss计算 loss=criterion(output,target) #预测结果求每一行最大的列标号 output=torch.argmax(output,1) #计算一个批次预测对的数量 count+=(output==target).sum() #total_loss提取 total_loss+=loss # 先将梯度归零 optimizer.zero_grad() # 反向传播计算得到每个参数的梯度值 loss.backward() # 通过梯度下降执行一步参数更新 optimizer.step() train_loss.append(total_loss ) train_acc.append(count / all) print('train_acc:',count/all) model.eval() count = 0. all = 0 loss = 0.0 #TEST不用梯度反向传播 with torch.no_grad(): for (input, target) in test_loader: #all计算 all+=input.shape[0] #把数据加载到GPU input.to(device) target.to(device) #input进行modle运算 output=model(input) #loss计算 loss=criterion(output,target) #预测结果求每一行最大的列标号 output=torch.argmax(output,1) #计算一个批次预测对的数量 count+=(output==target).sum() pass acc=count/all print('test_acc:',acc) test_loss.append(loss) test_acc.append(acc) #计算BEST_ACC if acc > best_acc: best_acc=acc torch.save(model.state_dict(),'./hyperspectral-checkpoints/best-{}.pth'.format(best_acc)) pass fig,ax =plt.subplots(1,2) ax[0].plot(range(EPOCH), train_loss, label='train_loss') ax[0].plot(range(EPOCH), test_loss, label='test_loss') ax[0].legend() ax[1].plot(range(EPOCH), train_acc, label='train_acc') ax[1].plot(range(EPOCH), test_acc, label='test_acc') ax[1].legend() #设置图例 plt.savefig('./result.png') passif __name__ == '__main__': main()
2021-10-31
最新推荐文章于 2023-11-01 17:58:35 发布