阅读了不少关于yield的博客,仍不是很理解它的用法。
确实,在实例面前,再多的语言都是苍白的,这里参考github上一个tensorflow例子中用于产生batch的子函数来帮助理解yield:
def batch_yield(data, batch_size, vocab, tag2label, shuffle=False):
if shuffle:
random.shuffle(data)
seqs, labels = [], []
for (sent_, tag_) in data:
# sent_是一个列表(文本序列),tag也是一个列表(文本序列对应的标签)
# 使用sentence2id将文本序列映射为数值序列,为自己定义的一个文本处理函数
sent_ = sentence2id(sent_, vocab)
# 使用tag2label将tag映射为数值,为自己定义的一个文本处理函数
label_ = [tag2label[tag] for tag in tag_]
l = len(seqs)
if len(seqs) == batch_size:
yield seqs, labels
seqs, labels = [], []
seqs.append(sent_)
labels.append(label_)
# len()
if len(seqs) != 0:
yield seqs, labels
train_data = read_corpus(train_path) # 读取数据集,为自己