【Python-Tensorflow】tf.data.Dataset的解析与使用

参考资料

1 作用

dataset = tf.data.Dataset…()

构建和处理数据集。包括三种类型的操作。

  • 根据输入数据创建源数据集。
  • 应用数据集转换以预处理数据。
  • 遍历数据集并处理元素。

2 tf.data.Dataset的函数

2.1 from_generator()

通过生成器去创建dataset,该函数的参数用于传生成器

# 定义生成器
def gen():
  ragged_tensor = tf.ragged.constant([[1, 2], [3]])
  yield 42, ragged_tensor
# 创建数据集
dataset = tf.data.Dataset.from_generator(
     gen,
     # 定义输出形状和输出类型
     output_signature=(
     	 # 定义输出形状
         tf.TensorSpec(shape=(), dtype=tf.int32),
         # 定义输出类型
         tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))

print(list(dataset.take(1)))

2.2 from_tensor_slices()

对给定张量进行切片
给定的张量沿其第一维被切片。此操作将保留输入张量的结构,删除每个张量的第一维并将其用作数据集维。所有输入张量的第一个维度必须具有相同的大小。

# Slicing a 1D tensor produces scalar tensor elements.
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
list(dataset.as_numpy_iterator())
# Slicing a tuple of 1D tensors produces tuple elements containing
# scalar tensors.
dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
list(dataset.as_numpy_iterator())
[(1,3,5),(2,4,6)]
# Dictionary structure is also preserved.
dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
                                      {'a': 2, 'b': 4}]
True

2.3 from_tensors()

创建一个Dataset包含给定张量的单个元素的。
from_tensors产生仅包含单个元素的数据集。要将输入张量切成多个元素,请from_tensor_slices改用

dataset = tf.data.Dataset.from_tensors([1, 2, 3])
list(dataset.as_numpy_iterator())
[array([1,2,3],dtype=int32)]
dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A'))
list(dataset.as_numpy_iterator())
[(array([1,2,3],dtype=int32),b'A')]

3 dataset 的函数

3.1 apply()

apply启用自定义Dataset转换的链接,这些转换表示为采用一个Dataset参数并返回transformd的函数Dataset。

dataset = tf.data.Dataset.range(100)
def dataset_fn(ds):
  return ds.filter(lambda x: x < 5)
dataset = dataset.apply(dataset_fn)
list(dataset.as_numpy_iterator())

3.2 as_numpy_iterator()

返回一个迭代器,该迭代器将数据集的所有元素转换为numpy。

使用as_numpy_iterator检查你的数据集的内容。要查看元素的形状和类型,请直接打印数据集元素,而不要使用 as_numpy_iterator。不建议使用

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset.as_numpy_iterator():
  print(element)

建议如下用法

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
  print(element)

3.3 batch()

将此数据集的连续元素合并为批

dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
list(dataset.as_numpy_iterator())

分成三批,分别为【1 2 3】【4 5 6】【7 8】


3.4 cache()

在此数据集中缓存元素。
第一次迭代数据集时,其元素将缓存在指定的文件或内存中。随后的迭代将使用缓存的数据。

dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x**2)
dataset = dataset.cache()
# The first time reading through the data will generate the data using
# `range` and `map`.
list(dataset.as_numpy_iterator())
[0,1,4,9,16]
# Subsequent iterations read from the cache.
list(dataset.as_numpy_iterator())
[0,1,4,9,16]

缓存到文件时,缓存的数据将在运行期间保持不变。即使是第一次遍历数据,也将从缓存文件中读取。.cache()直到删除缓存文件或更改文件名,在调用之前更改输入管道才有效。

dataset = tf.data.Dataset.range(5)
dataset = dataset.cache("/path/to/file")  # doctest: +SKIP
list(dataset.as_numpy_iterator())  # doctest: +SKIP
[0,1,2,3,4]
dataset = tf.data.Dataset.range(10)
dataset = dataset.cache("/path/to/file")  # Same file! # doctest: +SKIP
list(dataset.as_numpy_iterator())  # doctest: +SKIP
[0,1,2,3,4]

