matchzoo笔记

一、数据结构
1、DataPack
MatchZoo中数据是以DataPack这种数据结构进行组织,必须包含5个字段,即id_left, text_left, id_left, text_right, label。就这种形式,虽然不知道id_left和id_right在干嘛
补充:有点懂了,意思是包含了三个dataframe,[id_left,text_left] [id_right,text_rigth] [label]

这种形式,虽然不知道id_left和id_right在干嘛
2、Dataset
继承自torch用来对DataPack进行类型转化,
(1)参数中有mode,可以选择point和pair这两种模式(我猜意思是返回点样本还是对样本,是不是和task有关?)
补充:
Point-Wise 方式
将排序问题转化为多分类问题或者回归问题。分为五个等级,Perfect(完全相关),Excellent(非常相关),Good(很相关),Fair(一般),Bad(不相关),等于说就是划分了阈值,阈值在这个区间都是一类
Pair-Wise 方式
根据给定查询和文档的关系,得到D1 D2等的相关度,按照两两文件之间的大小关系排序

(2)这个类中使用self.index_pool保存对应的索引列表,对于point模式,即为原始的索引列表;而对于pair模式,其中的一个元素的数量为num_neg+1,表示一组元素。(啥子意思哟)
(3)最终返回的Dataset类型是经过采样之后的,可以通过ds.data_pack.frame()进行查看,可以发现是采样之后的结果。(SO?WHO?)

3、Embedding
有fit和transform操作。
transform操作好像就是把读到的文件转换datapack数据结构

4、DataLoader
(1)这里是实现数据迭代的核心类,最终每一次返回的数据都是dp.unpack()返回的形式,也就是说x是dict型的,y是二维列表型的。(?请问dp.unpack我们见过吗?x ,y的形式为啥是这样?真是打扰了)
补充:unpack()函数是用来解压文件

x, y = train_processed.unpack()
test_x, test_y = test_processed.unpack()
data_generator = mz.DataGenerator(train_processed, batch_size=32)
model.fit_generator(data_generator, epochs=5, use_multiprocessing=True, workers=4)

(2)实现采样的类在sampler中,可以发现,这里定义的几个采样类SequentialSampler,SortedSampler以及RandomSampler都是对dataset中的index_pool进行操作;同时BatchSampler每次采样的也是index_pool中的一个元素,因此对于pair模式下,每次采样的都是多个样本,符合loss计算的要求。(What are you talking about?)
(3)在内部使用torch.data.DataLoader的时候,指定了collate_fn参数,这个参数只有当每一个batch_x是dict的时候使用,用于对内部数据进行特殊处理。(我只是一个没有感情的复制机器罢了)

二、Tasks
rank 和 classification的label类型不一样

三、LOSS
(1)matchzoo内部自定义的只有两种,而且都是用于排序的,这两种损失函数都定义在mz.losses这个包中,使用时需要指定num_neg这个参数。
(2)对于分类任务,直接使用torch中常规的损失函数即可,比如torch.nn.CrossEntropy()

四、关于Trainer中的参数
model:表示上面定义的model
optimizer:表示pytorch中的优化器
trainloader:DataLoader类型,表示训练数据加载实例
validloader:DataLoader类型,表示验证数据加载实例
device:可以是list型,表示数据并行训练
start_epoch:默认起始epoch的值
epochs:表示需要运行的epochs,默认为10
validate_interval:int型,表示每各多少个steps显示一次结果
clip_norm:int型或者float型,表示梯度裁剪的norm值(?)
patience:int型,这里并不表示多少个steps,而是表示经过多少次evaluate,使用patience * validate_interval则表示经过多少个steps
key:用于比较的metric,也就是如何用于决定是否early-stopping的标准(科科)
checkpoint:Path类型,表示保存的模型的checkpoint(好像就是模型参数存放的地方)
save_dir:Path类型,表示模型保存的文件夹
save_all:bool型,如果是True,则保存Trainer的实例,否则只保存模型
verbose
7. 训练模型的流程
使用模块中预定义好的数据,或者加载自己的数据,得到DataPack类型的数据;
使用预处理器preprocessor对数据进行处理,仍然得到DataPack类型的数据;这里的处理包含分词,将单词数值化,去除停止词等等;
通过mz.datasets.embedding封装自定义的embedding_matrix,获取上一步预处理时得到的词表,构建最终的embedding_matrix;
PS:这里可以加载自己的词向量,使用load_from_file()函数。

# 这里可以加载自己的词向量,使用load_from_file函数
glove_embedding = mz.datasets.embeddings.load_glove_embedding(dimension=300)
term_index = preprocessor.context['vocab_unit'].state['term_index']
embedding_matrix = glove_embedding.build_matrix(term_index)
l2_norm = np.sqrt((embedding_matrix * embedding_matrix).sum(axis=1))
embedding_matrix = embedding_matrix / l2_norm[:, np.newaxis]

使用mz.dataloader.Dataset对上面的数据进行封装,指定mode, num_dup, num_neg参数;
使用mz.dataloader.DataLoader对上面的dataset进行封装,指定batch_size, stage, resample, sort, shuffle等参数
定义模型并制定模型参数,使用model.build()定义好模型的各个模块;
定义优化器和模型的训练类,使用mz.trainers.Trainer,指定model, optimizer, trainloader, validloader, epochs等参数;
使用trainer.run()进行训练。
https://zhuanlan.zhihu.com/p/94085483
https://github.com/NTMC-Community/MatchZoo-py
还没看完https://blog.csdn.net/qq_34182808/article/details/103027879

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值