matchzoo中文适配笔记

中文适配需要参考源代码,同时修改源码,可以分两种情况使用
1、下载源码修改然后当项目运行,
2、pip install matchzoo 参考源码把中文需要改的地方都重新实现一遍,代码中调用matchzoo使用。

一、修改源代码

参考tutorials中的init
1、首先修改获取数据的文件,添加文件matchzoo/datasets/chinese_qa/load_data.py

"""WikiQA data loader."""
def load_data(
    stage: str = 'train',
    task: str = 'ranking',
    filtered: bool = False,
    return_classes: bool = False
) -> typing.Union[matchzoo.DataPack, tuple]:
    if stage not in ('train', 'dev', 'test'):
        raise ValueError(f"{stage} is not a valid stage."
                         f"Must be one of `train`, `dev`, and `test`.")
    data_root = Path(".../LCQMC/data")
    file_path = data_root.joinpath(f'{stage}.txt')
    data_pack = _read_data(file_path)
    
    if task == 'ranking':
        task = matchzoo.tasks.Ranking()
    if task == 'classification':
        task = matchzoo.tasks.Classification()

    if isinstance(task, matchzoo.tasks.Ranking):
        return data_pack
    elif isinstance(task, matchzoo.tasks.Classification):
        data_pack.one_hot_encode_label(task.num_classes, inplace=True)
        if return_classes:
            return data_pack, [False, True]
        else:
            return data_pack
    else:
        raise ValueError(f"{task} is not a valid task."
                         f"Must be one of `Ranking` and `Classification`.")

# 使用本地数据,此方法没有用
def _download_data():
    ref_path = keras.utils.data_utils.get_file(
        'wikiqa', _url, extract=True,
        cache_dir=matchzoo.USER_DATA_DIR,
        cache_subdir='wiki_qa'
    )
    return Path(ref_path).parent.joinpath('WikiQACorpus')
# 设置数据的格式为 sentence1 \t sentence2 \t Label
def _read_data(path):
    table = pd.read_csv(path, sep='\t', header=0, quoting=csv.QUOTE_NONE)
    df = pd.DataFrame({
        'text_left': table['sentence1'],
        'text_right': table['sentence2'],
        'label': table['Label']
    })
    return matchzoo.pack(df)

2、加载embedding
可以直接使用matchzoo/embedding/embedding.py中的load_from_file加载预训练的词向量。只需要设置文件路径就可以。

3、设置preprocessors,这一步设置sentence1 sentence2 的最大长度,词频过滤器,过滤最高词频,最低词频,去除停用词。
原preprocessor指定的默认单元为

    def _default_units(cls) -> list:
        """Prepare needed process units."""
        return [
            mz.preprocessors.units.tokenize.Tokenize(),
            mz.preprocessors.units.lowercase.Lowercase(),
            mz.preprocessors.units.punc_removal.PuncRemoval(),
        ]

改为

@classmethod
    def _default_chinese_units(cls) -> list:
        """Prepare needed process units."""
        return [
            # mz.preprocessors.units.tokenize.ChineseTokenize(),
            mz.preprocessors.units.tokenize.ChineseTokenize(),
            # mz.preprocessors.units.lowercase.Lowercase(),
            mz.preprocessors.units.punc_removal.PuncRemoval(),
        ]

同时添加文件matchzoo/preprocessors/chinese_preprocessor.py,实际是修改matchzoo/preprocessors/basic_preprocessor.py中__init__中的self._units = self._default_units()为self._units = self._default_chinese_units()

    def __init__(self, fixed_length_left: int = 30,
                 fixed_length_right: int = 30,
                 filter_mode: str = 'df',
                 filter_low_freq: float = 2,
                 filter_high_freq: float = float('inf'),
                 remove_stop_words: bool = False):
        """Initialization."""
        super().__init__()
        self._fixed_length_left = fixed_length_left
        self._fixed_length_right = fixed_length_right
        self._left_fixedlength_unit = units.FixedLength(
            self._fixed_length_left,
            pad_mode='post'
        )
        self._right_fixedlength_unit = units.FixedLength(
            self._fixed_length_right,
            pad_mode='post'
        )
        self._filter_unit = units.FrequencyFilter(
            low=filter_low_freq,
            high=filter_high_freq,
            mode=filter_mode
        )
        self._units = self._default_chinese_units()
        if remove_stop_words:
            self._units.append(units.stop_removal.StopRemoval())

同时将训练时的prepossor指定为

preprocessor_class = matchzoo.preprocessors.chinese_preprocessor.ChinesePreprocessor()
model, preprocessor, data_generator_builder, embedding_matrix = matchzoo.auto.prepare(
    task=task,
    model_class=model_class,
    preprocessor=preprocessor_class,
    data_pack=train_raw,
    embedding=emb
)

4、完整训练代码

import matchzoo
task = matchzoo.tasks.Ranking()
print(task)

train_raw = matchzoo.datasets.chinese_qa.load_data(stage='train', task=task)  #qa是datasets下新建的包,放置中文数据
test_raw = matchzoo.datasets.chinese_qa.load_data(stage='test', task=task)

print(train_raw.left.head())
print(train_raw.right.head())
print(train_raw.relation.head())
print(train_raw.frame().head())  #数据格式如下图3

emb = matchzoo.embedding.load_from_file(matchzoo.datasets.embeddings.EMBED_CPWS, mode='word2vec')  #加载word2vec词向量

