完整的代码:Github
num_skips与skip_window之间的关系
很多人都不理解num_skips与skip_window之间的关系,skip_window这个参数限制了采样的范围,skip_window=1就是在输入单词的左右各一个单词范围内采样,skip_dow=2就是在输入单词的左右各2个单词的范围内采样,num_skips参数是在skip_window规定的范围内采样多少个,比如skip_window=2的时候总共可以采样4个(input, output)单词对,num_skips=2就表示在4个单词对中选择2个单词对作为训练数据。
generate_batch函数理解
这个函数刚开始看很迷,到后来一步一步调试才看清楚。
变量说明:
假设batch_size=8,num_skips=2,skip_window=1。
indexs
:中存的是要训练的单词的id。
buffer
:是一个长度为2×skip_window+1的滑动窗口。
假设
i
n
d
e
x
s
=
[
5234
,
3081
,
12
,
6
,
195
,
2
,
3134
,
.
.
.
]
indexs=[5234, 3081, 12, 6, 195, 2, 3134, ...]
indexs=[5234,3081,12,6,195,2,3134,...],
下面就看着代码一起来理解。
def generate_batch(batch_size, num_skips, skip_window):
'''
生成训练数据
:param batch_size: 表示每个批次大小
:param num_skips: skip的数量,就是从上下文窗口采样的数量,batch_size%num_skips == 0为true
:param skip_window: 窗口大小,单方向的,2*skip_window需要大于等于num_skips
:return:
'''
global data_index
assert batch_size % num_skips == 0
assert num_skips <= 2* skip_window
batch = np.ndarray(shape=[batch_size], dtype=np.int32)
labels = np.ndarray(shape=[batch_size, 1], dtype=np.int32)
span = 2*skip_window+1 ## 3
buffer = collections.deque(maxlen=span)
for _ in range(span):
buffer.append(indexs[data_index])
data_index = (data_index+1) % len(indexs)
## 到这一步的时候data_index=3, buffer中是[5234, 3081, 12]
for i in range(batch_size // num_skips):
target = skip_window ## input word, 是buffer的中间位置
target_to_avoid = [skip_window] ## 记录已经选择的位置列表
for j in range(num_skips):
while target in target_to_avoid:
target = random.randint(0, span-1) ## 选择一个不在target_to_avoid的单词
target_to_avoid.append(target) ## 添加到已选择列表
# 添加一个输入和标签
batch[i*num_skips+j] = buffer[skip_window] ## input word
labels[i*num_skips+j, 0] = buffer[target] ## output word
## 加入一个新的词,索引+1
buffer.append(indexs[data_index]) ## 此时buffer中是[3081, 12, 6]
data_index = (data_index+1) % len(indexs) ## data_index = 4
return batch, labels
所以看到这里可以看出,index相当于是一个纸带,buffer是一个滑动窗口在上面移动,每次移动一个单词的长度,而data_index就相当于是一个指针,指定buffer移动的下一个单词。