pytorch之诗词生成--2

先上代码:

# -*- coding: utf-8 -*-
# @File    : dataset.py
# @Author  : AaronJny
# @Time    : 2019/12/30
# @Desc    : 构建数据集
from collections import Counter
import math
import numpy as np
import tensorflow as tf
import settings


class Tokenizer:
    """
    分词器
    """

    def __init__(self, token_dict):
        # 词->编号的映射
        self.token_dict = token_dict
        # 编号->词的映射
        self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
        # 词汇表大小
        self.vocab_size = len(self.token_dict)

    def id_to_token(self, token_id):
        """
        给定一个编号,查找词汇表中对应的词
        :param token_id: 带查找词的编号
        :return: 编号对应的词
        """
        return self.token_dict_rev[token_id]

    def token_to_id(self, token):
        """
        给定一个词,查找它在词汇表中的编号
        未找到则返回低频词[UNK]的编号
        :param token: 带查找编号的词
        :return: 词的编号
        """
        return self.token_dict.get(token, self.token_dict['[UNK]'])

    def encode(self, tokens):
        """
        给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
        :param tokens: 待编码字符串
        :return: 编号序列
        """
        # 加上开始标记
        token_ids = [self.token_to_id('[CLS]'), ]
        # 加入字符串编号序列
        for token in tokens:
            token_ids.append(self.token_to_id(token))
        # 加上结束标记
        token_ids.append(self.token_to_id('[SEP]'))
        return token_ids

    def decode(self, token_ids):
        """
        给定一个编号序列,将它解码成字符串
        :param token_ids: 待解码的编号序列
        :return: 解码出的字符串
        """
        # 起止标记字符特殊处理
        spec_tokens = {'[CLS]', '[SEP]'}
        # 保存解码出的字符的list
        tokens = []
        for token_id in token_ids:
            token = self.id_to_token(token_id)
            if token in spec_tokens:
                continue
            tokens.append(token)
        # 拼接字符串
        return ''.join(tokens)


# 禁用词
disallowed_words = settings.DISALLOWED_WORDS
# 句子最大长度
max_len = settings.MAX_LEN
# 最小词频
min_word_frequency = settings.MIN_WORD_FREQUENCY
# mini batch 大小
batch_size = settings.BATCH_SIZE

# 加载数据集
with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f:
    lines = f.readlines()
    # 将冒号统一成相同格式
    lines = [line.replace(':', ':') for line in lines]
# 数据集列表
poetry = []
# 逐行处理读取到的数据
for line in lines:
    # 有且只能有一个冒号用来分割标题
    if line.count(':') != 1:
        continue
    # 后半部分不能包含禁止词
    __, last_part = line.split(':')
    ignore_flag = False
    for dis_word in disallowed_words:
        if dis_word in last_part:
            ignore_flag = True
            break
    if ignore_flag:
        continue
    # 长度不能超过最大长度
    if len(last_part) > max_len - 2:
        continue
    poetry.append(last_part.replace('\n', ''))

# 统计词频
counter = Counter()
for line in poetry:
    counter.update(line)
