Bert实现文本分类微调Demo
1. 运行代码
import random
from collections import namedtuple
# 使用namedtuple定义一个类别(Category),包含两个字段:名称(name)和样例(samples)
Category = namedtuple('Category', ['name', 'samples'])
# 定义四个不同的类别及其对应的样例文本
categories = [
Category('Weather Forecast', ['今天北京晴转多云,气温20-25度。', '明天上海有小雨,记得带伞。']), # 天气预报类别的样例
Category('Company Financial Report', ['本季度公司净利润增长20%。', '年度财务报告显示,成本控制良好。']), # 公司财报类别的样例
Category('Company Audit Materials', ['审计发现内部控制存在漏洞。', '审计确认财务报表无重大错报。']), # 公司审计材料类别的样例
Category('Product Marketing Ad', ['新口味可乐,清爽上市!', '买一送一,仅限今日。']) # 产品营销广告类别的样例
]
def generate_data(num_samples_per_category=50):
'''
生成模拟数据集
输入:
- num_samples_per_category: 每个类别生成的样本数量,默认为50
输出:
- data: 包含文本样本及其对应类别的列表,每项为一个元组(text, label)
'''
data = [] # 初始化存储数据的列表
for category in categories: # 遍历所有类别
for _ in range(num_samples_per_category): # 对每个类别生成指定数量的样本
sample = random.choice(category.samples) # 从该类别的样例中随机选择一条文本
data.append((sample, category.name)) # 将文本及其类别添加到data列表中
return data
# 调用generate_data函数生成模拟数据集
train_data = generate_data(100) # 为每个类别生成100个训练样本
test_data = generate_data(6) # 生成少量(6个)测试样本用于演示
'''
train_data =
[('明天上海有小雨,记得带伞。', 'Weather Forecast'),
('明天上海有小雨,记得带伞。', 'Weather Forecast'),
('今天北京晴转多云,气温20-25度。', 'Weather Forecast'),
('今天北京晴转多云,气温20-25度。', 'Weather Forecast'),
('今天北京晴转多云,气温20-25度。', 'Weather Forecast'),
('明天上海有小雨,记得带伞。', 'Weather Forecast'),
('明天上海有小雨,记得带伞。', 'Weather Forecast'),
('明天上海有小雨,记得带伞。', 'Weather Forecast'),
('今天北京晴转多云,气温20-25度。', 'Weather Forecast'),]
'''
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, TensorDataset
import torch
import torch.nn.functional as F
# 步骤1: 定义类别到标签的映射
label_map = {category.name: index for index, category in enumerate(categories)}
num_labels = len(categories) # 类别总数
# 步骤2: 初始化BERT分词器和模型
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased',
num_labels=num_labels)
# 步骤3: 准备数据集
def encode_texts(texts, labels):
# 对文本进行编码,得到BERT模型需要的输入格式
encodings = tokenizer(texts, truncation=True, padding=True, return_tensors='pt')
# 将标签名称转换为对应的索引
label_ids = torch.tensor([label_map[label] for label in labels])
return encodings, label_ids
def prepare_data(data):
texts, labels = zip(*data) # 解压数据
encodings, label_ids = encode_texts(texts, labels) # 编码数据
dataset = TensorDataset(encodings['input_ids'], encodings['attention_mask'], label_ids) # 创建数据集
return DataLoader(dataset, batch_size=8, shuffle=True) # 创建数据加载器
# 步骤4: 准备训练和测试数据
train_loader = prepare_data(train_data)
test_loader = prepare_data(test_data)
# 步骤5: 定义训练和评估函数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
def train_epoch(model, data_loader, optimizer):
model.train()
total_loss = 0
for batch in data_loader:
optimizer.zero_grad()
input_ids, attention_mask, labels = batch
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
total_loss += loss.item()
loss.backward()
optimizer.step()
return total_loss / len(data_loader)
def evaluate(model, data_loader):
model.eval()
total_acc = 0
total_count = 0
with torch.no_grad():
for batch in data_loader:
input_ids, attention_mask, labels = batch
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
outputs = model(input_ids, attention_mask=attention_mask)
predictions = torch.argmax(outputs.logits, dim=1)
total_acc += (predictions == labels).sum().item()
total_count += labels.size(0)
return total_acc / total_count
# 步骤6: 训练模型
optimizer = AdamW(model.parameters(), lr=2e-5)
for epoch in range(3): # 训练3个epoch
train_loss = train_epoch(model, train_loader, optimizer)
acc = evaluate(model, test_loader)
print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Test Accuracy: {acc*100:.2f}%')
# 步骤7: 使用微调后的模型进行预测
def predict(text):
encodings = tokenizer(text, truncation=True, padding=True, return_tensors='pt')
input_ids = encodings['input_ids'].to(device)
attention_mask = encodings['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
predicted_class_id = torch.argmax(outputs.logits).item()
return categories[predicted_class_id].name
# 预测一个新文本
new_text = ["明天的天气怎么样?"] # 注意这里是一个列表
predicted_category = predict(new_text)
print(f'The predicted category for the new text is: {predicted_category}')
2. 依赖文件
requirements【pip freeze > D:\AworkSpace\requirements\requirements.txt】:
【pip install -r requirements.txt】
需要用腾讯源,其他源我觉得难用,也可以自己测试
pip install -r requirements.txt -i https://mirrors.tencent.com/pypi/simple/
pip install torch=2.3.0 -i https://mirrors.tencent.com/pypi/simple/
pip install transformers[torch] -i https://mirrors.tencent.com/pypi/simple/
3.必须新的虚拟环境,否则安装不成功
pip install ipykernel
python -m ipykernel install --name your_env_name
python -m ipykernel install --name cpy312_env04
旧版jupyter:内核复制到jupyter
个人觉得旧版jupyter好用,主要是:code_cell在运行后,可以复制code_cell的结果;
新版本不可以复制。
3.1 安装cpu版本是因为自己的显卡只有1650ti
一会儿安装了依赖后,需要安装torch、transformer
4. 依赖文件贴出来
accelerate==0.30.1
aiohttp==3.9.5
aiosignal==1.3.1
annotated-types==0.7.0
anyio==4.3.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
attrs==23.2.0
Babel==2.15.0
beautifulsoup4==4.12.3
bleach==6.1.0
certifi==2024.2.2
cffi==1.16.0
chardet==5.2.0
charset-normalizer==3.3.2
click==8.1.7
colorama==0.4.6
comm==0.2.2
contourpy==1.2.1
cycler==0.12.1
datasets==2.19.1
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.8
distlib==0.3.8
dnspython==2.6.1
docx==0.2.4
email_validator==2.2.0
et-xmlfile==1.1.0
executing==2.0.1
fastapi==0.111.0
fastapi-cli==0.0.4
fastjsonschema==2.19.1
filelock==3.14.0
fonttools==4.53.0
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.3.1
h11==0.14.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.0
huggingface-hub==0.23.0
idna==3.7
ijson==3.2.3
intel-openmp==2021.4.0
ipykernel==6.29.4
ipython==8.24.0
ipywidgets==8.1.2
isoduration==20.11.0
jedi==0.17.0
jieba==0.42.1
Jinja2==3.1.4
joblib==1.4.2
json5==0.9.25
jsonpatch==1.33
jsonpointer==2.4
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter-nbextensions-configurator==0.6.3
jupyter_client==8.6.1
jupyter_contrib_core==0.4.2
jupyter_core==5.7.2
jupyter_server==2.14.0
jupyter_server_terminals==0.5.3
jupyterlab==4.2.1
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.1
jupyterlab_widgets==3.0.10
kiwisolver==1.4.5
langid==1.1.6
lxml==5.2.2
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.0
matplotlib-inline==0.1.7
mdurl==0.1.2
mistune==3.0.2
mkl==2021.4.0
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
nltk==3.8.1
notebook==7.2.0
notebook_shim==0.2.4
numpy==1.26.4
openpyxl==3.1.2
orjson==3.10.5
overrides==7.7.0
packaging==24.0
pandas==2.2.2
pandoc==2.3
pandocfilters==1.5.1
parso==0.7.1
pillow==10.3.0
platformdirs==4.2.1
plumbum==1.8.3
ply==3.11
prometheus_client==0.20.0
prompt-toolkit==3.0.43
psutil==5.9.8
pure-eval==0.2.2
pyarrow==16.1.0
pyarrow-hotfix==0.6
pycparser==2.22
pydantic==2.7.3
pydantic_core==2.18.4
Pygments==2.18.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
python-docx==1.1.2
python-dotenv==1.0.1
python-json-logger==2.0.7
python-multipart==0.0.9
pytz==2024.1
pywin32==306
pywinpty==2.0.13
PyYAML==6.0.1
pyzmq==26.0.3
qtconsole==5.5.2
QtPy==2.4.1
referencing==0.35.1
regex==2024.5.15
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rpds-py==0.18.1
safetensors==0.4.3
scikit-learn==1.4.2
scipy==1.13.0
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==70.0.0
shellingham==1.5.4
six==1.16.0
sniffio==1.3.1
soupsieve==2.5
spark-ai-python==0.3.31
stack-data==0.6.3
starlette==0.37.2
sympy==1.12
tbb==2021.12.0
tenacity==8.3.0
terminado==0.18.1
threadpoolctl==3.5.0
tinycss2==1.3.0
tokenizers==0.19.1
torch==2.3.0+cpu
torchaudio==2.3.0+cpu
torchvision==0.18.0+cpu
tornado==6.4
tqdm==4.66.4
traitlets==5.14.3
transformers==4.41.0
typer==0.12.3
types-python-dateutil==2.9.0.20240316
typing_extensions==4.11.0
tzdata==2024.1
ujson==5.10.0
uri-template==1.3.0
urllib3==2.2.1
uvicorn==0.30.1
virtualenv==20.26.2
watchfiles==0.22.0
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.8.0
websockets==12.0
widgetsnbextension==4.0.10
wordcloud==1.9.3
xxhash==3.4.1
yarl==1.9.4