MindSpore易点通·精讲系列--数据处理之vision.c_transforms.Decode

Dive Into MindSpore – vision.c_transforms.Decode For Data Processing

MindSpore易点通·精讲系列–数据处理之vision.c_transforms.Decode

本文开发环境

  • Ubuntu 20.04
  • Python 3.8
  • MindSpore 1.7.0

本文内容摘要

  • 先看官方文档
  • 再谈两种试错
  • 探究官方源码
  • 验证探究结果
  • 实战案例分析

1. 先看官方文档

Decode

从官方文档可以获取到的有效信息只有RGB参数,而且该参数还只有一个默认值True。样例部分只知道可以用到dataset的map方法中,至于输入是什么样的,输出是什么样的,文档并没有给出明确说明。不过从样例来看,直觉上觉得输入应该是个文件,顺着这个思路,开始进行尝试。

2. 再谈两种试错

2.1 试错一

输入是文件名?

测试代码如下:

from mindspore.dataset.vision.c_transforms import Decode

image_file = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/0000.jpg"
decode_op = Decode()
out = decode_op(image_file)

print("=== out: ===\n{}".format(out), flush=True)

果不其然,直觉不可靠,报错如下:

Traceback (most recent call last):
  File "/Users/kaierlong/test_decode.py", line 5, in <module>
    out = decode_op(image_file)
  File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/vision/c_transforms.py", line 594, in __call__
    raise TypeError(
TypeError: Input should be an encoded image in 1-D NumPy format, got <class 'str'>.

2.2 试错二

输入是Numpy?

根据试错一的错误提示,Input应该是1维numpy数据,顺着这个思路进行再次测试。

测试代码如下:

import numpy as np

from PIL import Image
from mindspore.dataset.vision.c_transforms import Decode

image_file = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/0000.jpg"
image = np.array(Image.open(image_file)).reshape(-1)
decode_op = Decode()
out = decode_op(image)

print("=== out: ===\n{}".format(out), flush=True)

报错如下:

Traceback (most recent call last):
  File "/Users/kaierlong/test_decode.py", line 9, in <module>
    out = decode_op(image)
  File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/vision/c_transforms.py", line 596, in __call__
    return super().__call__(img)
  File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/vision/c_transforms.py", line 72, in __call__
    return super().__call__(*input_tensor_list)
  File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/transforms/c_transforms.py", line 43, in __call__
    output_tensor_list = callable_op(tensor_row)
RuntimeError: Unexpected error. Decode: image decode failed.
Line of code : 236
File         : /Users/jenkins/agent-working-dir/workspace/Compile_CPU_X86_MacOS_PY39/mindspore/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc

这个报错更深入,直接是RuntimeError: Unexpected error. Decode: image decode failed.,看来要去研究一下官方源码。

3. 探究官方源码

mindspore源码分支为v1.7.0

在官方文档的样例部分,是对image_folder_dataset进行操作。所以决定深入该源码一探究竟。

源码文件位置:

mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc

image_folder_op第91行,发现了图片处理的代码,代码如下:

RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pair_ptr->first), &image));

该行代码用到了一个辅助方法CreateFromFile,该方法定义和实现位置为:

mindspore/ccsrc/minddata/dataset/core/tensor.h
mindspore/ccsrc/minddata/dataset/core/tensor.cc

该方法的具体实现代码如下:

Status Tensor::CreateFromFile(const std::string &path, std::shared_ptr<Tensor> *out) {
  RETURN_UNEXPECTED_IF_NULL(out);
  Path file(path);
  if (file.IsDirectory()) {
    RETURN_STATUS_UNEXPECTED("Invalid file found: " + path + ", should be file, but got directory.");
  }
  std::ifstream fs;
  fs.open(path, std::ios::binary | std::ios::in);
  CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Failed to open file: " + path);
  int64_t num_bytes = fs.seekg(0, std::ios::end).tellg();
  CHECK_FAIL_RETURN_UNEXPECTED(num_bytes < kDeMaxDim, "Invalid file to allocate tensor memory, check path: " + path);
  CHECK_FAIL_RETURN_UNEXPECTED(fs.seekg(0, std::ios::beg).good(), "Failed to find size of file, check path: " + path);
  RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape{num_bytes}, DataType(DataType::DE_UINT8), out));
  int64_t written_bytes = fs.read(reinterpret_cast<char *>((*out)->GetMutableBuffer()), num_bytes).gcount();
  if (!(written_bytes == num_bytes && fs.good())) {
    fs.close();
    RETURN_STATUS_UNEXPECTED("Error in writing to tensor, check path: " + path);
  }
  fs.close();
  return Status::OK();
}

