2021-10-31

from matplotlib import  pyplot as pltimport torch.optim as optimfrom torch.utils.data import DataLoaderfrom modle.cnn import *#是把一个模块中所有函数都导入进来;from data.dataset import Hyperspectral_DataSet#最高准确率和DEVICE指定best_acc=0train_loss=[]train_acc=[]test_loss=[]test_acc=[]device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')def main():    global best_acc    #模型加载和指定    model=Hyperspectral_CNN(num_classes=10)    model=model.to(device)    #损失函数 优化器 BATCH EPOH指定    criterion=nn.CrossEntropyLoss()    optimizer=optim.Adam(model.parameters(),lr=0.0001)    BATCH_SIZE=128    EPOCH=100    #数据集预处理    train_db=Hyperspectral_DataSet('./hyperspectral-files/PaviaU.mat',                                   './hyperspectral-files/PaviaU_gt.mat')    #数据集划分和加载    train_db,val_db= torch.utils.data.random_split(train_db,[int(len(train_db)*0.8),int(len(train_db)*0.2)])    train_loader=DataLoader(train_db,batch_size=BATCH_SIZE,shuffle=True)    test_loader=DataLoader(val_db,batch_size=BATCH_SIZE)    for epoch in range(EPOCH):        model.train()        count = 0.        all = 0        total_loss = 0.        for batch_idx, (input, target) in enumerate(train_loader):            #all计算            all+=input.shape[0]            #把数据加载到GPU            input=input.to(device)            target=target.to(device)            #input进行modle运算            output=model(input)            #loss计算            loss=criterion(output,target)            #预测结果求每一行最大的列标号            output=torch.argmax(output,1)            #计算一个批次预测对的数量            count+=(output==target).sum()            #total_loss提取            total_loss+=loss            # 先将梯度归零            optimizer.zero_grad()            # 反向传播计算得到每个参数的梯度值            loss.backward()            # 通过梯度下降执行一步参数更新            optimizer.step()        train_loss.append(total_loss )        train_acc.append(count / all)        print('train_acc:',count/all)        model.eval()        count = 0.        all = 0        loss = 0.0        #TEST不用梯度反向传播        with torch.no_grad():            for (input, target) in test_loader:                #all计算                all+=input.shape[0]                #把数据加载到GPU                input.to(device)                target.to(device)                #input进行modle运算                output=model(input)                #loss计算                loss=criterion(output,target)                #预测结果求每一行最大的列标号                output=torch.argmax(output,1)                #计算一个批次预测对的数量                count+=(output==target).sum()                pass            acc=count/all        print('test_acc:',acc)        test_loss.append(loss)        test_acc.append(acc)        #计算BEST_ACC        if acc > best_acc:            best_acc=acc            torch.save(model.state_dict(),'./hyperspectral-checkpoints/best-{}.pth'.format(best_acc))            pass    fig,ax =plt.subplots(1,2)    ax[0].plot(range(EPOCH), train_loss, label='train_loss')    ax[0].plot(range(EPOCH), test_loss, label='test_loss')    ax[0].legend()    ax[1].plot(range(EPOCH), train_acc, label='train_acc')    ax[1].plot(range(EPOCH), test_acc, label='test_acc')    ax[1].legend()  #设置图例    plt.savefig('./result.png')    passif __name__ == '__main__':    main()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值