seq2seq中,需注意以下代码:
(1)生成source_int和target_int用到了字典的get()方法,source_letter_to_int[‘UNK’]表示找不到的键返回的默认值。
source_int = [[source_letter_to_int.get(letter, source_letter_to_int['<UNK>'])
for letter in line] for line in source_data.split('\n')]
target_int = [[target_letter_to_int.get(letter, target_letter_to_int['<UNK>'])
for letter in line] + [target_letter_to_int['<EOS>']] for line in target_data.split('\n')]
(2)next(get_batches())
(valid_targets_batch, valid_sources_batch, valid_targets_lengths, valid_sources_lengths) = next(
get_batches(valid_target, valid_source, batch_size, source_letter_to_int['<PAD>'], target_letter_to_int['<PAD>']))
- 当训练数据不够一个batch_size的时候,next(get_batches())会报错StopIteration。改为如下代码虽能暂时解决,但当数据超过一个batch_size的时候,这里会陷入死循环:
try:
while True:
(valid_targets_batch, valid_sources_batch, valid_targets_lengths, valid_sources_lengths) = next(
get_batches(valid_target, valid_source, batch_size, source_letter_to_int['<PAD>'], target_letter_to_int['<PAD>']))
except StopIteration:
pass
- 所以要解决的是数据量不够的问题,使之达到好几个batch_size,不能为了数据委屈代码。