model_class = matchzoo.models.ArcI
preprocessor_class = matchzoo.preprocessors.chinese_preprocessor.ChinesePreprocessor()
print(preprocessor_class)
# preprocessor = matchzoo.preprocessors.BasicPreprocessor()

model, preprocessor, data_generator_builder, embedding_matrix = matchzoo.auto.prepare(
    task=task,
    model_class=model_class,
    preprocessor=preprocessor_class,
    data_pack=train_raw,
    embedding=emb
)

print(model.params)   # 展示模型中可调参数
model.params['mlp_num_units'] = 3  # 直接调整参数
print("embedding_matrix: \n", type(embedding_matrix), '\n', embedding_matrix)
train_processed = preprocessor.transform(train_raw, verbose=0)
test_processed = preprocessor.transform(test_raw, verbose=0)

vocab_unit = preprocessor.context['vocab_unit']   # 此部分是为了显示处理过程
print('Orig Text:', train_processed.left.loc['L-0']['text_left'])
sequence = train_processed.left.loc['L-0']['text_left']
print('Transformed Indices:', sequence)
print('Transformed Indices Meaning:',
      '/'.join([vocab_unit.state['index_term'][i] for i in sequence]))

train_gen = data_generator_builder.build(train_processed)
test_gen = data_generator_builder.build(test_processed)
model.fit_generator(train_gen, epochs=1)
model.evaluate_generator(test_gen)

# model.save('my-model')  #保存模型
# loaded_model = matchzoo.load_model('my-model')  #加载模型

二、不修改源码

import matchzoo as mz
import typing
from pathlib import Path
import pandas as pd
import csv

def read_data(path):
    table = pd.read_csv(path, sep='\t', header=0, quoting=csv.QUOTE_NONE)
    df = pd.DataFrame({
        'text_left': table['sentence1'],
        'text_right': table['sentence2'],
        'label': table['Label']
    })
    return mz.pack(df)

def load_data(
    stage: str = 'train',
    task: str = 'ranking',
    filtered: bool = False,
    return_classes: bool = False
) -> typing.Union[mz.DataPack, tuple]:
    if stage not in ('train', 'dev', 'test'):
        raise ValueError(f"{stage} is not a valid stage."
                         f"Must be one of `train`, `dev`, and `test`.")

    data_root = Path("/corpus/LCQMC/data")
    file_path = data_root.joinpath(f'{stage}.txt')
    data_pack = read_data(file_path)
    if task == 'ranking':
        task = mz.tasks.Ranking()
    if task == 'classification':
        task = mz.tasks.Classification()
    if isinstance(task, mz.tasks.Ranking):
        return data_pack
    elif isinstance(task, mz.tasks.Classification):
        data_pack.one_hot_encode_label(task.num_classes, inplace=True)
        if return_classes:
            return data_pack, [False, True]
        else:
            return data_pack
    else:
        raise ValueError(f"{task} is not a valid task."
                         f"Must be one of `Ranking` and `Classification`.")
task = mz.tasks.Ranking()    
train_raw = load_data(stage='train', task=task)  #qa是datasets下新建的包,放置中文数据
test_raw = load_data(stage='test', task=task)
print(train_raw.left.head())
print(train_raw.right.head())
print(train_raw.relation.head())
print(train_raw.frame().head())
path_vec = "/word2vec/WordVector_60dimensional/wiki.zh.text.vector"
emb = mz.embedding.load_from_file(path_vec, mode='word2vec')
# print(emb.shape)
print(type(emb))
model_class = mz.models.ArcI
# preprocessor_class = mz.preprocessors.chinese_preprocessor.ChinesePreprocessor()
# print(preprocessor_class)
preprocessor_class = mz.preprocessors.BasicPreprocessor()
preprocessor_class._units = [
            # mz.preprocessors.units.tokenize.ChineseTokenize(),
            mz.preprocessors.units.tokenize.ChineseTokenize(),
            # mz.preprocessors.units.lowercase.Lowercase(),
            mz.preprocessors.units.punc_removal.PuncRemoval(),
        ]

model, preprocessor, data_generator_builder, embedding_matrix = mz.auto.prepare(
    task=task,
    model_class=model_class,
    preprocessor=preprocessor_class,
    data_pack=train_raw,
    embedding=emb
)
print(model.params)   # 展示模型中可调参数
model.params['mlp_num_units'] = 3  # 直接调整参数
print("embedding_matrix: \n", type(embedding_matrix), '\n', embedding_matrix)
train_processed = preprocessor.transform(train_raw, verbose=0)
test_processed = preprocessor.transform(test_raw, verbose=0)

vocab_unit = preprocessor.context['vocab_unit']   # 此部分是为了显示处理过程
print('Orig Text:', train_processed.left.loc['L-0']['text_left'])
sequence = train_processed.left.loc['L-0']['text_left']
print('Transformed Indices:', sequence)
print('Transformed Indices Meaning:',
      '/'.join([vocab_unit.state['index_term'][i] for i in sequence]))

train_gen = data_generator_builder.build(train_processed)
test_gen = data_generator_builder.build(test_processed)
model.fit_generator(train_gen, epochs=1)
model.evaluate_generator(test_gen)

参考:

1、https://github.com/NTMC-Community/MatchZoo 源代码

2、https://blog.csdn.net/wkh7717/article/details/89886713?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值