🟢 这一期主要是利用管道来实现微博内容情感分析+保存数据库的操作
🟢 情感分析使用BERT实现
1 利用BERT来做情感分析
我们需要通过代码实现了一个完整的BERT情感分析流程,包括数据预处理、模型定义、数据迭代以及情感分类的推断。
1.1 BERT介绍
BERT(Bidirectional Encoder Representations from Transformers)通过预训练大量文本数据,能够捕捉语言中的上下文语义关系,尤其擅长处理长文本中的复杂依赖。它通过双向编码器同时从前后文学习句子含义,能更准确地理解微博中的情感倾向。微博中的文本往往短小、口语化且包含多种情感表达,BERT在处理此类细微情感差异时表现出色,因此广泛应用于微博内容的情感分析任务。
1.2 搭建BERT模型
import re
from scrapy.utils.project import get_project_settings
from weiboScrapy.pytorch_pretrained.modeling import BertModel
from weiboScrapy.pytorch_pretrained.tokenization import BertTokenizer
import torch
import torch.nn as nn
import numpy as np
class Config(object):
"""配置参数"""
def __init__(self):
settings = get_project_settings()
self.model_name = 'bert'
self.class_list = ['中性', '积极', '消极'] # 类别名单
self.save_path = settings.get('BERT_SAVE_PATH')
# self.save_path = './saved_dict/bert.ckpt' # 模型训练结果
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
# 兼容apple silicon
self.device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"device:{self.device}")
self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练
self.num_classes = len(self.class_list) # 类别数
self.num_epochs = 3 # epoch数
self.batch_size = 128 # mini-batch大小
self.pad_size = 32 # 每句话处理成的长度(短填长切)
self.learning_rate = 5e-5 # 学习率
# self.bert_path = './bert_pretrain'
self.bert_path = settings.get('BERT_PRETRAIN_PATH')
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
self.hidden_size = 768
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.bert = BertModel.from_pretrained(config.bert_path)
for param in self.bert.parameters():
param.requires_grad = True
self.fc = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, x):
context = x[0] # 输入的句子
mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0]
_, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False)
out = self.fc(pooled)
return out
PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号
def clean(text):
# text = re.sub(r"(回复)?(//)?\s*@\S*?\s*(:| |$)", " ", text) # 去除正文中的@和回复/转发中的用户名
# text = re.sub(r"\[\S+\]", "", text) # 去除表情符号
# text = re.sub(r"#\S+#", "", text) # 保留话题内容
URL_REGEX = re.compile(
r'(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:\'".,<>?«»“”‘’]))',
re.IGNORECASE)
text = re.sub(URL_REGEX, "", text) # 去除网址
text = text.replace("转发微博", "") # 去除无意义的词语
text = re.sub(r"\s+", " ", text) # 合并正文中过多的空格
return text.strip()
def load_dataset(data, config):
pad_size = config.pad_size
contents = []
for line in data:
lin = clean(line)
token = config.tokenizer.tokenize(lin) # 分词
token = [CLS] + token # 句首加入CLS
seq_len = len(token)
mask = []
token_ids = config.tokenizer.convert_tokens_to_ids(token)
if pad_size:
if len(token) < pad_size:
mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
token_ids += ([0] * (pad_size - len(token)))
else:
mask = [1] * pad_size
token_ids = token_ids[:pad_size]
seq_len = pad_size
contents.append((token_ids, int(0), seq_len, mask))
return contents
class DatasetIterater(object):
def __init__(self, batches, batch_size, device):
self.batch_size = batch_size
self.batches = batches # data
self.n_batches = len(batches) // batch_size
self.residue = False # 记录batch数量是否为整数
if len(batches) % self.n_batches != 0:
self.residue = True
self.index = 0
self.device = device
def _to_tensor(self, datas):
x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
# pad前的长度(超过pad_size的设为pad_size)
seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
mask = torch.LongTensor([_[3] for _ in datas]).to(self.device)
return (x, seq_len, mask), y
def __next__(self): # 返回下一个迭代器对象,必须控制结束条件
if self.residue and self.index == self.n_batches:
batches = self.batches[self.index * self.batch_size: len(self.batches)]
self.index += 1
batches = self._to_tensor(batches)
return batches
elif self.index >= self.n_batches:
self.index = 0
raise StopIteration
else:
batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
self.index += 1
batches = self._to_tensor(batches)
return batches
def __iter__(self): # 返回一个特殊的迭代器对象,这个迭代器对象实现了 __next__() 方法并通过 StopIteration 异常标识迭代的完成。
return self
def __len__(self):
if self.residue:
return self.n_batches + 1
else:
return self.n_batches
def build_iterator(dataset, config):
iter = DatasetIterater(dataset, 1, config.device)
return iter
def match_label(pred, config):
label_list = config.class_list
return label_list[pred]
def final_predict(config, model, data_iter):
map_location = lambda storage, loc: storage
model.load_state_dict(torch.load(config.save_path, map_location=map_location))
model.eval()
predict_all = np.array([])
with torch.no_grad():
for texts, _ in data_iter:
outputs = model(texts)
pred = torch.max(outputs.data, 1)[1].cpu().numpy()
pred_label = [match_label(i, config) for i in pred]
predict_all = np.append(predict_all, pred_label)
return predict_all
- Config类:配置了模型的参数,包括类别(中性、积极、消极)、设备选择(CUDA或MPS加速)、训练的超参数(batch size、学习率、训练轮次等),并且加载了BERT预训练模型和分词器。它还定义了模型保存路径和一些与训练有关的设置。
- Model类:这是情感分析的核心模型。模型基于BERT预训练模型,并在其基础上添加了一个全连接层,用于分类(即将文本分类为中性、积极或消极情感)。模型的输入是微博文本,输出是情感类别的预测。
- clean函数:用于清洗微博文本,移除一些不必要的字符、表情符号、链接等,以便更好地输入到模型中进行情感分析。
- load_dataset函数:将微博数据转换成模型可以处理的格式。它对文本进行分词,将分词后的token映射成ID,并对其进行padding以保证输入的长度一致。
- DatasetIterater类:用于数据的批量迭代。它负责将数据划分成小批量,并将其转换成张量,供模型训练或预测使用。
- build_iterator函数:创建一个数据迭代器,用于模型的批量训练或推断。
- match_label函数:将模型输出的预测索引映射为情感类别标签(中性、积极、消极)。
- final_predict函数:这是最终用于推断的函数。它加载预训练好的模型,然后对输入的微博文本进行预测,返回预测的情感类别。
1.3在pipelines中新增一个管道
这个管道可以放在保存MYSQL的管道前,可以这样设置
ITEM_PIPELINES = {
# 'weiboScrapy.pipelines.DuplicatesPipeline': 300,
'weiboScrapy.pipelines.SentimentPipeline': 301,
'weiboScrapy.pipelines.MySQLPipeline': 302,
}```
```python
# 处理情感分析的管道
class SentimentPipeline(object):
def __init__(self):
self.config = Config()
self.model = Model(self.config).to(self.config.device)
def process_item(self, item, spider):
if 'weibo' in item:
GOLD = "\033[38;5;214m" # 使用色号214表示金色
RESET = "\033[0m" # 重置颜色
text = [item['weibo']['text']]
test_data = load_dataset(text, self.config)
test_iter = build_iterator(test_data, self.config)
result = final_predict(self.config, self.model, test_iter)
for i, j in enumerate(result):
# print('text:{}'.format(text[i]))
# print('label:{}'.format(j))
item['weibo']['label'] = j
print(f"{GOLD}情感分析微博内容:{item['weibo']['text']},结果:{j}{RESET}")
return item
2 编写items
这里给出完整的items.py的代码,为了接受情感分析结果,增加了情感分析的标签字段
# 微博内容
class WeiboItem(scrapy.Item):
# define the fields for your item here like:
id = scrapy.Field()
mid = scrapy.Field()
bid = scrapy.Field()
user_id = scrapy.Field()
screen_name = scrapy.Field()
text = scrapy.Field()
article_url = scrapy.Field()
location = scrapy.Field()
at_users = scrapy.Field()
topics = scrapy.Field()
reposts_count = scrapy.Field()
comments_count = scrapy.Field()
attitudes_count = scrapy.Field()
created_at = scrapy.Field()
source = scrapy.Field()
pics = scrapy.Field()
video_url = scrapy.Field()
retweet_id = scrapy.Field()
ip = scrapy.Field()
user_authentication = scrapy.Field()
keywords = scrapy.Field() # 关键词
label = scrapy.Field() # 情感分析结果
3 保存MySQL
编写保存MySQL的管道代码
# MySQL 处理管道
class MySQLPipeline:
def open_spider(self, spider):
# 获取数据库配置
settings = get_project_settings()
self.connection = pymysql.connect(
host=settings['MYSQL_HOST'],
user=settings['MYSQL_USER'],
password=settings['MYSQL_PASSWORD'],
database=settings['MYSQL_DATABASE'],
charset='utf8',
use_unicode=True,
)
self.cursor = self.connection.cursor()
self.connection.commit()
def close_spider(self, spider):
# 关闭数据库连接
self.cursor.close()
self.connection.close()
def process_item(self, item, spider):
print('MySQL 管道...')
if 'weibo' in item:
data = dict(item['weibo'])
keys = ', '.join(data.keys())
values = ', '.join(['%s'] * len(data))
sql = """INSERT INTO {table}({keys}) VALUES ({values}) ON
DUPLICATE KEY UPDATE""".format(table='tb_weibo',
keys=keys,
values=values)
update = ','.join([" {key} = {key}".format(key=key) for key in data])
sql += update
CYAN = '\033[96m' # 青色
RED = '\033[91m' # 红色
RESET = '\033[0m' # 重置为默认颜色
try:
self.cursor.execute(sql, tuple(data.values()))
self.connection.commit()
print(f"{CYAN}插入数据库成功{RESET}")
except Exception as E:
print('**********'+sql)
print("Error:", E)
print(f"{RED}插入数据库失败{RESET}")
self.connection.rollback()
self.cursor.execute(sql, values)
self.connection.commit() # 提交事务
return item # 必须返回 item
4 运行效果
我们用 ‘黑神话’, ‘iphone’ 这两个作为关键词,爬取范围是2024年9月21日到2024年10月8日的微博内容。
scrapy爬取效果如下:
用金色文字的日志打印出了情感分析结果:
存储数据库的结果: