Python dict格式的Dataset类定义

前言

今天看到了一个简单的Dataset类定义的函数,包括了打乱顺序,取一个batch,取多个batch,取一个batch的前n个数据以及迭代样本的所有batch功能。下将代码粘贴如下以备后续借鉴。

代码

import numpy as np

class Dataset(object):
    def __init__(self, data_map, deterministic=False, shuffle=True):
        # data_map是一个字典,key是数据的名字,value是数据本身,比如data_map = {"ob": ob, "ac": ac, "atarg": atarg, "vtarg": tdlamret}
        self.data_map = data_map
        # 这个参数是指是否是确定性的,如果是确定性的,那么每次调用next_batch的时候,都会返回相同的数据,否则每次都会返回不同的数据
        self.deterministic = deterministic
        # 这个参数是指是否打乱顺序,如果是True,那么每次调用next_batch的时候,都会打乱顺序
        self.enable_shuffle = shuffle
        # 这个参数是指数据的总数
        self.n = next(iter(data_map.values())).shape[0]
        # 这个参数是指当前的数据的id
        self._next_id = 0
        self.shuffle()

    def shuffle(self):
        # shuffle的意思是打乱顺序,但是这里的shuffle是打乱顺序的同时,还要保证每个key的数据是一一对应的
        if self.deterministic:
            return
        perm = np.arange(self.n)
        np.random.shuffle(perm)

        for key in self.data_map:
            # self.data_map[key][perm]是指把self.data_map[key]的数据按照perm的顺序打乱
            self.data_map[key] = self.data_map[key][perm]

        self._next_id = 0

    def next_batch(self, batch_size):
        # 这个函数的作用是返回一个batch的数据,比如batch_size=32,那么就返回32个数据
        if self._next_id >= self.n and self.enable_shuffle:
            # 如果当前的数据已经用完了,那么就重新打乱顺序
            self.shuffle()

        cur_id = self._next_id
        cur_batch_size = min(batch_size, self.n - self._next_id)
        self._next_id += cur_batch_size

        # 这里的data_map是指返回的数据,比如data_map = {"ob": ob, "ac": ac, "atarg": atarg, "vtarg": tdlamret}
        data_map = dict()
        for key in self.data_map:
            data_map[key] = self.data_map[key][cur_id:cur_id+cur_batch_size]
        return data_map

    def iterate_once(self, batch_size):
        # 这个函数的作用是返回一个batch的数据,比如batch_size=32,那么就返回32个数据
        if self.enable_shuffle: self.shuffle()

        while self._next_id <= self.n - batch_size:
            # yield的作用是返回一个值,但是不会结束函数,而是会继续执行函数
            # self.next_batch(batch_size)是指返回一个batch的数据
            yield self.next_batch(batch_size)
        self._next_id = 0

    def iterate_times(self, batch_size, times):
        if self.enable_shuffle: self.shuffle()

        for x in range(times):
            yield self.next_batch(batch_size)
        self._next_id = 0

    def subset(self, num_elements, deterministic=True):
        data_map = dict()
        for key in self.data_map:
            # self.data_map[key][:num_elements]是指取self.data_map[key]的前num_elements个数据
            data_map[key] = self.data_map[key][:num_elements]
        return Dataset(data_map, deterministic)


def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True):
    '''
    Iterate over batches of arrays, optionally shuffling and/or in parallel.
    迭代样本的所有batch
    iterbatches(batch_size=100, epochs=10, deterministic=False)将在10次迭代完所有的样本,每次有不同的随机顺序。
    '''
    assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both'
    # map的作用是对arrays中的每一个元素都执行np.asarray,然后返回一个map对象,tuple的作用是把map对象转换成元组
    arrays = tuple(map(np.asarray, arrays))
    # arrays[0].shape[0]是指取arrays[0]的第一个维度的长度,比如arrays[0]的shape是(100, 10),那么arrays[0].shape[0]就是100
    n = arrays[0].shape[0]
    # assert的作用是判断后面的表达式是否为True,如果为False,那么就会报错,
    # all的作用是判断后面的表达式是否都为True,all(a.shape[0] == n for a in arrays[1:])是指判断arrays[1:]中的每一个元素的第一个维度的长度是否都等于n
    assert all(a.shape[0] == n for a in arrays[1:])
    inds = np.arange(n)
    # np.random.shuffle(inds)是指把inds的顺序打乱
    if shuffle: np.random.shuffle(inds)
    # np.arange(0, n, batch_size)是指生成一个从0到n,步长为batch_size的数组,比如n=100,batch_size=32,那么就会生成一个从0到100,步长为32的数组
    sections = np.arange(0, n, batch_size)[1:] if num_batches is None else num_batches
    # np.array_split(inds, sections)是指把inds按照sections的顺序分割成多个数组,比如inds=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],sections=[3, 6],
    # 那么就会把inds分割成[[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
    for batch_inds in np.array_split(inds, sections):
        # include_final_partial_batch是指是否包含最后一个不完整的batch,如果为True,那么就会包含最后一个不完整的batch
        if include_final_partial_batch or len(batch_inds) == batch_size:
            yield tuple(a[batch_inds] for a in arrays)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值