下面展示一些 内联代码片
。
// A code block
var foo = 'bar';
import json
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def load_model(pretrained_path, model_path, device):
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.to(device)
model.eval()
return tokenizer, model
def predict_text_category_batch(texts, model, tokenizer, device):
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
predicted_classes = torch.argmax(outputs.logits, dim=1).cpu().tolist()
return predicted_classes
def classify_jsonl(input_file, output_files, model, tokenizer, device, batch_size=1024):
with open(input_file, 'rb') as f:
lines = f.read().decode('utf-8', errors='ignore').splitlines()
num_lines = len(lines)
file_handlers = {i: open(output_file, 'w', encoding='utf-8') for i, output_file in output_files.items()}
try:
for i in tqdm(range(0, num_lines, batch_size)):
batch_lines = lines[i:i+batch_size]
texts = []
valid_lines = []
for line in batch_lines:
try:
entry = json.loads(line.strip())
texts.append(entry['text'])
valid_lines.append(line)
except json.JSONDecodeError:
print(f"Skipping invalid JSON line: {line}")
continue
predicted_classes = predict_text_category_batch(texts, model, tokenizer, device)
for line, predicted_class in zip(valid_lines, predicted_classes):
entry = json.loads(line)
file_handlers[predicted_class].write(json.dumps(entry, ensure_ascii=False) + '\n')
finally:
for fh in file_handlers.values():
fh.close()
# 配置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pretrained_path = './model/base_model/tinybert_6L_zh/' # 更新为你的预训练模型路径
model_path = './model/save_model_6/third_tinybert_6L_zh_epoch_3/' # 更新为你的保存模型路径
# 加载模型和tokenizer
tokenizer, model = load_model(pretrained_path, model_path, device)
# 文件路径
input_jsonl = './output_class_1_2.jsonl'
output_files = {
0: './muti_data_1/output_class_0.jsonl',
1: './muti_data_1/output_class_1.jsonl',
2: './muti_data_1/output_class_2.jsonl',
3: './muti_data_1/output_class_3.jsonl'
}
# 运行分类
classify_jsonl(input_jsonl, output_files, model, tokenizer, device)