TensorFlow官网上word2vec产生batch的函数讲解。

完整的代码Github

num_skips与skip_window之间的关系

很多人都不理解num_skips与skip_window之间的关系,skip_window这个参数限制了采样的范围,skip_window=1就是在输入单词的左右各一个单词范围内采样,skip_dow=2就是在输入单词的左右各2个单词的范围内采样,num_skips参数是在skip_window规定的范围内采样多少个,比如skip_window=2的时候总共可以采样4个(input, output)单词对,num_skips=2就表示在4个单词对中选择2个单词对作为训练数据。

generate_batch函数理解

这个函数刚开始看很迷,到后来一步一步调试才看清楚。
变量说明:
假设batch_size=8,num_skips=2,skip_window=1。
indexs:中存的是要训练的单词的id。
buffer:是一个长度为2×skip_window+1的滑动窗口。

假设 i n d e x s = [ 5234 , 3081 , 12 , 6 , 195 , 2 , 3134 , . . . ] indexs=[5234, 3081, 12, 6, 195, 2, 3134, ...] indexs=[5234,3081,12,6,195,2,3134,...]
下面就看着代码一起来理解。

def generate_batch(batch_size, num_skips, skip_window):
    '''
    生成训练数据
    :param batch_size: 表示每个批次大小
    :param num_skips: skip的数量,就是从上下文窗口采样的数量,batch_size%num_skips == 0为true
    :param skip_window: 窗口大小,单方向的,2*skip_window需要大于等于num_skips
    :return:
    '''
    global data_index

    assert batch_size % num_skips == 0
    assert num_skips <= 2* skip_window

    batch = np.ndarray(shape=[batch_size], dtype=np.int32)
    labels = np.ndarray(shape=[batch_size, 1], dtype=np.int32)

    span = 2*skip_window+1     ## 3
    buffer = collections.deque(maxlen=span)
    for _ in range(span):
        buffer.append(indexs[data_index])
        data_index = (data_index+1) % len(indexs)
    ## 到这一步的时候data_index=3, buffer中是[5234, 3081, 12]

    for i in range(batch_size // num_skips):
        target = skip_window     ## input word, 是buffer的中间位置
        target_to_avoid = [skip_window]    ## 记录已经选择的位置列表
        for j in range(num_skips):
            while target in target_to_avoid:
                target = random.randint(0, span-1)     ## 选择一个不在target_to_avoid的单词
            target_to_avoid.append(target)    ## 添加到已选择列表
            # 添加一个输入和标签
            batch[i*num_skips+j] = buffer[skip_window]    ## input word
            labels[i*num_skips+j, 0] = buffer[target]     ## output word
        ## 加入一个新的词,索引+1
        buffer.append(indexs[data_index])      ## 此时buffer中是[3081, 12, 6]
        data_index = (data_index+1) % len(indexs)   ## data_index = 4

    return batch, labels

所以看到这里可以看出,index相当于是一个纸带,buffer是一个滑动窗口在上面移动,每次移动一个单词的长度,而data_index就相当于是一个指针,指定buffer移动的下一个单词。
在这里插入图片描述

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值