最近在使用tf.data.Dataset.map(data_map),我想使用data_map对读取的图片进行模糊,并将模糊后的图片与清晰的图片在通道上连接从而得到一个新的数据集,但在运行时却报了以下的错误:
tensorflow.python.framework.errors_impl.UnimplementedError: 2 root error(s) found.
(0) Unimplemented: 2 root error(s) found.
(0) Unimplemented: {{function_node __inference_Dataset_map_Training_data.data_map_101}} The Conv2D op currently only supports the NHWC tensor format on the CPU. The op was given the format: NCHW
[[{{node Conv2D_2}}]]
[[concat/_16]]
(1) Unimplemented: {{function_node __inference_Dataset_map_Training_data.data_map_101}} The Conv2D op currently only supports the NHWC tensor format on the CPU. The op was given the format: NCHW
[[{{node Conv2D_2}}]]
0 successful operations.
0 derived errors ignored.
[[training_data/IteratorGetNext_1]]
[[training_data/IteratorGetNext_1/_1]]
(1) Unimplemented: 2 root error(s) found.
(0) Unimplemented: {{function_node __inference_Dataset_map_Training_data.data_map_101}} The Conv2D op currently only supports the NHWC tensor format on the CPU. The op was given the format: NCHW
[[{{node Conv2D_2}}]]
[[concat/_16]]
(1) Unimplemented: {{function_node __inference_Dataset_map_Training_data.data_map_101}} The Conv2D op currently only supports the NHWC tensor format on the CPU. The op was given the format: NCHW
[[{{node Conv2D_2}}]]
0 successful operations.
0 derived errors ignored.
[[training_data/IteratorGetNext_1]]
0 successful operations.
0 derived errors ignored.
Dataset.map(data_map)貌似是在CPU中运行的,因为我把data_map的单独一个创建py文件使用GPU运行完全正常,但一使用Dataset.map()调用data_map程序就会报错。
报错说我输入的是NCHW格式的张量,CPU只支持NHWC的格式。于是我跟踪到rgb = tf.image.decode_jpeg(tf.read_file(img_path), channels=3)发现rgb是一个tensor且shape是(?,?,3),之后我用tf.expand_dims(data, 0) 增加了一个batch维(不然后面不能进行卷积操作)变为(1,?,?,3),难道他把我增加的维度视为Chanel了吗?
我的代码如下:
# training_data.py 输入的图片大小为(64,64,3)
dataset = tf.data.Dataset.from_tensor_slices(data_path_list)
dataset = dataset.map(self.data_map, num_parallel_calls=8)
training_dataset = dataset.batch(n_batch).prefetch(64)
training_iterator = training_dataset.make_one_shot_iterator()
training_batch = training_iterator.get_next()
def data_map(self, img_path):
rgb = tf.image.decode_jpeg(tf.read_file(img_path), channels=3) # 读取图片
# 对图片进行模糊操作
data = tf.cast(rgb, tf.float32)
data = tf.expand_dims(data, 0) # 拓展一个batch维度
blur = 10 # 模糊系数,决定模糊程度
filter_ = np.array([[1 / (blur ** 2)] * (blur ** 2)]).reshape(blur, blur, 1, 1) # 模糊核
temp = []
data_list = tf.split(data, 3, 3) # 在通道维度上将图片分割成单通道图片
for data in data_list: # 对每个通道分别进行模糊
rgb_blur = tf.nn.conv2d(data, filter_, [1, 1, 1, 1], 'SAME')
temp.append(rgb_blur)
rgb_blur = tf.concat(temp, 3) # 将模糊后的单通图片重新在通道上进行连接
rgb_blur = tf.squeeze(rgb_blur, [0]) # 去掉batch维度
rgb = tf.concat([rgb_blur, rgb], 2) # 在通道上将清晰和模糊两个图片相连
return rgb
有没有大佬知道我的问题出在哪里了,应该怎么解决?万分感谢!