伪标签技术

流程
  1. 就是先在有lable的train data上进行训练,得到一个模型。
  2. 然后用第一步中训练得到的模型进行inference test data对应的标签类别。
  3. 然后对test data进行筛选,选出超过概率阈值的样本与原来的train data进行拼接组合成新的的训练数据,然后再重新进行训练出最终的模型。
代码(只需要对原代码中infernce部分进行修改就可以了,其他部分不动)
class TestDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.title = df['title'].values
        self.assignee = df['assignee'].values
        self.abstract = df['abstract'].values
        self.tokenizer = tokenizer
        self.sep_token = tokenizer.sep_token

    def __len__(self):
        return len(self.title)

    def __getitem__(self, item):
        title = self.title[item]
        assignee = self.assignee[item]
        abstract = self.abstract[item]
        input_text = title + self.sep_token + assignee + self.sep_token + abstract
        inputs = self.tokenizer(input_text, truncation=True, max_length=400, padding='max_length')
        return torch.as_tensor(inputs['input_ids'], dtype=torch.long), \
               torch.as_tensor(inputs['attention_mask'], dtype=torch.long)

def infer(test_loader, model, device):
    model.to(device)
    model.eval()
    total_logits = []
    for step, batch in tqdm(enumerate(test_loader)):
        mask = batch[1].to(device)
        input_ids = batch[0].to(device)
        with torch.no_grad():
            output = model(input_ids=input_ids, attention_mask=mask)
        logits = F.softmax(output.logits, dim=-1)
        total_logits.append(logits.to('cpu').numpy())
    total_logits = np.concatenate(total_logits)
    return total_logits

res = []
for fold in range(5):
    saved_path = CFG.OUTPUT_DIR + "{}_best{}.pth".format(CFG.model_path.replace('/', '_'),fold)
    model.load_state_dict(torch.load(saved_path)['model'])
    test_dataset = TestDataset(test, tokenizer)
    test_dataloader = DataLoader(test_dataset,
                                batch_size=CFG.batch_size * 2,
                                shuffle=False,
                                num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    total_logits = infer(test_dataloader, model, CFG.device)
    res.append(total_logits)

total_logits = np.mean(res, axis=1)      #(样本数,num_lables)
pl_prob = np.max(total_logits, axis=-1)  #(样本数,1)  保留每个类别的最大概率
pl = np.argmax(total_logits, axis=-1)    #(样本数,1)  最大值下标,也就是label_id
test['label'] = pl
test['prob'] = pl_prob
test = test[test['prob'] >0.99]      #筛选出置信度大于0.99的test样本,之后加入到train data中进行训练
test.to_csv('add_data.csv', index=None)  
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值