从该方法分析可知,Decode的输入实际上是图片的二进制数据,下面进行验证。

4. 验证探究结果

测试源码如下:

import codecs

from mindspore.dataset.vision.c_transforms import Decode

image_file = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/0000.jpg"
fp = codecs.open(image_file, "rb")
image_data = fp.read()
fp.close()
decode_op = Decode()
out = decode_op(image_data)

print("=== out: ===\n{}".format(out), flush=True)

输出结果:

=== out: ===
[[[  0  11  17]
  [  0   2   7]
  [ 13  11  14]
  [ 11   1   2]
  [ 24   9   6]
  [ 43  25  21]
  [ 30  16  15]
  [ 34  29  26]
  [ 38  42  43]
  [ 38  48  49]]

 [[  1  10  15]
  [  0   1   4]
  [ 27  23  24]
  [ 34  23  21]
  [ 26   7   3]
  [ 27   8   2]
  [ 29  11   7]
  [ 43  33  31]
  [ 39  41  38]
  [ 40  49  48]]

 [[  6  10  11]
  [  3   2   0]
  [ 14   4   2]
  [ 33  16   9]
  [ 58  33  26]
  [ 44  19  12]
  [ 19   0   0]
  [ 39  24  17]
  [ 48  45  40]
  [ 53  55  50]]

 [[ 15  11   8]
  [ 39  32  26]
  [ 42  25  18]
  [ 94  71  63]
  [188 158 148]
  [162 130 119]
  [ 69  41  30]
  [ 54  34  23]
  [ 91  82  75]
  [ 99  96  89]]

 [[ 13   4   0]
  [ 28  15   7]
  [ 48  25  17]
  [115  85  74]
  [198 160 147]
  [195 157 144]
  [125  91  79]
  [ 70  44  31]
  [ 75  61  52]
  [ 83  74  65]]

 [[ 52  39  31]
  [ 40  23  15]
  [ 94  67  58]
  [160 126 114]
  [164 125 110]
  [179 137 123]
  [178 140 127]
  [ 98  68  57]
  [ 88  70  60]
  [ 94  81  72]]

 [[ 38  23  18]
  [ 35  16   9]
  [ 96  67  59]
  [172 138 128]
  [170 130 118]
  [164 122 110]
  [159 121 110]
  [ 83  53  43]
  [103  84  77]
  [106  93  85]]

 [[ 31  17  14]
  [ 30  12   8]
  [ 28   0   0]
  [ 95  62  53]
  [158 120 111]
  [152 112 104]
  [135  98  90]
  [110  81  75]
  [ 99  81  77]
  [ 98  87  83]]

 [[ 13   4   5]
  [ 31  20  18]
  [ 88  67  64]
  [120  91  87]
  [123  88  84]
  [130  93  87]
  [159 128 125]
  [216 192 190]
  [180 168 168]
  [166 160 160]]

 [[ 52  46  48]
  [113 103 104]
  [179 159 158]
  [176 148 145]
  [166 132 130]
  [187 152 150]
  [206 176 174]
  [227 206 205]
  [223 212 216]
  [215 210 214]]]

看起来结果没什么问题,为了确认是否正确,下面再用PIL库做对比测试。

对比测试代码如下:

import numpy as np

from PIL import Image

image_file = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/0000.jpg"
image = np.array(Image.open(image_file))

print("=== PIL out: ===\n{}".format(image), flush=True)

对比测试结果为:

