文本分类半监督学习问题(八)

2021SC@SDUSC

我们现在来进去对于normal_train.py文件的分析,本文件的主要目的是进行bert模型的训练,前面已经对于代码报错的部分进行了修改,现在进行分析:

首先创建解释器:使用argparser的第一步是创建一个ArgumentParser对象,ArgumentParser对象包含将命令行解析成Python数据类型所需的全部信息。

argparse是Python标准库中推荐使用的编写命令行程序的工具,是一个用来解析命令行程序的参数的模块。其实就是我们编写在linux中常见的命令行程序中带的参数的处理程序。这些参数理论上可以分为定位参数(positional arguments)和可选参数。

prog - 程序的名称(默认: sys.argv[0],prog猜测是programma的缩写)

usage - 描述程序用途的字符串(默认值:从添加到解析器的参数生成)

description - 在参数帮助文档之后显示的文本 (默认值:无)

parser = argparse.ArgumentParser(description='PyTorch Base Models')

其次添加参数:

给一个ArgumentParser添加程序阐述信息是通过调用add_arguement()方法完成的。

name or flags - 一个命名或者一个选项字符串的列表

action - 表示该选项要执行的操作

default - 当参数未在命令行中出现时使用的值

dest - 用来指定参数的位置

type - 为参数类型,例如int

choices - 用来选择输入参数的范围。例如choice = [1, 5, 10], 表示输入参数只能为1,5 或10

help - 用来描述这个选项的作用

parser.add_argument('--epochs', default=50, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--batch-size', default=4, type=int, metavar='N',
                    help='train batchsize')
parser.add_argument('--batch-size-u', default=24, type=int, metavar='N',
                    help='train batchsize')

parser.add_argument('--lrmain', '--learning-rate-bert', default=0.00001, type=float,
                    metavar='LR', help='initial learning rate for bert')
parser.add_argument('--lrlast', '--learning-rate-model', default=0.001, type=float,
                    metavar='LR', help='initial learning rate for models')

parser.add_argument('--gpu', default='0,1,2,3', type=str,
                    help='id(s) for CUDA_VISIBLE_DEVICES')

parser.add_argument('--n-labeled', type=int, default=20,
                    help='Number of labeled data')
parser.add_argument('--val-iteration', type=int, default=200,
                    help='Number of labeled data')


parser.add_argument('--mix-option', default=False, type=bool, metavar='N',
                    help='mix option')
parser.add_argument('--train_aug', default=False, type=bool, metavar='N',
                    help='aug for training data')


parser.add_argument('--model', type=str, default='bert-base-uncased',
                    help='pretrained model')

parser.add_argument('--data-path', type=str, default='/Users/wuzehao/Desktop/科研/文本分类/MixText-master/data/yahoo_answers_csv/',
                    help='path to data folders')

解析参数

args = parser.parse_args()

接下来则是使用torch框架,调用GPU

args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
print("gpu num: ", n_gpu)

best_acc = 0

下面是进行测试集数据,训练集数据,验证集数据的加载

    train_labeled_set, train_unlabeled_set, val_set, test_set, n_labels = get_data(
        args.data_path, args.n_labeled)
    labeled_trainloader = Data.DataLoader(
        dataset=train_labeled_set, batch_size=args.batch_size, shuffle=True)
    val_loader = Data.DataLoader(
        dataset=val_set, batch_size=512, shuffle=False)
    test_loader = Data.DataLoader(
        dataset=test_set, batch_size=512, shuffle=False)

模型通过cuda进行训练:

    model = ClassificationBert(n_labels).cuda()
    model = nn.DataParallel(model)
    optimizer = AdamW(
        [
            {"params": model.module.bert.parameters(), "lr": args.lrmain},
            {"params": model.module.linear.parameters(), "lr": args.lrlast},
        ])


    criterion = nn.CrossEntropyLoss()

    test_accs = []

下面是对于损失函数的计算:
​​​​​​​

ef validate(valloader, model, criterion, epoch, mode):
    model.eval()
    with torch.no_grad():
        loss_total = 0
        total_sample = 0
        acc_total = 0
        correct = 0

        for batch_idx, (inputs, targets, length) in enumerate(valloader):
            inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            _, predicted = torch.max(outputs.data, 1)

            correct += (np.array(predicted.cpu()) ==
                        np.array(targets.cpu())).sum()
            loss_total += loss.item() * inputs.shape[0]
            total_sample += inputs.shape[0]

        acc_total = correct/total_sample
        loss_total = loss_total/total_sample

    return loss_total, acc_total

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值