3.5 cardinality()

返回数据集的大小

  • 数量确定返回数字
  • 无限量,返回tf.data.INFINITE_CARDINALITY
  • 未知,返回tf.data.UNKNOWN_CARDINALITY
dataset = tf.data.Dataset.range(42)
print(dataset.cardinality().numpy())
42
dataset = dataset.repeat()
cardinality = dataset.cardinality()
print((cardinality == tf.data.INFINITE_CARDINALITY).numpy())
True
dataset = dataset.filter(lambda x: True)
cardinality = dataset.cardinality()
print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy())
True

3.6 concatenate()

将给定数据集与此数据集连接来创建一个新的dataset

a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 8)  # ==> [ 4, 5, 6, 7 ]
ds = a.concatenate(b)
list(ds.as_numpy_iterator())
[1,2,3,4,5,6,7]
# The input dataset and dataset to be concatenated should have the same
# nested structures and output types.
c = tf.data.Dataset.zip((a, b))
a.concatenate(c)
错误,a、c类型不同,c是tf.int64类型,a是int64类型
d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])
a.concatenate(d)
错误,a、d类型不同,a是int64类型,d是string类型

3.7 enumerate()

枚举此数据集的元素。
它类似于python的enumerate

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.enumerate(start=5)
for element in dataset.as_numpy_iterator():
  print(element)
(5,1)
(6,2)
(7,4)

3.8 filter()

根据自定义过滤函数去过滤此数据集

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.filter(lambda x: x < 3)
list(dataset.as_numpy_iterator())
[1,2]
# `tf.math.equal(x, y)` is required for equality comparison
def filter_fn(x):
  return tf.math.equal(x, 1)
dataset = dataset.filter(filter_fn)
list(dataset.as_numpy_iterator())
[1]

3.9 flat_map()

跨此数据集映射并展平结果。
使用flat_map,如果你想确保你的数据集保持不变的顺序。例如,要将批次的数据集展平为其元素的数据集:

dataset = tf.data.Dataset.from_tensor_slices(
               [[1, 2, 3], [4, 5, 6], [7, 8, 9]])
dataset = dataset.flat_map(lambda x: Dataset.from_tensor_slices(x))
list(dataset.as_numpy_iterator())
[1,2,3,4,5,6,7,8,9]

3.10 zip()

将给定的数据集压缩在一起来创建一个。
此方法的语义与zip()Python的内置函数相似,主要区别在于datasets 参数可以是Dataset对象的任意嵌套结构。

# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 7)  # ==> [ 4, 5, 6 ]
ds = tf.data.Dataset.zip((a, b))
list(ds.as_numpy_iterator())
[(1,4),(2,5),(3,6)]
ds = tf.data.Dataset.zip((b, a))
list(ds.as_numpy_iterator())
[(4,1),(5,2),(6,3)]

2.11 window()

window(size, shift=None, stride=1, drop_remainder=False)

将输入元素(嵌套)组合到窗口(嵌套)的数据集中。说白了就是按窗口大小划分数据集。
“窗口”是大小为平面元素的有限数据集size(如果没有足够的输入元素来填充窗口并drop_remainder计算为,则可能会更少 False)。
该shift参数确定窗口在每次迭代中移动的输入元素的数量。如果窗口和元素都从0开始编号,则窗口中的第一个元素k将是k * shift 输入数据集的元素。特别是,第一个窗口的第一个元素将始终是输入数据集的第一个元素。
所述stride参数确定输入元件的步幅,并且 shift参数确定窗口的移位。

dataset = tf.data.Dataset.range(7).window(2)
for window in dataset:
  print(list(window.as_numpy_iterator()))
[0,1]
[2,3]
[4,5]
[6]
dataset = tf.data.Dataset.range(7).window(3, 2, 1, True)
for window in dataset:
  print(list(window.as_numpy_iterator()))
