Dive Into MindSpore – vision.c_transforms.Decode For Data Processing
MindSpore易点通·精讲系列–数据处理之vision.c_transforms.Decode
本文开发环境
- Ubuntu 20.04
- Python 3.8
- MindSpore 1.7.0
本文内容摘要
- 先看官方文档
- 再谈两种试错
- 探究官方源码
- 验证探究结果
- 实战案例分析
1. 先看官方文档
从官方文档可以获取到的有效信息只有RGB参数,而且该参数还只有一个默认值True。样例部分只知道可以用到dataset的map方法中,至于输入是什么样的,输出是什么样的,文档并没有给出明确说明。不过从样例来看,直觉上觉得输入应该是个文件,顺着这个思路,开始进行尝试。
2. 再谈两种试错
2.1 试错一
输入是文件名?
测试代码如下:
from mindspore.dataset.vision.c_transforms import Decode
image_file = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/0000.jpg"
decode_op = Decode()
out = decode_op(image_file)
print("=== out: ===\n{}".format(out), flush=True)
果不其然,直觉不可靠,报错如下:
Traceback (most recent call last):
File "/Users/kaierlong/test_decode.py", line 5, in <module>
out = decode_op(image_file)
File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/vision/c_transforms.py", line 594, in __call__
raise TypeError(
TypeError: Input should be an encoded image in 1-D NumPy format, got <class 'str'>.
2.2 试错二
输入是Numpy?
根据试错一的错误提示,Input应该是1维numpy数据,顺着这个思路进行再次测试。
测试代码如下:
import numpy as np
from PIL import Image
from mindspore.dataset.vision.c_transforms import Decode
image_file = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/0000.jpg"
image = np.array(Image.open(image_file)).reshape(-1)
decode_op = Decode()
out = decode_op(image)
print("=== out: ===\n{}".format(out), flush=True)
报错如下:
Traceback (most recent call last):
File "/Users/kaierlong/test_decode.py", line 9, in <module>
out = decode_op(image)
File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/vision/c_transforms.py", line 596, in __call__
return super().__call__(img)
File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/vision/c_transforms.py", line 72, in __call__
return super().__call__(*input_tensor_list)
File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/transforms/c_transforms.py", line 43, in __call__
output_tensor_list = callable_op(tensor_row)
RuntimeError: Unexpected error. Decode: image decode failed.
Line of code : 236
File : /Users/jenkins/agent-working-dir/workspace/Compile_CPU_X86_MacOS_PY39/mindspore/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc
这个报错更深入,直接是RuntimeError: Unexpected error. Decode: image decode failed.
,看来要去研究一下官方源码。
3. 探究官方源码
mindspore源码分支为v1.7.0
在官方文档的样例部分,是对image_folder_dataset
进行操作。所以决定深入该源码一探究竟。
源码文件位置:
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc
在image_folder_op
第91行,发现了图片处理的代码,代码如下:
RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pair_ptr->first), &image));
该行代码用到了一个辅助方法CreateFromFile
,该方法定义和实现位置为:
mindspore/ccsrc/minddata/dataset/core/tensor.h
mindspore/ccsrc/minddata/dataset/core/tensor.cc
该方法的具体实现代码如下:
Status Tensor::CreateFromFile(const std::string &path, std::shared_ptr<Tensor> *out) {
RETURN_UNEXPECTED_IF_NULL(out);
Path file(path);
if (file.IsDirectory()) {
RETURN_STATUS_UNEXPECTED("Invalid file found: " + path + ", should be file, but got directory.");
}
std::ifstream fs;
fs.open(path, std::ios::binary | std::ios::in);
CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Failed to open file: " + path);
int64_t num_bytes = fs.seekg(0, std::ios::end).tellg();
CHECK_FAIL_RETURN_UNEXPECTED(num_bytes < kDeMaxDim, "Invalid file to allocate tensor memory, check path: " + path);
CHECK_FAIL_RETURN_UNEXPECTED(fs.seekg(0, std::ios::beg).good(), "Failed to find size of file, check path: " + path);
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape{num_bytes}, DataType(DataType::DE_UINT8), out));
int64_t written_bytes = fs.read(reinterpret_cast<char *>((*out)->GetMutableBuffer()), num_bytes).gcount();
if (!(written_bytes == num_bytes && fs.good())) {
fs.close();
RETURN_STATUS_UNEXPECTED("Error in writing to tensor, check path: " + path);
}
fs.close();
return Status::OK();
}
从该方法分析可知,Decode的输入实际上是图片的二进制数据,下面进行验证。
4. 验证探究结果
测试源码如下:
import codecs
from mindspore.dataset.vision.c_transforms import Decode
image_file = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/0000.jpg"
fp = codecs.open(image_file, "rb")
image_data = fp.read()
fp.close()
decode_op = Decode()
out = decode_op(image_data)
print("=== out: ===\n{}".format(out), flush=True)
输出结果:
=== out: ===
[[[ 0 11 17]
[ 0 2 7]
[ 13 11 14]
[ 11 1 2]
[ 24 9 6]
[ 43 25 21]
[ 30 16 15]
[ 34 29 26]
[ 38 42 43]
[ 38 48 49]]
[[ 1 10 15]
[ 0 1 4]
[ 27 23 24]
[ 34 23 21]
[ 26 7 3]
[ 27 8 2]
[ 29 11 7]
[ 43 33 31]
[ 39 41 38]
[ 40 49 48]]
[[ 6 10 11]
[ 3 2 0]
[ 14 4 2]
[ 33 16 9]
[ 58 33 26]
[ 44 19 12]
[ 19 0 0]
[ 39 24 17]
[ 48 45 40]
[ 53 55 50]]
[[ 15 11 8]
[ 39 32 26]
[ 42 25 18]
[ 94 71 63]
[188 158 148]
[162 130 119]
[ 69 41 30]
[ 54 34 23]
[ 91 82 75]
[ 99 96 89]]
[[ 13 4 0]
[ 28 15 7]
[ 48 25 17]
[115 85 74]
[198 160 147]
[195 157 144]
[125 91 79]
[ 70 44 31]
[ 75 61 52]
[ 83 74 65]]
[[ 52 39 31]
[ 40 23 15]
[ 94 67 58]
[160 126 114]
[164 125 110]
[179 137 123]
[178 140 127]
[ 98 68 57]
[ 88 70 60]
[ 94 81 72]]
[[ 38 23 18]
[ 35 16 9]
[ 96 67 59]
[172 138 128]
[170 130 118]
[164 122 110]
[159 121 110]
[ 83 53 43]
[103 84 77]
[106 93 85]]
[[ 31 17 14]
[ 30 12 8]
[ 28 0 0]
[ 95 62 53]
[158 120 111]
[152 112 104]
[135 98 90]
[110 81 75]
[ 99 81 77]
[ 98 87 83]]
[[ 13 4 5]
[ 31 20 18]
[ 88 67 64]
[120 91 87]
[123 88 84]
[130 93 87]
[159 128 125]
[216 192 190]
[180 168 168]
[166 160 160]]
[[ 52 46 48]
[113 103 104]
[179 159 158]
[176 148 145]
[166 132 130]
[187 152 150]
[206 176 174]
[227 206 205]
[223 212 216]
[215 210 214]]]
看起来结果没什么问题,为了确认是否正确,下面再用PIL库做对比测试。
对比测试代码如下:
import numpy as np
from PIL import Image
image_file = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/0000.jpg"
image = np.array(Image.open(image_file))
print("=== PIL out: ===\n{}".format(image), flush=True)
对比测试结果为:
=== PIL out: ===
[[[ 0 11 17]
[ 0 2 7]
[ 13 11 14]
[ 11 1 2]
[ 24 9 6]
[ 43 25 21]
[ 30 16 15]
[ 34 29 26]
[ 38 42 43]
[ 38 48 49]]
[[ 1 10 15]
[ 0 1 4]
[ 27 23 24]
[ 34 23 21]
[ 26 7 3]
[ 27 8 2]
[ 29 11 7]
[ 43 33 31]
[ 39 41 38]
[ 40 49 48]]
[[ 6 10 11]
[ 3 2 0]
[ 14 4 2]
[ 33 16 9]
[ 58 33 26]
[ 44 19 12]
[ 19 0 0]
[ 39 24 17]
[ 48 45 40]
[ 53 55 50]]
[[ 15 11 8]
[ 39 32 26]
[ 42 25 18]
[ 94 71 63]
[188 158 148]
[162 130 119]
[ 69 41 30]
[ 54 34 23]
[ 91 82 75]
[ 99 96 89]]
[[ 13 4 0]
[ 28 15 7]
[ 48 25 17]
[115 85 74]
[198 160 147]
[195 157 144]
[125 91 79]
[ 70 44 31]
[ 75 61 52]
[ 83 74 65]]
[[ 52 39 31]
[ 40 23 15]
[ 94 67 58]
[160 126 114]
[164 125 110]
[179 137 123]
[178 140 127]
[ 98 68 57]
[ 88 70 60]
[ 94 81 72]]
[[ 38 23 18]
[ 35 16 9]
[ 96 67 59]
[172 138 128]
[170 130 118]
[164 122 110]
[159 121 110]
[ 83 53 43]
[103 84 77]
[106 93 85]]
[[ 31 17 14]
[ 30 12 8]
[ 28 0 0]
[ 95 62 53]
[158 120 111]
[152 112 104]
[135 98 90]
[110 81 75]
[ 99 81 77]
[ 98 87 83]]
[[ 13 4 5]
[ 31 20 18]
[ 88 67 64]
[120 91 87]
[123 88 84]
[130 93 87]
[159 128 125]
[216 192 190]
[180 168 168]
[166 160 160]]
[[ 52 46 48]
[113 103 104]
[179 159 158]
[176 148 145]
[166 132 130]
[187 152 150]
[206 176 174]
[227 206 205]
[223 212 216]
[215 210 214]]]
可以看出两者的输出是一致的,对官方源码的分析是正确的。
5. 实战案例分析
下面结合MindRecord来做一个数据生成及读取的实战案例。
在下面案例中,笔者准备了5000张图片,读者可自行准备。
5.1 数据生成部分
数据生成代码如下:
其中
- data_dir为图片数据存储目录
- image_list_file为图片名列表文件,每行一个图片名,
- mindrecord_dir为MindRecord数据保存目录
def generate_dataset(data_dir, image_list_file, mindrecord_dir, num_train_shard=4, num_test_shard=2):
data_schema = {
"image": {"type": "bytes"},
"label": {"type": "int32"}
}
train_writer = FileWriter(
file_name=os.path.join(mindrecord_dir, "train.mindrecord"), shard_num=num_train_shard)
test_writer = FileWriter(
file_name=os.path.join(mindrecord_dir, "test.mindrecord"), shard_num=num_test_shard)
train_writer.add_schema(data_schema, "train")
test_writer.add_schema(data_schema, "test")
num_all_samples = 0
num_train_samples = 0
num_test_samples = 0
# 用来放置一定数据的样本数据,加速数据写入。
# 这里总体样本数比较少,体现不出加速效果。
train_tmp_samples = []
test_tmp_samples = []
with codecs.open(image_list_file, "r", "UTF8") as image_list_fp:
for line in image_list_fp:
line = line.strip()
if not line:
continue
# 判断图片是否存在
image_path = os.path.join(data_dir, line)
if not os.path.exists(image_path):
print("image: {} not exists!".format(line), flush=True)
continue
# 读取图片数据
image_fp = codecs.open(image_path, "rb")
image_data = image_fp.read()
image_fp.close()
num_all_samples += 1
# 伪造标签数据,实际项目中会有真是的标签数据
label_data = num_all_samples % 10
sample = {
"image": image_data,
"label": label_data,
}
# 按照4:1比例生成训练集和测试集
if num_all_samples % 5 == 0:
test_tmp_samples.append(sample)
num_test_samples += 1
if num_test_samples % 10 == 0:
test_writer.write_raw_data(test_tmp_samples)
test_tmp_samples = []
else:
train_tmp_samples.append(sample)
num_train_samples += 1
if num_train_samples % 10 == 0:
train_writer.write_raw_data(train_tmp_samples)
train_tmp_samples = []
if train_tmp_samples:
train_writer.write_raw_data(train_tmp_samples)
if test_tmp_samples:
test_writer.write_raw_data(test_tmp_samples)
train_writer.commit()
test_writer.commit()
print("====== number of all samples: {} ".format(num_all_samples), flush=True)
print("====== number of train samples: {} ".format(num_train_samples), flush=True)
print("====== number of test samples: {} ".format(num_test_samples), flush=True)
数据生成测试
测试代码如下:
data_dir = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/data"
image_list_file = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/list.txt"
mindrecord_dir = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/mindrecord"
generate_dataset(data_dir=data_dir, image_list_file=image_list_file, mindrecord_dir=mindrecord_dir)
测试输出如下内容:
====== number of all samples: 5000
====== number of train samples: 4000
====== number of test samples: 1000
使用tree命令查看mindrecord_dir目录,目录内容如下:
mindrecord/
├── test.mindrecord0
├── test.mindrecord0.db
├── test.mindrecord1
├── test.mindrecord1.db
├── train.mindrecord0
├── train.mindrecord0.db
├── train.mindrecord1
├── train.mindrecord1.db
├── train.mindrecord2
├── train.mindrecord2.db
├── train.mindrecord3
└── train.mindrecord3.db
0 directories, 12 files
5.2 数据读取部分
数据读取代码如下,这里用到了Decode:
def create_dataset(mindrecord_dir, usage="train", batch_size=1, num_workers=2):
if usage == "train":
data_file_name = "train.mindrecord0"
shuffle = True
else:
data_file_name = "test.mindrecord0"
shuffle = False
dataset_path = os.path.join(mindrecord_dir, data_file_name)
dataset = MindDataset(
dataset_path, columns_list=["image", "label"],
num_parallel_workers=num_workers, shuffle=shuffle, num_shards=None, shard_id=None)
dataset = dataset.map(
operations=Decode(), input_columns=["image"], output_columns=["image"], num_parallel_workers=num_workers)
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
数据读取测试
测试代码如下:
mindrecord_dir = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/mindrecord"
dataset = create_dataset(mindrecord_dir=mindrecord_dir, usage="test")
data_size = dataset.get_dataset_size()
print("====== data size: {} ======".format(data_size), flush=True)
sample = None
for item in dataset:
sample = item
break
print("====== sample: ======\n{}".format(sample), flush=True)
测试输出如下内容:
====== data size: 1000 ======
====== sample: ======
[Tensor(shape=[1, 218, 178, 3], dtype=UInt8, value=
[[[[152, 197, 203],
[150, 197, 203],
[150, 198, 202],
...
[142, 202, 200],
[142, 202, 200],
[142, 202, 200]],
[[152, 197, 203],
[150, 197, 203],
[151, 199, 203],
...
[142, 202, 200],
[142, 202, 200],
[142, 202, 200]],
[[152, 197, 203],
[151, 198, 204],
[151, 199, 203],
...
[141, 203, 202],
[141, 203, 202],
[141, 203, 202]],
...
[[146, 200, 202],
[146, 200, 202],
[146, 200, 202],
...
[145, 199, 201],
[141, 192, 195],
[141, 192, 195]],
[[145, 199, 201],
[146, 200, 202],
[146, 200, 202],
...
[145, 199, 201],
[146, 197, 200],
[146, 197, 200]],
[[145, 199, 201],
[146, 200, 202],
[146, 200, 202],
...
[145, 199, 201],
[148, 199, 202],
[148, 199, 202]]]]), Tensor(shape=[1], dtype=Int32, value= [5])]
5.3 验证读取结果
找到5.2
中读取的图片(如果读者顺序一致的话,应该是第五张图片),用PIL进行读取测试。
测试代码如下:
注意替换代码中image_file为读者真实路径
import numpy as np
from PIL import Image
image_file = "/Users/kaierlong/Downloads/ms_demos/ms_gan/image/data/000005.jpg"
image = Image.open(image_file)
image_data = np.asarray(image)
print("====== PIL output: ======\n{}".format(image_data), flush=True)
测试输出如下内容:
====== PIL output: ======
[[[152 197 203]
[150 197 203]
[150 198 202]
...
[142 202 200]
[142 202 200]
[142 202 200]]
[[152 197 203]
[150 197 203]
[151 199 203]
...
[142 202 200]
[142 202 200]
[142 202 200]]
[[152 197 203]
[151 198 204]
[151 199 203]
...
[141 203 202]
[141 203 202]
[141 203 202]]
...
[[146 200 202]
[146 200 202]
[146 200 202]
...
[145 199 201]
[141 192 195]
[141 192 195]]
[[145 199 201]
[146 200 202]
[146 200 202]
...
[145 199 201]
[146 197 200]
[146 197 200]]
[[145 199 201]
[146 200 202]
[146 200 202]
...
[145 199 201]
[148 199 202]
[148 199 202]]]
可以看到5.2
和5.3
的读取结果是一致的。
本文总结
本文探究了vision.c_transforms下的Decode算子,并结合一个数据生成和读取的案例进一步讲述如何使用该算子。
本文参考
本文为原创文章,版权归作者所有,未经授权不得转载!