一、github项目下载
finBERT项目源码
我的项目路径
二、模型训练及输出测试
1.运行dataset.py
2.生成data
数据来源路透社的文章一些节选,需要的可以github项目中有下载地址
3.训练模型
from pathlib import Path
import sys
sys.path.append('..')
import argparse
import shutil
import os
import logging
from textblob import TextBlob
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import *
from finbert.finbert import *
import finbert.utils as tools
from pprint import pprint
from sklearn.metrics import classification_report
project_dir = Path.cwd().parent
pd.set_option('max_colwidth', -1)
# %%
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.ERROR)
# %% md
## Prepare the model
# %% md
# %%
# lm_path = project_dir / 'models' / 'TRC2' / 'pytorch_model.bin'
lm_path = project_dir / 'models' / 'classifier_model' / 'TRC2'
cl_path = project_dir / 'models' / 'classifier_model' / 'finbert-sentiment'
cl_data_path = project_dir / 'data' / 'sentiment_data'
try:
shutil.rmtree(cl_path)
except:
pass
bertmodel = BertForSequenceClassification.from_pretrained(lm_path, cache_dir