[0,1,2]
[2,3,4]
[4,5,6]
dataset = tf.data.Dataset.range(7).window(3, 1, 2, True)
for window in dataset:
  print(list(window.as_numpy_iterator()))
[0,2,4]
[1,3,5]
[2,4,6]

请注意,将window转换应用于嵌套元素的数据集时,它将生成嵌套窗口的数据集。

nested = ([1, 2, 3, 4], [5, 6, 7, 8])
dataset = tf.data.Dataset.from_tensor_slices(nested).window(2)
for window in dataset:
  def to_numpy(ds):
    return list(ds.as_numpy_iterator())
  print(tuple(to_numpy(component) for component in window))
([1,2],[5,6])
([3,4],[7,8])
dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3, 4]})
dataset = dataset.window(2)
for window in dataset:
  def to_numpy(ds):
    return list(ds.as_numpy_iterator())
  print({'a': to_numpy(window['a'])})
  {'a':[1,2]}
  {'a':[3,4]}

3.12 unbatch()

将数据集的元素拆分为多个元素。
例如,如果数据集的元素是shape [B, a0, a1, …],其中B每个输入元素的位置可能有所不同,那么对于数据集中的每个元素,未批处理的数据集将包含Bshape的连续元素[a0, a1, …]。

elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ]
dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64)
dataset = dataset.unbatch()
list(dataset.as_numpy_iterator())
[1,2,3,1,2,1,2,3,4]

3.13 take()

从此数据集中Dataset最多创建一个count元素

dataset = tf.data.Dataset.range(10)
dataset = dataset.take(3)
list(dataset.as_numpy_iterator())
[0,1,2]

3.14 skip()

创建一个Dataset跳过count此数据集中的元素的。

dataset = tf.data.Dataset.range(10)
dataset = dataset.skip(7)
list(dataset.as_numpy_iterator())
python
[7,8,9]

3.15 shuffle()

shuffle(buffer_size, seed=None, reshuffle_each_iteration=None)

随机重新排列此数据集的元素。
该数据集用buffer_size元素填充缓冲区,然后从该缓冲区中随机采样元素,用新元素替换所选元素。为了实现完美的改组,需要缓冲区大小大于或等于数据集的完整大小。

例如,如果您的数据集包含10,000个元素但buffer_size设置为1,000,则shuffle最初将仅从缓冲区的前1,000个元素中选择一个随机元素。选择一个元素后,其缓冲区中的空间将被下一个(即1,001个)元素替换,并保留1,000个元素缓冲区。

reshuffle_each_iteration控制随机播放顺序对于每个时期是否应该不同。在TF 1.X中,创建历元的惯用方式是通过repeat转换:

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
dataset = dataset.repeat(2)  # doctest: +SKIP
[1,0,2,1,2,0]

3.16 shard()

shard( num_shards, index)

返回dataset指定索引开始,一定步长下的所有数据
num_shards步长,index索引

A = tf.data.Dataset.range(10)
B = A.shard(num_shards=3, index=0)
list(B.as_numpy_iterator())
[0,3,6,9]
C = A.shard(num_shards=3, index=1)
list(C.as_numpy_iterator())
[1,4,7]
D = A.shard(num_shards=3, index=2)
list(D.as_numpy_iterator())
[2,5,8]

3.17 repeat()

重复此数据集

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.repeat(3)
list(dataset.as_numpy_iterator())
[1,2,3,1,2,3,1,2,3]

3.18 reduce()

reduce( initial_state, reduce_func)

将输入数据集简化为单个元素。
转换将reduce_func依次调用输入数据集的每个元素,直到数据集用完为止,以其内部状态聚合信息。该initial_state参数用于初始状态,并返回最终状态作为结果。

tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy()
5
tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy()
10

3.19 prefetch()

创建一个Dataset从该数据集中预提取元素的。

