NLP项目2-中文文本分类
Huggingface实现中文文本分类
1.定义Torch数据集
import torch
from datasets import load_from_disk
class Dataset(torch.utils.data.Dataset):
def __init__(self, split):
self.datasets = load_from_disk('../data/ChnSentiCorp')
self.dataset = self.datasets[split]
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
text = self.dataset[i]['text']
label = self.dataset[i]['label']
return text, label
dataset = Dataset('train')
dataset
<__main__.Dataset at 0x2cf4f74fb08>
dataset.dataset
Dataset({
features: ['text', 'label'],
num_rows: 9600
})
len(dataset)
9600
dataset[0]
('选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般',
1)
2.加载字典和分词工具 Huggingface
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
tokenizer
BertTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})
3.重写Collate_fn 批量读取数据
def collate_fn(data):
sents = [i[0] for i in data]
labels = [i[1] for i in data]
data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents, # 按批次编码
truncation=True, # 截断
padding='max_length',
return_tensors='pt', # torch
return_length='True')
input_ids = data['input_ids'] # 编码之后的数字 中文索引
attention_mask = data['attention_mask']
token_type_ids = data['token_type_ids']
labels = torch.LongTensor(labels)
return input_ids, attention_mask, token_type_ids, labels
4.数据加载器
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=16,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
for i,(input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
break
print(len(loader)) # 16一批 共计600批次
600
attention_mask
tensor([[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
...,
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0]])
print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels) # model_max_length=512
torch.Size([16, 512]) torch.Size([16, 512]) torch.Size([16, 512]) tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0])
5.加载预训练模型
from transformers import BertModel
pretrained = BertModel.from_pretrained('bert-base-chinese')
6.固定Bert参数
for param in pretrained.parameters():
param.requires_grad_(False)
Downloading (…)"pytorch_model.bin";: 0%| | 0.00/412M [00:00<?, ?B/s]
7.模型试算
out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
out.last_hidden_state[:, 0].shape
torch.Size([16, 768])
8.定义下游任务模型
class Model(torch.nn.Module):
def __init__(self, pretrained):
super().__init__()
self.fc = torch.nn.Linear(768, 2)
self.pretrained = pretrained
def forward(self, input_ids, attention_mask, token_type_ids):
with torch.no_grad():
out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
out = self.fc(out.last_hidden_state[:, 0])
out = out.softmax(dim=1)
return out
model = Model(pretrained)
model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids).shape
torch.Size([16, 2])
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda', index=0)
model.device
AttributeError: 'Model' object has no attribute 'device'
9.训练
from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()
model = Model(pretrained) # 加载预训练模型
model.train()
model.to(device)
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
labels = labels.to(device)
out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
loss = criterion(out, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if i % 5 == 0:
out = out.argmax(dim=1)
accuracy = (out == labels).sum().item() / len(labels)
print(i, loss.item(), accuracy)
if i == 600:
break
0 0.6827474236488342 0.5625
...
590 0.7014723420143127 0.4375
595 0.7267574071884155 0.375
10.测试
def test():
model.eval()
correct = 0
total = 0
loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
batch_size=32,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
labels = labels.to(device)
if i == 5:
break
print(i)
with torch.no_grad():
out = model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
out = out.argmax(dim=1)
correct += (out==labels).sum().item()
total += len(labels)
print(correct / total)
test()
0
1
2
3
4
0.58125