data_util

# -- encoding:utf-8 --

import re
import os
import jieba
import json
import numpy as np
from collections import defaultdict

re_han = re.compile(r"([\u4E00-\u9FD5]+)", re.U)
PAD = "<PAD>"  # 用于数据填充的
UNKNOWN = "<UNKNOWN>"  # 用于训练数据中不存在的数据


def split_sentence(sentence):
    """
    对给定的文本进行分词处理
    :param sentence:
    :return:
    """
    for word in jieba.lcut(sentence):
        # NOTE: 对这个单词的进行处理:如果这个单词是中文,那么直接返回,否则拆分成字符返回
        if re_han.match(word):
            yield word
        else:
            for ch in word:
                yield ch


def convert_sentence_to_words(in_file, out_file, encoding='utf-8-sig'):
    """
    对输入文件的数据进行分词处理,结果保存到输出文件中
    :param in_file:
    :param out_file:
    :param encoding:
    :return:
    """
    with open(in_file, 'r', encoding=encoding) as reader:
        with open(out_file, 'w', encoding=encoding) as writer:
            for sentence in reader:
                # 1. 前后空格去除
                sentence = sentence.strip()
                # 2. 数据过滤
                if len(sentence) == 0:
                    continue
                # 3. 数据转换为单词列表
                words = split_sentence(sentence)
                # 4. 结果数据输出
                result = " ".join(words)
                writer.writelines("%s\n" % result)


def build_dictionary(in_file, out_file, encoding='utf-8-sig', min_count=5):
    """
    基于分词好的数据,进行字典的构建
    :param in_file:
    :param out_file:
    :param encoding:
    :param min_count:最少出现次数,如果出现次数小于该值,那么直接不作为单词列表
    :return:
    """
    # 1. 读取单词数据
    words = defaultdict(int)
    with open(in_file, 'r', encoding=encoding) as reader:
        for sentence in reader:
            # 1. 前后空格隔开
            sentence = sentence.strip()
            # 2. 文本转换为单词、过滤&处理单词
            for word in sentence.split(" "):
                if len(word) > 0:
                    words[word] += 1
    # 2. 过滤+排序
    words = sorted(map(lambda t: t[0], filter(lambda t: t[1] >= min_count, words.items())))

    # 3. 添加特殊单词
    words = [PAD, UNKNOWN] + words

    # 4. 结果数据保存磁盘
    dir_name = os.path.dirname(out_file)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    json.dump(words, open(out_file, 'w', encoding=encoding), indent=4, ensure_ascii=False)


