第六章 番外篇:webdataset

参考教程:
https://github.com/pytorch/pytorch/issues/38419
https://zhuanlan.zhihu.com/p/412772439
https://webdataset.github.io/webdataset/gettingstarted/


背景

训练数据通常是以个体的方式存储的,就像我们在第一章下载并处理成png格式后的cifar10数据,它以’xxx.png’的文件形式存放在一个一个独立的空间中。
随着数据集变得越来越大,这样的存放形式就不是那么高效和便捷。在进行模型训练时,也会因为数据的IO瓶颈拖慢训练的速度。
在使用Dataset中的数据时,我们的__getitem__(self, idx)函数会根据数据的index检索数据。在训练时,我们一般都会使用shuffle = True来完成数据的随机读取,这样索引的index也是无效的,当图片数据直接存放在系统上时,对文件的访问需要花费大量的代价。
这个问题可以使用sequential storage formats and sharding来解决。就像tensorflow中使用的TFRecord格式,它将训练集/测试集打包在一起使用,文件里存储的就是序列化的tf.Example。Pytorch是没有这种专属的数据存储格式的。

WebDataset

WebDataset提供了一种序列化存储大规模数据的方法,它将数据保存在tar包中,但是在使用时不需要对tar包进行解压。这种形式提供了高效的I/O,并且不管是在本地还是云端数据上都表现很不错。

webdataset的生成

webdataset是一个tar文件,所以你直接使用tar命令就可以进行文件的生成。

tar --sort=name -cf dataset.tar dataset/

我们也可以使用python调用webdataset的包,来进行文件的写入操作。
以下面的代码为例,下方的代码想要将现有的MNIST数据存放到’mnist.tar’文件中,因此它按照顺序将数据一个一个多写入了文件里。

dataset = torchvision.datasets.MNIST(root="./temp", download=True) # 获得MNIST数据
sink = wds.TarWriter("mnist.tar") # 使用TarWriter,准备将数据写入mnist.tar
for index, (input, output) in enumerate(dataset):
    if index%1000==0:
        print(f"{index:6d}", end="\r", flush=True, file=sys.stderr) # 每写入1000个数据,输出一些状态
    sink.write({
        "__key__": "sample%06d" % index, # 当前的数据的index
        "input.pyd": input, # 数据的input
        "output.pyd": output, # 数据的target
    })
sink.close() # 关闭当前文件。

这里的sink_write写入了是一个dict,其中’key’这一项决定了你想保存的数据的前缀名,’input.pyd’是你的input的数据的后缀,它同时也决定了你的数据存放的格式。
比如说这里使用的’pyd’,就是我们之前说过的pickle格式,它可以保证数据的完整性,以不压缩的形式存储数据,缺点是不能被其它的语言读取。
在你明确知道数据的类型的情况下,你也可以使用别的格式来存放数据,比如说对于图片,你可以使用‘ppm’,‘png’,'jpg’等格式,对于图片的标签,已知数据标签是整数的形式时,可以使用’cls’格式。

webdataset的加载

对于一个存入tar的webdataset的数据,你可以通过它的url对它进行读取,这个url可以是云端地址,也可以是本地路径。

import webdataset as wds
dataset = wds.WebDataset(url)

我们在讲数据存入tar时,writer根据我们定义的数据格式对数据进行了encode,所以我们直接读取到的数据是还没有decode的数据。
在教程中给了这样一个例子。
在这里插入图片描述
直接获取到的数据格式是bytes的格式。
你可以数据进行一些处理,webdataset提供一种链式的数据处理方法,比如上面的数据,你就可以使用下面的方法处理。

dataset = (
    wds.WebDataset(url)
    .shuffle(100)
    .decode("rgb")
    .to_tuple("jpg;png", "json")
)

这里的decode传入的’rgb’属于headler,webdataset提供了一些自带的imageheadler。帮助使用者进行数据类型转换。imagespecs = { "l8": ("numpy", "uint8", "l"), "rgb8": ("numpy", "uint8", "rgb"), "rgba8": ("numpy", "uint8", "rgba"), "l": ("numpy", "float", "l"), "rgb": ("numpy", "float", "rgb"), "rgba": ("numpy", "float", "rgba"), "torchl8": ("torch", "uint8", "l"), "torchrgb8": ("torch", "uint8", "rgb"), "torchrgba8": ("torch", "uint8", "rgba"), "torchl": ("torch", "float", "l"), "torchrgb": ("torch", "float", "rgb"), "torch": ("torch", "float", "rgb"), "torchrgba": ("torch", "float", "rgba"), "pill": ("pil", None, "l"), "pil": ("pil", None, "rgb"), "pilrgb": ("pil", None, "rgb"), "pilrgba": ("pil", None, "rgba"), }
webdataset提供了多种数据的decode方式的示例,你也可以自定义decode的方法。具体的源码可以查看https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py