大多数数据集输入管道应以调用结束prefetch。这允许在处理当前元素时准备以后的元素。这通常会提高延迟和吞吐量,但以使用额外的内存存储预取元素为代价。

dataset = tf.data.Dataset.range(3)
dataset = dataset.prefetch(2)
list(dataset.as_numpy_iterator())
[0,1,2]

3.20 map()

map(map_func, num_parallel_calls=None, deterministic=None)

此转换将应用于map_func此数据集的每个元素,并以与输入中出现的顺序相同的顺序返回包含转换后的元素的新数据集。map_func可用于更改值和数据集元素的结构。例如,向每个元素加1或投影元素组件的子集。

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1)
list(dataset.as_numpy_iterator())
[2,3,4,5,6]
dataset = tf.data.Dataset.range(3)
# `map_func` returns two `tf.Tensor` objects.
def g(x):
  return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
result = dataset.map(g)
result.element_spec
(TensorSpec(shape=(),dtype=tf.float32,name=None),
 TensorSpec(shape=(3,),dtype=tf.int32,name=None))

# `map_func` can return nested structures.
def i(x):
  return (37.0, [42, 16]), "foo"
result = dataset.map(i)
result.element_spec
(TensorSpec(shape=(),dtype=tf.float32,name=None),
 TensorSpec(shape=(2,),dtype=tf.int32,name=None),
 TensorSpec(shape=(),dtype=tf.string,name=None))

3.21 interleave()

interleave(
map_func, cycle_length=None, block_length=None, num_parallel_calls=None,
deterministic=None
)

map_func跨此数据集映射,并交织结果。
例如,您可以用来Dataset.interleave()同时处理许多输入文件:

  • cycle_length和block_length参数控制在其中的元件所产生的顺序。cycle_length控制并发处理的输入元素的数量。
  • 如果设置cycle_length为1,则此转换将一次处理一个输入元素,并将产生与相同的结果tf.data.Dataset.flat_map。
  • 一般来说,这种转换将适用map_func于cycle_length输入元件,开放迭代对返回的Dataset对象,并循环通过它们产生block_length从每个迭代连续元素,并且每个其到达一个迭代的结束时间消耗下一个输入元件。
dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
dataset = dataset.interleave(
    lambda x: Dataset.from_tensors(x).repeat(6),
    cycle_length=2, block_length=4)
list(dataset.as_numpy_iterator())
[1,1,1,1,
2,2,2,2,
1,1,
2,2,
3,3,3,3,
4,4,4,4,
3,3,
4,4,
5,5,5,5,
5,5]
# -*- coding=utf-8 -*- import matplotlib.pyplot as plt import pydicom import pydicom.uid import sys import PIL.Image as Image # from PyQt5 import QtGui import os have_numpy = True try: import numpy except ImportError: have_numpy = False raise sys_is_little_endian = (sys.byteorder == 'little') NumpySupportedTransferSyntaxes = [ pydicom.uid.ExplicitVRLittleEndian, pydicom.uid.ImplicitVRLittleEndian, pydicom.uid.DeflatedExplicitVRLittleEndian, pydicom.uid.ExplicitVRBigEndian, ] # 支持的传输语法 def supports_transfer_syntax(dicom_dataset): """ Returns ------- bool True if this pixel data handler might support this transfer syntax. False to prevent any attempt to try to use this handler to decode the given transfer syntax """ return (dicom_dataset.file_meta.TransferSyntaxUID in NumpySupportedTransferSyntaxes) def needs_to_convert_to_RGB(dicom_dataset): return False def should_change_PhotometricInterpretation_to_RGB(dicom_dataset): return False # 加载Dicom图像数据 def get_pixeldata(dicom_dataset): """If NumPy is available, return an ndarray of the Pixel Data. Raises ------ TypeError If there is no Pixel Data or not a supported data type. ImportError If NumPy isn't found NotImplementedError if the transfer syntax is not supported AttributeError if the decoded amount of data does not match the expected amount Returns ------- numpy.ndarray The contents of the Pixel Data element (7FE0,0010) as an ndarray. """ if (dicom_dataset.file_meta.TransferSyntaxUID not in NumpySupportedTransferSyntaxes): raise NotImplementedError("Pixel Data is compressed in a " "format pydicom does not yet handle. " "Cannot return array. Pydicom might " "be able to convert the pixel data " 帮
最新发布
03-29
### 使用 Pydicom 和 Numpy 解码压缩像素数据 Pydicom 是一个用于读取、修改和写入 DICOM 文件的 Python 库。然而,默认情况下,它仅支持某些标准传输语法(如 JPEG 基线)。对于不被默认支持的传输语法(例如 RLE 或其他高级编码),可以借助外部库来完成解压操作。 以下是通过 `pydicom` 结合 `numpy` 及第三方工具(如 Pillow 或 GDCM)实现解码的过程: #### 安装依赖项 为了处理非支持的传输语法,可能需要安装额外的支持包: ```bash pip install pydicom numpy pillow gdcm ``` #### 示例代码:使用 Pydicom 和 Pillow 进行解码 如果目标文件采用的是常见的图像压缩格式(如 JPEG, JPEG-LS),可以通过 Pillow 来辅助解码。 ```python import pydicom from PIL import Image import numpy as np def decode_pixel_data(dicom_path): ds = pydicom.dcmread(dicom_path) # 如果像素数据未压缩,则直接返回 NumPy 数组 if hasattr(ds.file_meta, 'TransferSyntaxUID') and \ str(ds.file_meta.TransferSyntaxUID) != "1.2.840.10008.1.2": # 非显式 Little Endian pixel_bytes = ds.PixelData image = Image.frombuffer( mode='L', size=(ds.Columns, ds.Rows), data=pixel_bytes, decoder_name="raw" ) decoded_array = np.array(image.getdata()).reshape((ds.Rows, ds.Columns)) else: # 对于已知可由 Pillow 支持的压缩格式 try: decoded_array = ds.pixel_array # 尝试自动解码 except NotImplementedError: raise ValueError(f"Unsupported transfer syntax {ds.file_meta.TransferSyntaxUID}[^1]") return decoded_array # 调用函数并保存结果到 NumPy 数组 decoded_image = decode_pixel_data('path_to_dicom_file.dcm') print(decoded_image.shape) ``` 上述代码尝试加载 DICOM 数据,并判断其是否为受支持的传输语法。如果不支持,则抛出异常提示用户该语法无法解析[^1]。 #### 示例代码:使用 GDCM 处理复杂压缩 对于更复杂的压缩算法(如 RLE 编码或其他专有格式),推荐使用 GDCM (Grassroots DICOM),它可以作为后端插件集成至 Pydicom 中。 ```python import pydicom import numpy as np try: from pydicom.pixel_data_handlers.util import apply_modality_lut except ImportError: pass def decode_with_gdcm(dicom_path): ds = pydicom.dcmread(dicom_path) if not hasattr(pydicom.config, 'have_numpy'): raise RuntimeError("Numpy is required to handle pixel arrays.") # 判断是否需要调用 GDCM 后端 if getattr(ds.file_meta, 'TransferSyntaxUID', None) in [ "1.2.840.10008.1.2.5", # RLE Lossless "1.2.840.10008.1.2.4.90" # JPEG 2000 Lossless ]: try: handler_name = "gdcm" pydicom.config.image_handlers.append(handler_name) array = ds.pixel_array # 自动触发 GDCM 的解码逻辑 modality_corrected = apply_modality_lut(array, ds)[^2] finally: pydicom.config.image_handlers.remove(handler_name) else: array = ds.pixel_array return array resulting_array = decode_with_gdcm('complex_compressed.dcm') print(resulting_array.dtype) ``` 此方法利用了 GDCM 插件扩展的能力,能够覆盖更多类型的压缩语法[^2]。 ---
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Better Bench

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值