run.py是整个项目的入口,它包含两部分,一是使用argparse工具,配置相关参数;二是整个项目的流程框架,各个模块/函数的调用。
目录
1. 参数配置
#声明argparse对象 可附加说明
parser = argparse.ArgumentParser(description='Chinese Text Classification')
#添加参数
#模型是必须设置的参数(required=True) 类型是字符串
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
#embedding随机初始化或使用预训练词或字向量 默认使用预训练
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
#基于词还是基于字 默认基于字
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
#解析参数
args = parser.parse_args()
2. 项目流程
if __name__ == '__main__':
dataset = 'THUCNews' # 数据集
# 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
embedding = 'embedding_SougouNews.npz' #默认使用搜狗预训练字向量
if args.embedding == 'random': #如果embedding参数设置为random
embedding = 'random'
#获取选择的模型名字
model_name = args.model # 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer
#导入数据预处理与加载函数
if model_name == 'FastText': #如果所选模型名字为FastText 由于增加了bi-gram tri-gram特征 会有不同的行为
from utils_fasttext import build_dataset, build_iterator, get_time_dif
embedding = 'random' #此时embedding需要设置为随机初始化
else: #其他模型统一处理
from utils import build_dataset, build_iterator, get_time_dif
x = import_module('models.' + model_name) #根据所选模型名字在models包下 获取相应模块(.py)
config = x.Config(dataset, embedding) #每一个模块(.py)中都有一个模型定义类 和与该模型相关的配置类(定义该模型的超参数) 初始化配置类的对象
#设置随机种子 确保每次运行的条件(模型参数初始化、数据集的切分或打乱等)是一样的
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True
start_time = time.time()
print("Loading data...")
#数据预处理
vocab, train_data, dev_data, test_data = build_dataset(config, args.word) #构建词典、训练集、验证集、测试集
#构建训练集、验证集、测试集迭代器/生成器(节约内存、避免溢出)
train_iter = build_iterator(train_data, config)
dev_iter = build_iterator(dev_data, config)
test_iter = build_iterator(test_data, config)
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)
# 构造模型对象
config.n_vocab = len(vocab) #词典大小可能不确定,在运行时赋值
model = x.Model(config).to(config.device) #构建模型对象 并to_device
if model_name != 'Transformer': #如果不是Transformer模型 则使用自定义的参数初始化方式
init_network(model) #也可以采用之前达观杯中的做法 把自定义模型参数的函数 放在模型的定义类中 在__init__中执行
print(model.parameters)
#训练、验证和测试
train(config, model, train_iter, dev_iter, test_iter)
- 自定义参数初始化方式
# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):
for name, w in model.named_parameters():
if exclude not in name: #排除embedding层
#权重按指定方式初始化 默认xavier
if 'weight' in name:
if method == 'xavier':
nn.init.xavier_normal_(w)
elif method == 'kaiming':
nn.init.kaiming_normal_(w)
else:
nn.init.normal_(w)
#偏置初始化为常数0
elif 'bias' in name:
nn.init.constant_(w, 0)
else:
pass