def build_record(words, window=4, structure='cbow', allow_padding=True):
    """
    基于给定的参数获取训练数据,并以数组的形式返回
    :param words:
    :param window:
    :param structure:
    :param allow_padding:
    :return:
    """
    n_words = len(words)
    start, end = (0, n_words) if allow_padding else (window // 2, n_words - window + window // 2)
    for idx in range(start, end):
        # 假定中心词为idx,然后分别获取上下文的单词
        center_word = words[idx]  # 中心词
        surrounding_words = words[idx - window // 2:idx]  # 上下文单词(前部分的单词)
        surrounding_words = [PAD] * (window // 2 - len(surrounding_words)) + surrounding_words  # 如果前面没有单词,那么进行填充
        surrounding_words = surrounding_words + words[idx + 1:idx + window - len(surrounding_words) + 1]
        surrounding_words = surrounding_words + [PAD] * (window - len(surrounding_words))
        if structure == 'cbow':
            yield surrounding_words + [center_word]
        else:
            yield [center_word] + surrounding_words


def convert_words_to_record(in_file, out_file, encoding='utf-8-sig', window=4, structure='cbow', allow_padding=True):
    """
    将原始数据进行转换,并输出到磁盘
    :param in_file:
    :param out_file:
    :param encoding:
    :param window:
    :param structure:
    :param allow_padding:构建数据的时候是否允许填充
    :return:
    """
    with open(in_file, 'r', encoding=encoding) as reader:
        with open(out_file, 'w', encoding=encoding) as writer:
            for sentence in reader:
                # 1. 前后空格去除+转换为单词序列
                words = sentence.strip().split(" ")
                # 2. 数据过滤
                if len(words) == 0:
                    continue
                if not allow_padding and len(words) <= window:
                    # 如果不允许填充,那么数据直接过滤
                    continue
                # 3. 基于生成的数据,输出
                for record in build_record(words, window, structure, allow_padding):
                    writer.writelines("%s\n" % " ".join(record))


class DataManager(object):
    def __init__(self, data_path, dictionary_path, window=4, structure='cbow',
                 batch_size=8, encoding='utf-8-sig', shuffle=True):
        """
        基于给定的参数构建数据
        :param data_path:  数据路径
        :param dictionary_path:  字典路径
        :param window:窗口大小,那么数据数量实际为window+1
        :param structure: 数据的结构
        :param batch_size:  批次大小
        :param encoding: 数据文件编码格式
        :param shuffle: 是否混淆数据
        """
        self.structure = structure
        self.window = window
        self.batch_size = batch_size
        self.shuffle = shuffle
        # 一、构建字典
        # 1. 从磁盘加载单词列表
        words = json.load(open(dictionary_path, 'r', encoding=encoding))
        # 2. 构建单词和id之间的映射关系
        self.word_size = len(words)  # 总共多少个单词
        self.word_to_id = dict(zip(words, range(self.word_size)))  # 字典
        self.id_to_word = words  # 列表

        # 二、数据加载
        X = []
        Y = []
        unknown_id = self.word_to_id[UNKNOWN]
        with open(data_path, 'r', encoding=encoding) as reader:
            for line in reader:
                # a. 将数据划分为单词
                sample_words = line.strip().split(" ")
                if len(sample_words) != self.window + 1:
                    continue
                    # raise Exception("数据格式异常,请检查!!数据为:{}".format(line))
                # b. 将单词转换为id
                sample_word_ids = [self.word_to_id.get(word, unknown_id) for word in sample_words]
                # c. 按照不同的结构划分出来x和y
                if self.structure == 'cbow':
                    # 上下文预测中心词
                    x = sample_word_ids[:-1]
                    y = sample_word_ids[-1:]
                else:
                    x = sample_word_ids[0:1]
                    y = sample_word_ids[1:]
                # d. 将X和Y添加到最终的数据集合中
                X.append(x)
                Y.append(y)
        self.X = np.asarray(X)  # X, [total_sample,window] or [total_sample,1]
        self.Y = np.asarray(Y)  # Y, [total_sample,1] or [total_sample,window]
        self.total_samples = len(self.X)  # 总数据量
        self.total_batch = int(np.ceil(self.total_samples / self.batch_size))  # 往上取整,总批次

    def __iter__(self):
        # 1. 产生遍历的所有序列
        if self.shuffle:
            total_index = np.random.permutation(self.total_samples)
        else:
            total_index = np.arange(self.total_samples)

        # 2. 按照批次获取数据
        for batch_index in range(self.total_batch):
            # a. 获取当前批次对应的数据索引
            start = batch_index * self.batch_size
            end = start + self.batch_size
            index = total_index[start:end]
            # b. 获取对应的批次数据,并返回
            batch_x = self.X[index]
            batch_y = self.Y[index]
            # c. 结果返回
            yield batch_x, batch_y

        # 3. 结束批次
        raise StopIteration

    def __len__(self):
        return self.total_batch


if __name__ == '__main__':
    datamanager = DataManager(
        data_path="../data/train.cbow.data",
        dictionary_path="../data/dictionary.json",
        structure='cbow',
        batch_size=8,
        encoding='utf-8-sig',
        shuffle=True
    )
    print(len(datamanager))
    for batch_x, batch_y in datamanager:
        print(batch_x)
        print(batch_y)
        print("nihao")
        break

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值