8.6,8.7,8.8 batch_flow

import random
import numpy as np
from tensorflow.python.client import device_lib
from word_sequence import WordSequence

VOCAB_SIZE_THRESHOLD_CPU= 50000

def _get_available_gpus():
    """获取当前的GPU信息"""
    local_device_protos=device_lib.list_local_devices()
    return [x.name for x in local_device_protos if x.device_type=='GPU']


def _get_embed_device(vocab_size):
    """根据输入输出的字典大小来选择,是在CPU上embedding还是在GPU上进行embedding"""
    gpus=_get_available_gpus()
    if not gpus or vocab_size > VOCAB_SIZE_THRESHOLD_CPU:
        return "/cpu:0"
    return "/gpu:0"


def transform_sentence(sentence, ws, max_len=None, add_end=False):
    """单独的句子转换"""
    encoded=ws.transform(
        sentence,
        max_len=max_len if max_len is not None else len(sentence))
    encoded_len= len(sentence)+(1 if add_end else 0)
    if encoded_len>len(encoded):
        encoded_len=len(encoded)

    return encoded, encoded_len


def batch_flow(data, ws, batch_size, raw=False, add_end=True):
    """
    从数据中随机去生成batch_size的数据,然后给转换后输出出去
    row:是否返回原始对象,如果为True,假设结果ret,那么len(ret)==Len(data)*3
    如果为false,那么len(ret)==len(data)*2
    Q=(q1,q2,q3...qn)
    A=(a1,a2,a3...an)
    len(Q)==len(A)
    batch_flow([Q,A],ws,batch_size=32)
    row=False:
    next(generator)==q_i_encoded,a_i_encoded,a_i_len
    row=True:
    next(generator)==q_i_encoded,q_i,q_i_encoded,a_i_len,a_i
    """
    #ws数量要和data数量保持一致(多个),len(data)==len(ws)
    all_data=list(zip(*data))
    if isinstance(ws,(list,tuple)):
        assert len(ws)==len(data),'ws的长度必须等于data的长度,if ws是一个list or tuple'

    if isinstance(add_end, bool):
        add_end=[add_end]*len(data)
    else:
        assert(isinstance(add_end,(list,tuple))),'add_end不是boolean,就应该是一个list(tuple) of boolean'
        assert len(add_end)==len(data), '如果add_end是list(tuple),那么add_end的长度应该和输入数据的长度一致'

    mul=2
    if raw:
        mul=3

    while True:
        data_batch=random.sample(all_data, batch_size) #在all_data数据中随机抽取生成batch_size个数据
        batches=[[] for i in range(len(data)*mul)]

        max_lens=[]
        for j in range(len(data)):
            max_len=max([
                len(x[j]) if hasattr(x[j],'__len__') else 0
                for x in data_batch
            ])+(1 if add_end[j] else 0)
            max_lens.append(max_len)
        for d in data_batch:
            for j in range(len(data)):
                if isinstance(ws,(list,tuple)):
                    w=ws[j]
                else:
                    w=ws
                #添加结束标记(结尾)
                line=d[j]
                if add_end[j] and isinstance(line,(tuple,list)):
                    line=list(line)+[WordSequence.END_TAG]
                if w is not None:
                    x,xl=transform_sentence(line,w,max_lens[j],add_end[j])
                    batches[j*mul].append(x)
                    batches[j*mul+1].append(xl)
                else:
                    batches[j*mul].append(line)
                    batches[j*mul+1].append(line)
                if raw:
                    batches[j*mul+2].append(line)
                batches=[np.asarray(x) for x in batches]
                yield batches
if __name__=='__main__':
    size=30000
    print(_get_embed_device(size))

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值