tf.extract_image_patches是tensorflow用来从一张图像中提取多个patches的,其实现抽取patches的方式请移步一篇知乎的文章,那里介绍的很清楚。
最近尝试在pytorch中来实现tf.extract_image_patches的功能,具体代码如下:
一、tensorflow中tf.extract_image_patches从Tensor中提取patches
1、提取单通道张量
import tensorflow as tf
# 创建两个张量大小为6*6,且为单通道 (N, H, W, C) -> (2, 6, 6, 1)
img_tf = tf.range(start=0, limit=72, dtype="float32", name="img")
img_tf = tf.reshape(img_tf, (2, 6, 6, 1))
# 使用tf.extract_image_patches进行提取,ksizes大小为3
# img_a 的尺寸为:(2, 6, 6, 9),其中 9 = C * ksizes * ksizes
img_a = tf.extract_image_patches(images=img_tf, ksizes=[1, 3, 3, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='SAME')
# 由于pytorch中的张量是(N, C, H, W)形式,可以将img_a转换成该形式以便于更加直观的和pytorch实现的结果进行对比
# img_b = tf.transpose(img_a, perm=[0, 3, 1, 2])
sess = tf.Session()
with sess.as_default():
print("img:", img_tf.shape,'\n',img_tf.eval())
print("\nimg_a:",img_a.shape,'\n',img_a.eval())
# print("\nimg_b:",img_b.shape,'\n',img_b.eval())
其输出结果如下图所示(由于太长,仅展示两个Tensor中的第一个):
2、提取多通道张量,这里以4通道为例
img_tf = tf.range(start=0, limit=72, dtype="float32", name="img")
img_tf = tf.reshape(img_tf, (2, 3, 3, 4)
# img_a 的尺寸为:(2, 3, 3, 36),其中 36 = C * ksizes * ksizes
img_a = tf.extract_image_patches(images=img_tf, ksizes=[1, 3, 3, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='SAME')
# img_b = tf.transpose(img_a, perm=[0, 3, 1, 2])
sess = tf.Session()
with sess.as_default():
print("img:", img_tf.shape,'\n',img_tf.eval())
print("\nimg_a:",img_a.shape,'\n',img_a.eval())
# print("\nimg_b:",img_b.shape,'\n',img_b.eval())
其输出结果如下图所示:
二、在Pytorch中实现tf.extract_image_patches的功能
1、提取单通道张量
import torch
import torch.nn as nn
import torch.nn.functional as F
# 单通道提取
def extract_patche_onechannel(img, kernel_n, stride=1):
N, C, H, W = img.shape
kernel_size = torch.zeros(kernel_n, kernel_n)
kernel_list = []
# 由于输出的通道数是C * kernel_n * kernel_n(这里单通道C为1),所以固定住这 kernel_n * kernel_n 个卷积核的权重值
for i in range(kernel_n):
for j in range(kernel_n):
k = kernel_size.clone()
k[i][j] = 1
kernel_list.append(k)
kernel = torch.stack(kernel_list) # (kernel_n * kernel_n, kernel_n , kernel_n)
# 卷积核权重格式:(卷积核数量, 输入通道数, 卷积核高, 卷积核宽)
kernel = kernel.unsqueeze(1) # (kernel_n * kernel_n, 1, kernel_n , kernel_n)
weight = nn.Parameter(data=kernel, requires_grad=False)
out = F.conv2d(img[:, :1, ...], weight, padding=int((kernel_n - 1) / 2), stride=stride)
return out
img_torch = torch.range(start = 0, end = 71).reshape([2, 1, 6, 6])
# 输出尺寸:(2, 9, 3, 3),其中 9 = C * kernel_n * kernel_n
out_torch = extract_patche_onechannel(img_torch, kernel_n=3)
print("out_torch:{}".format(out_torch))
print(img_torch.shape, out_torch.shape)
# 转成tensorflow中的tensor格式,便于对比
print("\nshape:", out_torch.permute(0, 2, 3, 1).shape,'\n',out_torch.permute(0, 2, 3, 1))
部分输出:
转成tensorflow格式的部分输出:
2、提取4通道张量
# 4通道提取
def extract_patche(img, kernel_n, stride=1):
N, C, H, W = img.shape
kernel_size = torch.zeros(kernel_n, kernel_n)
kernel_list = []
for i in range(kernel_n):
for j in range(kernel_n):
k = kernel_size.clone()
k[i][j] = 1
kernel_list.append(k)
kernel = torch.stack(kernel_list)
# print("kernel:", kernel, kernel.shape)
kernel = kernel.unsqueeze(1)
# print(kernel, kernel.shape)
weight = nn.Parameter(data=kernel, requires_grad=False)
# 单独提取每一个通道
out_0 = F.conv2d(img[:, :1, ...], weight, padding=int((kernel_n - 1) / 2))
out_1 = F.conv2d(img[:, 1:2, ...], weight, padding=int((kernel_n - 1) / 2))
out_2 = F.conv2d(img[:, 2:3, ...], weight, padding=int((kernel_n - 1) / 2))
out_3 = F.conv2d(img[:, 3:4, ...], weight, padding=int((kernel_n - 1) / 2))
out_list = []
for n in range(N):
out_n_list = []
# 将一张图像4个通道[a, b, c, d]中的tensor按
# [a0, b0, c0, d0, a1, b1, c1, d1, ..., an, bn, cn, dn]的顺序重组
for v in range(out_0.shape[1]):
# print("out_0[n][v]:", out_0[n][v])
out_n_list.append(out_0[n][v])
out_n_list.append(out_1[n][v])
out_n_list.append(out_2[n][v])
out_n_list.append(out_3[n][v])
out_n = torch.stack(out_n_list)
print("out_n:", out_n)
out_list.append(out_n)
out = torch.stack(out_list)
return out
img_torch = torch.range(start = 0, end = 71).reshape([2, 4, 3, 3])
# 输出尺寸:(2, 36, 3, 3),其中 36 = C * kernel_n * kernel_n
out_torch = extract_patche(img_torch, kernel_n=3)
print("img_torch:{}\nout_torch:{}".format(img_torch, out_torch))
print(img_torch.shape, out_torch.shape)
print("\nshape:", out_torch.permute(0, 2, 3, 1).shape,'\n',out_torch.permute(0, 2, 3, 1))
部分输出:
转成tensorflow格式的部分输出: