此工程解读链接(建议按顺序阅读):
RNN代码解读之char-RNN with TensorFlow(model.py)
RNN代码解读之char-RNN with TensorFlow(train.py)
RNN代码解读之char-RNN with TensorFlow(util.py)
RNN代码解读之char-RNN with TensorFlow(sample.py)
终于到了最后,在这里我们用到了sample.py以及model.py里面的sample方法。
在采样过程中,要注意batch_size和sequence_length都是1了,我们只需要输入一个,根据这一个字符计算下一个就好了,因此在model中,某些张量的尺寸,比如说prob,就会改变,这一点在下面也有注明。
我在这里有一个问题,希望可以获得大家的指点。代码中sample设置了三种方法,其中用到了一种叫weighted_pick的方法,感觉像是在概率分布函数中随机插值取样,这里不太懂为什么要这么做,取最大不是更好吗?希望大家不吝赐教,非常感谢!
#-*-coding:utf-8-*-
from __future__ import print_function
import numpy as np
import tensorflow as tf
import argparse
import time
import os
from six.moves import cPickle
from utils import TextLoader
from model import Model