=== PIL out: ===
[[[  0  11  17]
  [  0   2   7]
  [ 13  11  14]
  [ 11   1   2]
  [ 24   9   6]
  [ 43  25  21]
  [ 30  16  15]
  [ 34  29  26]
  [ 38  42  43]
  [ 38  48  49]]

 [[  1  10  15]
  [  0   1   4]
  [ 27  23  24]
  [ 34  23  21]
  [ 26   7   3]
  [ 27   8   2]
  [ 29  11   7]
  [ 43  33  31]
  [ 39  41  38]
  [ 40  49  48]]

 [[  6  10  11]
  [  3   2   0]
  [ 14   4   2]
  [ 33  16   9]
  [ 58  33  26]
  [ 44  19  12]
  [ 19   0   0]
  [ 39  24  17]
  [ 48  45  40]
  [ 53  55  50]]

 [[ 15  11   8]
  [ 39  32  26]
  [ 42  25  18]
  [ 94  71  63]
  [188 158 148]
  [162 130 119]
  [ 69  41  30]
  [ 54  34  23]
  [ 91  82  75]
  [ 99  96  89]]

 [[ 13   4   0]
  [ 28  15   7]
  [ 48  25  17]
  [115  85  74]
  [198 160 147]
  [195 157 144]
  [125  91  79]
  [ 70  44  31]
  [ 75  61  52]
  [ 83  74  65]]

 [[ 52  39  31]
  [ 40  23  15]
  [ 94  67  58]
  [160 126 114]
  [164 125 110]
  [179 137 123]
  [178 140 127]
  [ 98  68  57]
  [ 88  70  60]
  [ 94  81  72]]

 [[ 38  23  18]
  [ 35  16   9]
  [ 96  67  59]
  [172 138 128]
  [170 130 118]
  [164 122 110]
  [159 121 110]
  [ 83  53  43]
  [103  84  77]
  [106  93  85]]

 [[ 31  17  14]
  [ 30  12   8]
  [ 28   0   0]
  [ 95  62  53]
  [158 120 111]
  [152 112 104]
  [135  98  90]
  [110  81  75]
  [ 99  81  77]
  [ 98  87  83]]

 [[ 13   4   5]
  [ 31  20  18]
  [ 88  67  64]
  [120  91  87]
  [123  88  84]
  [130  93  87]
  [159 128 125]
  [216 192 190]
  [180 168 168]
  [166 160 160]]

 [[ 52  46  48]
  [113 103 104]
  [179 159 158]
  [176 148 145]
  [166 132 130]
  [187 152 150]
  [206 176 174]
  [227 206 205]
  [223 212 216]
  [215 210 214]]]

可以看出两者的输出是一致的,对官方源码的分析是正确的。

5. 实战案例分析

下面结合MindRecord来做一个数据生成及读取的实战案例。

在下面案例中,笔者准备了5000张图片,读者可自行准备。

5.1 数据生成部分

数据生成代码如下:

其中

  • data_dir为图片数据存储目录
  • image_list_file为图片名列表文件,每行一个图片名,
  • mindrecord_dir为MindRecord数据保存目录
def generate_dataset(data_dir, image_list_file, mindrecord_dir, num_train_shard=4, num_test_shard=2):
    data_schema = {
        "image": {"type": "bytes"},
        "label": {"type": "int32"}
    }

    train_writer = FileWriter(
        file_name=os.path.join(mindrecord_dir, "train.mindrecord"), shard_num=num_train_shard)
    test_writer = FileWriter(
        file_name=os.path.join(mindrecord_dir, "test.mindrecord"), shard_num=num_test_shard)

    train_writer.add_schema(data_schema, "train")
    test_writer.add_schema(data_schema, "test")

    num_all_samples = 0
    num_train_samples = 0
    num_test_samples = 0

    # 用来放置一定数据的样本数据,加速数据写入。
    # 这里总体样本数比较少,体现不出加速效果。
    train_tmp_samples = []
    test_tmp_samples = []

    with codecs.open(image_list_file, "r", "UTF8") as image_list_fp:
        for line in image_list_fp:
            line = line.strip()
            if not line:
                continue

            # 判断图片是否存在
            image_path = os.path.join(data_dir, line)
            if not os.path.exists(image_path):
                print("image: {} not exists!".format(line), flush=True)
                continue

            # 读取图片数据
            image_fp = codecs.open(image_path, "rb")
            image_data = image_fp.read()
            image_fp.close()
            num_all_samples += 1
            # 伪造标签数据,实际项目中会有真是的标签数据
            label_data = num_all_samples % 10

            sample = {
                "image": image_data,
                "label": label_data,
            }

            # 按照4:1比例生成训练集和测试集
            if num_all_samples % 5 == 0:
                test_tmp_samples.append(sample)
                num_test_samples += 1
                if num_test_samples % 10 == 0:
                    test_writer.write_raw_data(test_tmp_samples)
                    test_tmp_samples = []
            else:
                train_tmp_samples.append(sample)
                num_train_samples += 1
                if num_train_samples % 10 == 0:
                    train_writer.write_raw_data(train_tmp_samples)
                    train_tmp_samples = []

    if train_tmp_samples:
        train_writer.write_raw_data(train_tmp_samples)
    if test_tmp_samples:
        test_writer.write_raw_data(test_tmp_samples)
                    
    train_writer.commit()
    test_writer.commit()

    print("====== number of all samples: {} ".format(num_all_samples), flush=True)
    print("====== number of train samples: {} ".format(num_train_samples), flush=True)
    print("====== number of test samples: {} ".format(num_test_samples), flush=True)

数据生成测试

测试代码如下:

data_dir = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/data"
image_list_file = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/list.txt"
mindrecord_dir = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/mindrecord"
generate_dataset(data_dir=data_dir, image_list_file=image_list_file, mindrecord_dir=mindrecord_dir)

测试输出如下内容:

====== number of all samples: 5000 
====== number of train samples: 4000 
====== number of test samples: 1000

