利用bert做分类

下面展示一些 内联代码片

// A code block
var foo = 'bar';
import json
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification

def load_model(pretrained_path, model_path, device):
    tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    model.to(device)
    model.eval()
    return tokenizer, model

def predict_text_category_batch(texts, model, tokenizer, device):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
    predicted_classes = torch.argmax(outputs.logits, dim=1).cpu().tolist()
    return predicted_classes

def classify_jsonl(input_file, output_files, model, tokenizer, device, batch_size=1024):
    with open(input_file, 'rb') as f:
        lines = f.read().decode('utf-8', errors='ignore').splitlines()
        num_lines = len(lines)

        file_handlers = {i: open(output_file, 'w', encoding='utf-8') for i, output_file in output_files.items()}

        try:
            for i in tqdm(range(0, num_lines, batch_size)):
                batch_lines = lines[i:i+batch_size]
                texts = []
                valid_lines = []
                for line in batch_lines:
                    try:
                        entry = json.loads(line.strip())
                        texts.append(entry['text'])
                        valid_lines.append(line)
                    except json.JSONDecodeError:
                        print(f"Skipping invalid JSON line: {line}")
                        continue

                predicted_classes = predict_text_category_batch(texts, model, tokenizer, device)

                for line, predicted_class in zip(valid_lines, predicted_classes):
                    entry = json.loads(line)
                    file_handlers[predicted_class].write(json.dumps(entry, ensure_ascii=False) + '\n')

        finally:
            for fh in file_handlers.values():
                fh.close()

# 配置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pretrained_path = './model/base_model/tinybert_6L_zh/'  # 更新为你的预训练模型路径
model_path = './model/save_model_6/third_tinybert_6L_zh_epoch_3/'  # 更新为你的保存模型路径

# 加载模型和tokenizer
tokenizer, model = load_model(pretrained_path, model_path, device)

# 文件路径
input_jsonl = './output_class_1_2.jsonl'
output_files = {
    0: './muti_data_1/output_class_0.jsonl',
    1: './muti_data_1/output_class_1.jsonl',
    2: './muti_data_1/output_class_2.jsonl',
    3: './muti_data_1/output_class_3.jsonl'
}

# 运行分类
classify_jsonl(input_jsonl, output_files, model, tokenizer, device)
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值