decoders = {
    "txt": lambda data: data.decode("utf-8"),
    "text": lambda data: data.decode("utf-8"),
    "transcript": lambda data: data.decode("utf-8"),
    "cls": lambda data: int(data),
    "cls2": lambda data: int(data),
    "index": lambda data: int(data),
    "inx": lambda data: int(data),
    "id": lambda data: int(data),
    "json": lambda data: json.loads(data),
    "jsn": lambda data: json.loads(data),
    "pyd": lambda data: pickle.loads(data),
    "pickle": lambda data: pickle.loads(data),
    "pth": lambda data: torch_loads(data),
    "ten": tenbin_loads,
    "tb": tenbin_loads,
    "mp": msgpack_loads,
    "msg": msgpack_loads,
    "npy": npy_loads,
    "npz": lambda data: np.load(io.BytesIO(data)),
    "cbor": cbor_loads,
}

如果是想要自己定义decode的方法,可以使用以下类似的方法。以下的方法中定义了my_decoder方法,这方法会判断dataset中sample的key是否为jpg,如果不是则忽略,是的话才会返回结果。要注意这里直接获得的数据类型都是bytes,你可以使用类似于**imageio.imread(io.BytesIO(value))**处理数据,将它转为图片。

def my_decoder(key, value):
        if not key.endswith(".jpg"):
            return None
        assert isinstance(value, bytes)
        return value

dataset = wds.WebDataset(url).shuffle(1000).decode(my_decoder)

示例代码

最后给出一个简单的webdataset多进程存储的方法,这里使用的dataset中返回sample是dict形式,最后以pickle的形式存放到指定数量的tar中。

import multiprocessing as mp
import webdataset as wds
import pickle
import os

def write_samples(dataset, tar_index, sample_index,save_dir):
    for t_idx, s_idx in zip(tar_index, sample_index):
        fname = os.path.join(save_dir,str(t_idx)+'.tar')
        stream = wds.TarWriter(fname)
        for idx in s_idx:
            data = dataset[idx]
            sample = {}
            sample['__key__'] = "sample%06d" % idx
            for key, value in data.items():
                sample[key +'.pyd'] = value
            stream.write(sample)
        stream.close()

def dataset2tar(dataset, save_dir,num_tars, num_workers):
    num_len = len(dataset)
    data_index = [i for i in range(num_len)]
    samples = [data_index[i::num_tars] for i in range(num_tars)]
    tar_index = list(range(num_tars))
    jobs = []
    for i in range(num_workers):
        job = mp.Process(target = write_samples,args=(dataset,tar_index[i::num_workers],samples[i::num_workers],save_dir))
        job.start()
        jobs.append(job)
   
    for job in jobs:
        job.join()
    
def pyd_decoder(key, data):
    if not key.endswith(".pyd"):
        return None
    result = pickle.loads(data)
    return result
  • 6
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
好的,您想了解有关自定义几何体的内容,我很乐意为您解答。首先,让我们了解一下什么是几何体。 在Cesium中,几何体是由一些点、线和三角形组成的图形。几何体可以在地球上显示各种形状的物体,如建筑、飞机、汽车等。Cesium提供了一些内置的几何体,如BoxGeometry、CylinderGeometry、SphereGeometry等,但是有时候我们需要展示一些特殊形状的物体,这时候就需要自定义几何体了。 下面是一个简单的自定义几何体的例子: ```javascript var geometry = new Cesium.Geometry({ attributes: { position: new Cesium.GeometryAttribute({ componentDatatype: Cesium.ComponentDatatype.DOUBLE, componentsPerAttribute: 3, values: new Float64Array([ 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0 ]) }) }, indices: new Uint16Array([ 0, 1, 2, 0, 2, 3 ]), primitiveType: Cesium.PrimitiveType.TRIANGLES }); ``` 这个例子中,我们创建了一个由四个点组成的矩形,并用这些点的索引定义了两个三角形。这个几何体可以用来在地球上显示一个矩形。 接下来,让我们逐步了解这个例子中的代码。首先是Cesium.GeometryAttribute。 Cesium.GeometryAttribute是几何体属性的容器。在这个例子中,我们定义了一个名为position的属性,它有三个分量:x、y和z。这个属性使用的数据类型是Cesium.ComponentDatatype.DOUBLE,表示每个分量有一个双精度浮点数。componentsPerAttribute表示每个属性有几个分量。在这个例子中,每个属性都有三个分量。最后,我们用一个Float64Array数组来定义这个属性的值。 接下来是indices,它定义了几何体使用哪些点来组成三角形。在这个例子中,我们定义了两个三角形,每个三角形使用三个顶点。在indices数组中,我们用顶点在attributes数组中的索引来定义每个三角形。 最后,我们定义了几何体的primitiveType,它表示几何体的类型。在这个例子中,我们使用的是三角形类型,所以primitiveType为Cesium.PrimitiveType.TRIANGLES。 希望这个例子可以帮助您更好地理解自定义几何体的实现。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值