1.简介
1.1项目划分:
-
登录
-
注册
-
主页
-
新闻页面(点击不同分类的会显示对应的分类新闻)
-
新闻内容
-
添加新闻
1.2 技术栈:
前端:js+css+bootstrap框架搭建。
后端:python+django框架+mysql搭建。
分类模型:torch1.7.1 + bert模型(hugging face)
1.3 数据集:
数据集采用的是dbpedia数据集
通过分类模型算法可以自动将添加的新闻计算出对应的分类,并添加到的相应的类别中。
2.训练模型和测试模型
2.1部分训练模型代码:
tokenizer = AutoTokenizer.from_pretrained("pretrain_Model_path")
model = AutoModelForSequenceClassification.from_pretrained("pretrain_Model_path", num_labels=10)
label_dic = {
'finance': 0,
'realty': 1,
'stocks': 2,
'education': 3,
'science': 4,
'society': 5,
'politics': 6,
'sports': 7,
'game': 8,
'entertainment': 9}
def get_train_data(file, label_dic):
content = []
label = []
with open(file, "r", encoding="utf-8") as f:
for i in f.readlines():
c, l = i.split("\t")
content.append(re.sub('[^\u4e00-\u9fa5]', "", c))
label.append(int(l.strip()))
return content, label
content, label = get_train_data('data/train.txt', label_dic=label_dic)
data = pd.DataFrame({"content": content, "label": label})
# data = shuffle(data)
train_data = tokenizer(data.content.to_list()[:8000], padding="max_length", max_length=40, truncation=True,
return_tensors="pt")
train_label = data.label.to_list()[:8000]
batch_size = 16
train = TensorDataset(train_data["input_ids"], train_data["attention_mask"], torch.tensor(train_label))
train_sampler = RandomSampler(train)
train_dataloader = DataLoader(train, sampler=train_sampler, batch_size=batch_size)
# 定义优化器
from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=1e-4)
# 定义学习率和训练轮数
num_epochs = 50
from transformers import get_scheduler
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)
device = torch.device("cuda:0")
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
2.2 部分测试代码接口
def classify_text(text):
test = tokenizer(text, return_tensors="pt", padding="max_length",
max_length=100)
model.eval()
with torch.no_grad():
outputs = model(test["input_ids"].to(device),
token_type_ids=None,
attention_mask=test["attention_mask"].to(device))
logits = outputs["logits"].cpu()
pred_flat = np.argmax(logits, axis=1)
pred_flat = pred_flat.cpu().numpy().squeeze()
print(pred_flat.tolist())
2.3 测试