# 过滤掉低频词
_tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency]
# 按词频排序
_tokens = sorted(_tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
_tokens = [token for token, count in _tokens]

# 将特殊词和数据集中的词拼接起来
_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(_tokens, range(len(_tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 混洗数据
np.random.shuffle(poetry)


class PoetryDataGenerator:
    """
    古诗数据集生成器
    """

    def __init__(self, data, random=False):
        # 数据集
        self.data = data
        # batch size
        self.batch_size = batch_size
        # 每个epoch迭代的步数
        self.steps = int(math.floor(len(self.data) / self.batch_size))
        # 每个epoch开始时是否随机混洗
        self.random = random

    def sequence_padding(self, data, length=None, padding=None):
        """
        将给定数据填充到相同长度
        :param data: 待填充数据
        :param length: 填充后的长度,不传递此参数则使用data中的最大长度
        :param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号
        :return: 填充后的数据
        """
        # 计算填充长度
        if length is None:
            length = max(map(len, data))
        # 计算填充数据
        if padding is None:
            padding = tokenizer.token_to_id('[PAD]')
        # 开始填充
        outputs = []
        for line in data:
            padding_length = length - len(line)
            # 不足就进行填充
            if padding_length > 0:
                outputs.append(np.concatenate([line, [padding] * padding_length]))
            # 超过就进行截断
            else:
                outputs.append(line[:length])
        return np.array(outputs)

    def __len__(self):
        return self.steps

    def __iter__(self):
        total = len(self.data)
        # 是否随机混洗
        if self.random:
            np.random.shuffle(self.data)
        # 迭代一个epoch,每次yield一个batch
        for start in range(0, total, self.batch_size):
            end = min(start + self.batch_size, total)
            batch_data = []
            # 逐一对古诗进行编码
            for single_data in self.data[start:end]:
                batch_data.append(tokenizer.encode(single_data))
            # 填充为相同长度
            batch_data = self.sequence_padding(batch_data)
            # yield x,y
            yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
            del batch_data

    def for_fit(self):
        """
        创建一个生成器,用于训练
        """
        # 死循环,当数据训练一个epoch之后,重新迭代数据
        while True:
            # 委托生成器
            yield from self.__iter__()

下面我们逐行分析该代码:我们首先定义一个分词器类:

class Tokenizer:
    """
    分词器
    """

    def __init__(self, token_dict):
        # 词->编号的映射
        self.token_dict = token_dict
        # 编号->词的映射
        self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
        # 词汇表大小
        self.vocab_size = len(self.token_dict)

    def id_to_token(self, token_id):
        """
        给定一个编号,查找词汇表中对应的词
        :param token_id: 带查找词的编号
        :return: 编号对应的词
        """
        return self.token_dict_rev[token_id]

    def token_to_id(self, token):
        """
        给定一个词,查找它在词汇表中的编号
        未找到则返回低频词[UNK]的编号
        :param token: 带查找编号的词
        :return: 词的编号
        """
        return self.token_dict.get(token, self.token_dict['[UNK]'])

    def encode(self, tokens):
        """
        给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
        :param tokens: 待编码字符串
        :return: 编号序列
        """
        # 加上开始标记
        token_ids = [self.token_to_id('[CLS]'), ]
        # 加入字符串编号序列
        for token in tokens:
            token_ids.append(self.token_to_id(token))
        # 加上结束标记
        token_ids.append(self.token_to_id('[SEP]'))
        return token_ids

    def decode(self, token_ids):
        """
        给定一个编号序列,将它解码成字符串
        :param token_ids: 待解码的编号序列
        :return: 解码出的字符串
        """
        # 起止标记字符特殊处理
        spec_tokens = {'[CLS]', '[SEP]'}
        # 保存解码出的字符的list
        tokens = []
        for token_id in token_ids:
            token = self.id_to_token(token_id)
            if token in spec_tokens:
                continue
            tokens.append(token)
        # 拼接字符串
        return ''.join(tokens)

看第一段:

def __init__(self, token_dict):
    # 词->编号的映射
    self.token_dict = token_dict
    # 编号->词的映射
    self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
    # 词汇表大小
    self.vocab_size = len(self.token_dict)

首先我们接受一个名为token_dict的参数,将其存储为类的属性,然后创建一个名为token_dict_rev的属性,这是token_dict的反向映射,最后,计算词汇表的大小并将其存储为vocab_size属性。

看下一段:

def id_to_token(self, token_id):
    """
    给定一个编号,查找词汇表中对应的词
    :param token_id: 带查找词的编号
    :return: 编号对应的词
    """
    return self.token_dict_rev[token_id]

这段代码定义一个方法id_to_token,接受一个名为token_id的参数,然后在词汇表中查找对应的词并返回,这个方法实际上是通过token_dict_rev属性实现的反向查找。明显,该字典中的键词的编号,值是词。

接着往下看:

def token_to_id(self, token):
    """
    给定一个词,查找它在词汇表中的编号
    未找到则返回低频词[UNK]的编号
    :param token: 带查找编号的词
    :return: 词的编号
    """
    return self.token_dict.get(token, self.token_dict['[UNK]'])

这段代码与上一段的由键到值差不多,是由值找到对应的键。接受名为token作为参数,然后在词汇表中查找对应词的编号并返回。如果词不在词汇表中,则返回低频词[UNK]的编号,注意我们的token_dict字典的键是词,值是编号,我们可以通过词来找到对应的编号,而token_dict_rev的键是编号,值是词,我们可以通过编号找到对应的值。

return self.token_dict.get(token, self.token_dict['[UNK]'])这段代码中,我们使用get方法,我们尝试在self.token_dict中获取键为token的值,也就是找到对应的编号,第二个参数表示如果没找到对应的键,则返回self.token_dict中键为[UNK]的值。(第二个参数表示字典找不到对应键时返回的默认值)。这样可以确保即使词不在词表中,也能返回一个默认值,避免了出现KeyError。

继续看代码:

def encode(self, tokens):
    """
    给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
    :param tokens: 待编码字符串
    :return: 编号序列
    """
    # 加上开始标记
    token_ids = [self.token_to_id('[CLS]'), ]
    # 加入字符串编号序列
    for token in tokens:
        token_ids.append(self.token_to_id(token))
    # 加上结束标记
    token_ids.append(self.token_to_id('[SEP]'))
    return token_ids

我们的开始标记调用了我们刚刚定义的token_to_id方法,显然,不可能出现[CLS]这个词,所以得到的是[UNK]对应的编号,显然是一个特殊的编号。
(我们看一下错误的输出,也不算错误,就是对应我们的处理词输出。)

而后遍历tokens中的每个词,将词转化为对应的编号加入到编号序列中,这样我们就可以将我们的汉字类型转化为数字,从而可以进行卷积层的处理。

随后加上结束标记的符号,显然也是对应[UNK]。最后我们返回完整的编号序列。(是一个由数字组成的列表)。

相对应的是解码:

def decode(self, token_ids):
    """
    给定一个编号序列,将它解码成字符串
    :param token_ids: 待解码的编号序列
    :return: 解码出的字符串
    """
    # 起止标记字符特殊处理
    spec_tokens = {'[CLS]', '[SEP]'}
    # 保存解码出的字符的list
    tokens = []
    for token_id in token_ids:
        token = self.id_to_token(token_id)
        if token in spec_tokens:
            continue
        tokens.append(token)
    # 拼接字符串
    return ''.join(tokens)

我们先将特殊字符,也就是开始与结束对应的字符组成一个集合。而后我们创建了一个名为tokens的空列表,用于保存由token_ids中token_id对应词。最后我们使用join方法,将tokens列表中的字符串元素链接起来,形成一个新的字符串,在这里,''表示以空字符串作为连接符,也就是将tokens中的词无缝衔接。

接下来我们定义一些参数,这些参数在setting中已经定义,这里我们直接拿来用:

isallowed_words = settings.DISALLOWED_WORDS
# 句子最大长度
max_len = settings.MAX_LEN
# 最小词频
min_word_frequency = settings.MIN_WORD_FREQUENCY
# mini batch 大小
batch_size = settings.BATCH_SIZEr

然后我们就可以开始加载数据集了:

with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f:
    lines = f.readlines()
    # 将冒号统一成相同格式
    lines = [line.replace(':', ':') for line in lines]
# 数据集列表
poetry = []

通过在setting中已经定义好的路径用只读的方式加载我们的数据,解码的类型是utf-8。f是一个对象,表示被打开的文件。文件对象f会在with代码块结束的时候自动关闭。

lines=f.readlines():这段代码从打开的文件对象f中读取所有行,并将它们存储在名为lines的列表中。(因为我们的数据集很大,所以这一步很耗时间)。
而后我们对我们的诗词进行处理,将所有行中的‘:’转化为‘:’,即格式统一,但是这里其实我们都转化为“:”也是不影响的。
然后我们创建一个数据集列表,也就是空列表。

接着我们开始对每一行(也就是一首诗)进行处理:

for line in lines:
    # 有且只能有一个冒号用来分割标题
    if line.count(':') != 1:
        continue
    # 后半部分不能包含禁止词
    __, last_part = line.split(':')
    ignore_flag = False
    for dis_word in disallowed_words:
        if dis_word in last_part:
            ignore_flag = True
            break
    if ignore_flag:
        continue
    # 长度不能超过最大长度
    if len(last_part) > max_len - 2:
        continue
    poetry.append(last_part.replace('\n', ''))

这里我们首先要参考一下数据的格式:

可见我们的每首诗在:的前面部分是诗词名,后半部分是内容,如果该行不包含:则表示是数据出现错误,这时我们直接跳过该数据,使用continue。对于没有问题的数据,我们使用split方法将数据分为前半部分诗词名(当然,直接丢掉),和第二部分内容(是我们需要的精华)。

我们定义一个布尔类型的变量ignore_flag用来判断是否将这个数据忽视。我们将禁词一一取出,如果禁词在我们的数据中出现,我们将该布尔变量设置为true,也就是要去除该数据,嵌套遍历完成后,我们通过判断布尔变量值来确定是否进行下一步处理,当然没有问题的数据,我们将其保存并进行下一步处理。

我们在进行下一步处理的时候也要进行判断,显然,当我们的数据长度较长的时候,比如(长恨歌),我们也是不需要的,这属于异常数据,我们用它作为参考生成小篇幅诗词无异于读圣经来学习小学的看图写话。

剩下的部分也就是符合我们要求的数据了,对于这些数据,我们直接将他们放进我们的列表中。注意小细节,我们将换行符转化为空格。(官方解释是确保诗词文本在处理之后仍然保持连续的完整性,而不会因为换行符被分割为很多行,有利于后期对文本的处理和分析)(但是我认为这是多余的,因为对于一行数据来代表一首诗词来说,完全没必要考虑换行符的问题)。

嗯嗯...也不是完全没用。

可见,我们生成诗词的时候,如果考虑到换行符的话,我们可以拉开我们生成的诗词的距离。

继续:

counter = Counter()
for line in poetry:
    counter.update(line)

这段代码创建了一个Counter类(计数器对象),它是collections模块中的一个数据结构,用于统计可哈希对象的出现次数。然后循环迭代poetry中的每一行,其中poetry是一个包含多行诗歌的列表。在每次迭代中,counter.update(line)都会被调用,它会将line中的字符添加到计数器中,并更新它们的出现次数,update()方法接受一个可迭代对象作为参数,它会遍历该对象并更新计数器。

最终,counter对象将会包含整个数据集中每个字符出现的次数。我们将通过一个简单的案例来说明counter函数的用法:

from collections import Counter

poetry = [
    "Roses are red,",
    "Violets are blue,",
    "Sugar is sweet,",
    "And so are you."
]

counter = Counter()
for line in poetry:
    counter.update(line)

print(counter)

输出结果如下:

Counter({' ': 15, 'e': 10, 's': 7, 'a': 6, 'o': 5, 'r': 4, 'u': 4, 't': 3, 'd': 2, 'n': 2, 'y': 2, 'w': 2, 'A': 1, 'R': 1, 'V': 1, 'i': 1, 'l': 1, 'b': 1, 'g': 1, ',': 1, 'S': 1, 'I': 1, '.': 1})

输出的是一个Counter对象。

接下来我们接着对词进行处理:

_tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency]
# 按词频排序
_tokens = sorted(_tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
_tokens = [token for token, count in _tokens]

# 将特殊词和数据集中的词拼接起来
_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(_tokens, range(len(_tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 混洗数据
np.random.shuffle(poetry)

我们首先来看第一行,创建了一个列表_tokens,用来包含计数器counter中词频大于等于min_word_frequency的词和它们的出现次数,counter.item返回的是一个键值对,键是词,值是对应的频数。

接下来,我们对_tokens列表进行排序,按照词频从高到低进行降序排序,key=lambda x:-x[1]表示使用每个元素的第二个值,即词频作为进行排序的依据。

_tokens = [token for token, count in _tokens]之后我们将排序后的列表中提取词汇,生成一个只包含词汇的列表,这里丢弃了词频信息,只包含了词汇。

而后我们将一些特殊字符,_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens 添加到_tokens列表中,即在词汇列表的最前面。

token_id_dict = dict(zip(_tokens, range(len(_tokens))))然后我们创建一个字典,字典是从词汇到ID的映射关系,当然,前几个索引对应的是特殊词汇,后面按照词汇出现的频率一次对应索引。当然,得到的结果是一个字典。(由词汇到索引)

我们将这个字典传到我们的分词器中,会自动生成由索引到词的映射,以及得到该字典的长度(即词的个数)。

然后我们将我们的诗词的列表进行混洗。

而后我们又定义了一个古诗数据集生成器:

class PoetryDataGenerator:
    """
    古诗数据集生成器
    """

    def __init__(self, data, random=False):
        # 数据集
        self.data = data
        # batch size
        self.batch_size = batch_size
        # 每个epoch迭代的步数
        self.steps = int(math.floor(len(self.data) / self.batch_size))
        # 每个epoch开始时是否随机混洗
        self.random = random

    def sequence_padding(self, data, length=None, padding=None):
        """
        将给定数据填充到相同长度
        :param data: 待填充数据
        :param length: 填充后的长度,不传递此参数则使用data中的最大长度
        :param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号
        :return: 填充后的数据
        """
        # 计算填充长度
        if length is None:
            length = max(map(len, data))
        # 计算填充数据
        if padding is None:
            padding = tokenizer.token_to_id('[PAD]')
        # 开始填充
        outputs = []
        for line in data:
            padding_length = length - len(line)
            # 不足就进行填充
            if padding_length > 0:
                outputs.append(np.concatenate([line, [padding] * padding_length]))
            # 超过就进行截断
            else:
                outputs.append(line[:length])
        return np.array(outputs)

    def __len__(self):
        return self.steps

    def __iter__(self):
        total = len(self.data)
        # 是否随机混洗
        if self.random:
            np.random.shuffle(self.data)
        # 迭代一个epoch,每次yield一个batch
        for start in range(0, total, self.batch_size):
            end = min(start + self.batch_size, total)
            batch_data = []
            # 逐一对古诗进行编码
            for single_data in self.data[start:end]:
                batch_data.append(tokenizer.encode(single_data))
            # 填充为相同长度
            batch_data = self.sequence_padding(batch_data)
            # yield x,y
            yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
            del batch_data

    def for_fit(self):
        """
        创建一个生成器,用于训练
        """
        # 死循环,当数据训练一个epoch之后,重新迭代数据
        while True:
            # 委托生成器
            yield from self.__iter__()

我们从头进行分析:

def __init__(self, data, random=False):
    # 数据集
    self.data = data
    # batch size
    self.batch_size = batch_size
    # 每个epoch迭代的步数
    self.steps = int(math.floor(len(self.data) / self.batch_size))
    # 每个epoch开始时是否随机混洗
    self.random = random

我们接受data和可选的random参数,在方法内部,我们将传入的data赋值给self.data,并确定了batch_size属性,我们之后通过数据集的长度和每个批次的长度来计算每一轮训练多少个批次(也就是步数)。
self.random=random表示每个epoch开始时是否随机混洗数据,它的值等于传入的random参数,默认为不随机混洗。

继续看代码:

def sequence_padding(self, data, length=None, padding=None):
    """
    将给定数据填充到相同长度
    :param data: 待填充数据
    :param length: 填充后的长度,不传递此参数则使用data中的最大长度
    :param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号
    :return: 填充后的数据
    """
    # 计算填充长度
    if length is None:
        length = max(map(len, data))
    # 计算填充数据
    if padding is None:
        padding = tokenizer.token_to_id('[PAD]')
    # 开始填充
    outputs = []
    for line in data:
        padding_length = length - len(line)
        # 不足就进行填充
        if padding_length > 0:
            outputs.append(np.concatenate([line, [padding] * padding_length]))
        # 超过就进行截断
        else:
            outputs.append(line[:length])
    return np.array(outputs)

我们使用sequence_padding方法,用于将给定的数据填充到相同的长度:

我们传入参数分别是数据,长度,填充的字符。
我们默认填充后的长度是我们数据中的最大长度,这也是我们为什么使用的是64作为最大长度,而将诗词较长的数据进行去除。(不太适合生成长恨歌那样的诗词)。
我们填充的数据编号是PAD对应的编号,即解码的时候对应的也是PAD。
之后我们进行填充,计算出每一行需要填充的长度(归一化长度后的长度减去当前的长度),如果需要进行填充,我们将原数据拼接填充内容作为填充之后的数据。将填充之后的数据放入我们的outputs列表中,否则的话(数据大于我们的最大数据,虽然理论上是不可能的,但是我们也是写一下吧,就只留下到最大长度为止的数据。)

这里值得注意的是,我们传入的是由索引组成的列表。我们得到的也是由数据列表组成的列表,我们通过np.array(outputs)将列表outputs转化为一个numpy数组,其中每个元素对应列表中子列表。便于进一步处理数据。

接下来我们通过__len__来返回步长:

def __len__(self):
    return self.steps

即每轮训练多少个批次,在这里,初始化的时候已经计算好了。

继续哈:

def __iter__(self):
    total = len(self.data)
    # 是否随机混洗
    if self.random:
        np.random.shuffle(self.data)
    # 迭代一个epoch,每次yield一个batch
    for start in range(0, total, self.batch_size):
        end = min(start + self.batch_size, total)
        batch_data = []
        # 逐一对古诗进行编码
        for single_data in self.data[start:end]:
            batch_data.append(tokenizer.encode(single_data))
        # 填充为相同长度
        batch_data = self.sequence_padding(batch_data)
        # yield x,y
        yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
        del batch_data

首先我们获取总样本数,也就是我们的诗词个数,如果self.random=True表示每个epoch开始时需要随机混洗数据集,因此使用np.random.shuffle随机打乱self.data。

然后使用for循环进入每个批次进行训练,(以批次大小为步长遍历数据集,每次迭代都产生一个批次的数据)。用start和end分别表示训练数据开始和结束对应的索引,这里我们要考虑当用累加计算结束位置的时候,不要超过数据的长度。

然后我们逐一对古诗进行编码,将编码得到的结果送入空列表batch_data中,这里要注意我们得到的tokenizer.encode(single_data)是一个由数字组成的列表设为A,然后送入batch_data得到的是一个由A组成的列表,对这个列表进行padding处理,将这个列表中每首诗对应的列表进行扩充。(填充到相同的长度)。

yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
        del batch_data

这一行代码使用yield语句生成一个批次的数据。它返回两个值:batch_data[:,:-1]输入数据,是经过填充的故事序列编码,去掉每个序列的最后一个词,它的形状是(batch_size,sequence_length-1)。(最后一个是句号哦)。
tf.one_hot(batch_data[;,1:],tokenizer.vocab_size)这段代码的目的是将目标数据进行编码,并在这个过程中去掉每个序列中的第一个词,进行独热编码。tokenizer.vocab_size是词汇表的大小,用于确定独热编码的维度。它的形状是(batch_size,sequence_length-1,tokenizer.vocab_size)。

最后del batch_data:

删除批次数据batch_data释放内存,在每次迭代后我们就不需要存储整个批次的数据,因此可以通过删除来释放内存。

为什么删除第一个和最后一个呢?因为我们的起始位置和结束都使用特殊字符进行编码。

最后我们使用:

def for_fit(self):
    """
    创建一个生成器,用于训练
    """
    # 死循环,当数据训练一个epoch之后,重新迭代数据
    while True:
        # 委托生成器
        yield from self.__iter__()

这里我们创建一个死循环,表示生成器会无限制的生成数据,这是为了在训练过程中能持续获取数据,这里使用yield from语法来委托另外一个生成器,即self.__iter__()方法生成的数据,委托生成器的作用是将self.__iter__()生成的数据直接传递给外部的迭代器,作为训练数据。

通过这种方式,当调用for_fit方法时,会得到一个生成器对象,每次迭代该生成器,会从self.__iter__()生成的数据中获取一个批次的训练数据,并将其作为生成器的输出,由于采用了死循环的设置,这个生成器会持续的生成数据,直到外部的训练过程停止或中断。

  • 28
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

丘小羽

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值