训练和测试代码

from model.lstm import LSTM
from config import args
import torch
import time
import pandas as pd
from utils.pre_process import load_data, to_transpose_lstm
import torch.utils.data as Data
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns

# 调制信号类别
classes = {b'QAM16': 0, b'QAM64': 1, b'8PSK': 2, b'WBFM': 3, b'BPSK': 4,
           b'CPFSK': 5, b'AM-DSB': 6, b'GFSK': 7, b'PAM4': 8, b'QPSK': 9, b'AM-SSB': 10}

# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
def train_valid(model, args, train_dataloader, valid_dataloader, criterion, optimizer):
    train_loss_all = []
    train_acc_all = []
    valid_loss_all = []
    valid_acc_all = []
    best_acc = 0.0

    for epoch in range(args.epochs):
        print('Epoch {}/{}'.format(epoch, args.epochs - 1))
        print('-' * 50)

        train_loss = 0.0
        train_corrects = 0.0
        train_num = 0
        valid_loss = 0.0
        valid_corrects = 0.0
        valid_num = 0


        start_time = time.time()
        model.train()     # 训练
        for batch_idx, (input, target) in enumerate(train_dataloader):

            output = model(input)
            pre_lab = torch.argmax(output, 1)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            train_loss += loss.item() * input.size(0)
            train_corrects += torch.sum(pre_lab == target)
            train_num += input.size(0)
        training_time = time.time() - start_time

        model.eval()     # 验证
        with torch.no_grad():
            for batch_idx, (input, target) in enumerate(valid_dataloader):
                output = model(input)
                pre_lab = torch.argmax(output, 1)
                loss = criterion(output, target)
                valid_loss += loss.item() * input.size(0)
                valid_corrects += torch.sum(pre_lab == target)
                valid_num += input.size(0)

        train_loss_all.append(train_loss/train_num)
        train_acc_all.append(train_corrects.item() / train_num)
        valid_loss_all.append(valid_loss/valid_num)
        valid_acc_all.append(valid_corrects.item() / valid_num)

        print('{} Train Loss: {:.4f}  Train Acc: {:.4f}'.format(
            epoch, train_loss_all[-1], train_acc_all[-1]))
        print('{} Val Loss: {:.4f}  val Acc: {:.4f}'.format(
            epoch, valid_loss_all[-1], valid_acc_all[-1]))
        # 拷贝模型最高精度下的参数
        if valid_acc_all[-1] >= best_acc:
            best_acc = valid_acc_all[-1]
            #######################保存模型#####################################
            torch.save(model.state_dict(), "../MulScale_CLDNN-25.pkl",
                       _use_new_zipfile_serialization=False)
            # torch.save(state_dict, "xxx.pth", _use_new_zipfile_serialization=False) #1.6之后pytorch默认使用zip文件格式来保存权重文件,导致这些权重文件无法直接被1.5及以下的pytorch加载。
            ####################################################################

        # 使用最好模型的参数
        #########################################################################
    model.load_state_dict(torch.load('../MulScale_CLDNN-25.pkl'))  # 加载模型
    ########################################################################

    train_process = pd.DataFrame(
        data={"epoch": range(args.epochs + 1),  # 必须加1,才能和下面的train_loss_all等长度一致
              "train_loss_all": train_loss_all,
              "val_loss_all": valid_loss_all,
              "train_acc_all": train_acc_all,
              "val_acc_all": valid_acc_all})
    return model, train_process

if __name__ == '__main__':
    (mods, snrs, lbl), (X_train, Y_train), (X_val, Y_val), (X_test, Y_test), (train_idx, val_idx, test_idx) = load_data()

    in_shp = list(X_train.shape[1])
    classes = mods

    train_data = Data.TensorDataset(X_train, X_test)
    valid_data = Data.TensorDataset(Y_train, Y_test)
    train_dataloader = Data.DataLoader(dataset=train_data, #使用的训练集
                                       batch_size=args.batch_size,  #批处理 样本大小
                                       shuffle=True,                #每次迭代前打乱数据
                                       num_workers=1,               #使用一个进程 只能开一个
                                        )

    valid_dataloader = Data.DataLoader(dataset=valid_data, #使用的训练集
                                       batch_size=args.batch_size,  #批处理 样本大小
                                       shuffle=True,                #每次迭代前打乱数据
                                       num_workers=1,               #使用一个进程 只能开一个
                                        )

    # 输出网络结构
    model = LSTM()
    mdoel = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

    model, train_process = train_valid(model, args, train_dataloader, valid_dataloader, criterion, optimizer)

    # 可视化学习率
    plt.figure(1)
    plt.plot(train_process.epoch, train_process.lr_list)
    plt.xlabel("epoch")
    plt.ylabel("lr")
    plt.title("learning rate")

    # 可视化模型训练过程中的损失函数
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_process.epoch, train_process.train_loss_all,
             "ro-", label="Train loss")
    plt.plot(train_process.epoch, train_process.val_loss_all,
             "bs-", label="Val loss")
    plt.legend()
    plt.xlabel("epoch")
    plt.ylabel("Loss")
    plt.subplot(1, 2, 2)
    plt.plot(train_process.epoch, train_process.train_acc_all,
             "ro-", label="Train acc")
    plt.plot(train_process.epoch, train_process.val_acc_all,
             "bs-", label="Val acc")
    plt.xlabel("epoch")
    plt.ylabel("acc")
    plt.legend()
    plt.show()

    # 如果测试集过大,难以计算整体模型识别率和画出混淆矩阵
    (mods, snrs, lbl), (X_train, Y_train), (X_val, Y_val), (X_test, Y_test), (train_idx, val_idx, test_idx) = load_data()
    print("X_test.shape:", X_test.shape)
    print('-' * 50)

    # 测试集准确度随信噪比的变化曲线
    acc = {}
    model.eval()
    for snr in snrs:
        test_SNRs = map(lambda x: lbl[x], test_idx)
        test_SNRs = list(test_SNRs)
        test_X_i = X_test[np.where(np.array(test_SNRs) == snr)]
        test_Y_i = Y_test[np.where(np.array(test_SNRs) == snr)]
        output = model(test_X_i)
        pre_lab = torch.argmax(output, 1).cpu()  # 得到最大值的序号索引
        acc[snr] = accuracy_score(test_Y_i, pre_lab)
        print(acc[snr])
    plt.figure(3)
    plt.plot(snrs, list(map(lambda x: acc[x], snrs)))
    plt.xlabel("Signal to Noise Ratio")
    plt.ylabel("Classification Accuracy")
    plt.title("LSTM Classification Accuracy on RadioML 2016.10 Alpha")
    plt.yticks(np.linspace(0, 1, 11))

    acc_total = {}
    output = model(X_test)
    pre_lab = torch.argmax(output, 1).cpu()
    acc_total = accuracy_score(Y_test, pre_lab)
    print("在测试集上的预测精度为:", acc_total)

    # 计算混淆矩阵并可视化
    plt.figure(4)
    conf_mat = confusion_matrix(Y_test, pre_lab)
    df_cm = pd.DataFrame(conf_mat, index=classes,
                         columns=classes)
    heatmap = sns.heatmap(df_cm, annot=True, fmt="d", cmap="YlGnBu")
    heatmap.yaxis.set_ticklabels(
        heatmap.yaxis.get_ticklabels(), rotation=0, ha='right')
    heatmap.xaxis.set_ticklabels(
        heatmap.xaxis.get_ticklabels(), rotation=45, ha='right')
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()






  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值