使用tree命令查看mindrecord_dir目录,目录内容如下:

mindrecord/
├── test.mindrecord0
├── test.mindrecord0.db
├── test.mindrecord1
├── test.mindrecord1.db
├── train.mindrecord0
├── train.mindrecord0.db
├── train.mindrecord1
├── train.mindrecord1.db
├── train.mindrecord2
├── train.mindrecord2.db
├── train.mindrecord3
└── train.mindrecord3.db

0 directories, 12 files

5.2 数据读取部分

数据读取代码如下,这里用到了Decode:

def create_dataset(mindrecord_dir, usage="train", batch_size=1, num_workers=2):
    if usage == "train":
        data_file_name = "train.mindrecord0"
        shuffle = True
    else:
        data_file_name = "test.mindrecord0"
        shuffle = False

    dataset_path = os.path.join(mindrecord_dir, data_file_name)

    dataset = MindDataset(
        dataset_path, columns_list=["image", "label"],
        num_parallel_workers=num_workers, shuffle=shuffle, num_shards=None, shard_id=None)

    dataset = dataset.map(
        operations=Decode(), input_columns=["image"], output_columns=["image"], num_parallel_workers=num_workers)

    dataset = dataset.batch(batch_size, drop_remainder=True)

    return dataset

数据读取测试

测试代码如下:

mindrecord_dir = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/mindrecord"
dataset = create_dataset(mindrecord_dir=mindrecord_dir, usage="test")
data_size = dataset.get_dataset_size()
print("====== data size: {} ======".format(data_size), flush=True)

sample = None
for item in dataset:
	sample = item
	break
print("====== sample: ======\n{}".format(sample), flush=True)

测试输出如下内容:

====== data size: 1000 ======
====== sample: ======
[Tensor(shape=[1, 218, 178, 3], dtype=UInt8, value=
[[[[152, 197, 203],
   [150, 197, 203],
   [150, 198, 202],
   ...
   [142, 202, 200],
   [142, 202, 200],
   [142, 202, 200]],
  [[152, 197, 203],
   [150, 197, 203],
   [151, 199, 203],
   ...
   [142, 202, 200],
   [142, 202, 200],
   [142, 202, 200]],
  [[152, 197, 203],
   [151, 198, 204],
   [151, 199, 203],
   ...
   [141, 203, 202],
   [141, 203, 202],
   [141, 203, 202]],
  ...
  [[146, 200, 202],
   [146, 200, 202],
   [146, 200, 202],
   ...
   [145, 199, 201],
   [141, 192, 195],
   [141, 192, 195]],
  [[145, 199, 201],
   [146, 200, 202],
   [146, 200, 202],
   ...
   [145, 199, 201],
   [146, 197, 200],
   [146, 197, 200]],
  [[145, 199, 201],
   [146, 200, 202],
   [146, 200, 202],
   ...
   [145, 199, 201],
   [148, 199, 202],
   [148, 199, 202]]]]), Tensor(shape=[1], dtype=Int32, value= [5])]

5.3 验证读取结果

找到5.2中读取的图片(如果读者顺序一致的话,应该是第五张图片),用PIL进行读取测试。

测试代码如下:

注意替换代码中image_file为读者真实路径

import numpy as np
from PIL import Image

image_file = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/data/000005.jpg"
image = Image.open(image_file)
image_data = np.asarray(image)
print("====== PIL output: ======\n{}".format(image_data), flush=True)

测试输出如下内容:

====== PIL output: ======
[[[152 197 203]
  [150 197 203]
  [150 198 202]
  ...
  [142 202 200]
  [142 202 200]
  [142 202 200]]

 [[152 197 203]
  [150 197 203]
  [151 199 203]
  ...
  [142 202 200]
  [142 202 200]
  [142 202 200]]

 [[152 197 203]
  [151 198 204]
  [151 199 203]
  ...
  [141 203 202]
  [141 203 202]
  [141 203 202]]

 ...

 [[146 200 202]
  [146 200 202]
  [146 200 202]
  ...
  [145 199 201]
  [141 192 195]
  [141 192 195]]

 [[145 199 201]
  [146 200 202]
  [146 200 202]
  ...
  [145 199 201]
  [146 197 200]
  [146 197 200]]

 [[145 199 201]
  [146 200 202]
  [146 200 202]
  ...
  [145 199 201]
  [148 199 202]
  [148 199 202]]]

可以看到5.25.3的读取结果是一致的。

本文总结

本文探究了vision.c_transforms下的Decode算子,并结合一个数据生成和读取的案例进一步讲述如何使用该算子。

本文参考

本文为原创文章,版权归作者所有,未经授权不得转载!

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值