在Pytorch中实现tf.extract_image_patches的功能

tf.extract_image_patches是tensorflow用来从一张图像中提取多个patches的,其实现抽取patches的方式请移步一篇知乎的文章,那里介绍的很清楚。
最近尝试在pytorch中来实现tf.extract_image_patches的功能,具体代码如下:

一、tensorflow中tf.extract_image_patches从Tensor中提取patches

1、提取单通道张量

import tensorflow as tf

# 创建两个张量大小为6*6,且为单通道 (N, H, W, C) -> (2, 6, 6, 1)
img_tf = tf.range(start=0, limit=72, dtype="float32", name="img")
img_tf = tf.reshape(img_tf, (2, 6, 6, 1))

# 使用tf.extract_image_patches进行提取,ksizes大小为3
# img_a 的尺寸为:(2, 6, 6, 9),其中 9 = C * ksizes * ksizes
img_a = tf.extract_image_patches(images=img_tf, ksizes=[1, 3, 3, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='SAME')
# 由于pytorch中的张量是(N, C, H, W)形式,可以将img_a转换成该形式以便于更加直观的和pytorch实现的结果进行对比
# img_b = tf.transpose(img_a, perm=[0, 3, 1, 2])

sess = tf.Session()
with sess.as_default():
    print("img:", img_tf.shape,'\n',img_tf.eval())
    print("\nimg_a:",img_a.shape,'\n',img_a.eval())
    # print("\nimg_b:",img_b.shape,'\n',img_b.eval())

其输出结果如下图所示(由于太长,仅展示两个Tensor中的第一个):
单通道输出

2、提取多通道张量,这里以4通道为例

img_tf = tf.range(start=0, limit=72, dtype="float32", name="img")
img_tf = tf.reshape(img_tf, (2, 3, 3, 4)

# img_a 的尺寸为:(2, 3, 3, 36),其中 36 = C * ksizes * ksizes
img_a = tf.extract_image_patches(images=img_tf, ksizes=[1, 3, 3, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='SAME')
# img_b = tf.transpose(img_a, perm=[0, 3, 1, 2])

sess = tf.Session()
with sess.as_default():
    print("img:", img_tf.shape,'\n',img_tf.eval())
    print("\nimg_a:",img_a.shape,'\n',img_a.eval())
    # print("\nimg_b:",img_b.shape,'\n',img_b.eval())

其输出结果如下图所示:
4通道输出

二、在Pytorch中实现tf.extract_image_patches的功能

1、提取单通道张量

import torch
import torch.nn as nn
import torch.nn.functional as F

# 单通道提取
def extract_patche_onechannel(img, kernel_n, stride=1):
    N, C, H, W = img.shape
    kernel_size = torch.zeros(kernel_n, kernel_n)
    kernel_list = []
    # 由于输出的通道数是C * kernel_n * kernel_n(这里单通道C为1),所以固定住这 kernel_n * kernel_n 个卷积核的权重值
    for i in range(kernel_n):
        for j in range(kernel_n):
            k = kernel_size.clone()
            k[i][j] = 1
            kernel_list.append(k)
            
    kernel = torch.stack(kernel_list)	# (kernel_n * kernel_n, kernel_n , kernel_n)
    # 卷积核权重格式:(卷积核数量, 输入通道数, 卷积核高, 卷积核宽)
    kernel = kernel.unsqueeze(1)	# (kernel_n * kernel_n, 1, kernel_n , kernel_n)
    weight = nn.Parameter(data=kernel, requires_grad=False)
    out = F.conv2d(img[:, :1, ...], weight, padding=int((kernel_n - 1) / 2), stride=stride)
    
    return out
 
img_torch = torch.range(start = 0, end = 71).reshape([2, 1, 6, 6])
# 输出尺寸:(2, 9, 3, 3),其中 9 = C * kernel_n * kernel_n
out_torch = extract_patche_onechannel(img_torch, kernel_n=3)

print("out_torch:{}".format(out_torch))
print(img_torch.shape, out_torch.shape)
# 转成tensorflow中的tensor格式,便于对比
print("\nshape:", out_torch.permute(0, 2, 3, 1).shape,'\n',out_torch.permute(0, 2, 3, 1))

部分输出:
pytorch单通道输出
转成tensorflow格式的部分输出:
tensorflow格式的部分输出

2、提取4通道张量

# 4通道提取
def extract_patche(img, kernel_n, stride=1):
    N, C, H, W = img.shape
    kernel_size = torch.zeros(kernel_n, kernel_n)
    kernel_list = []
    for i in range(kernel_n):
        for j in range(kernel_n):
            k = kernel_size.clone()
            k[i][j] = 1
            kernel_list.append(k)
    
            
    kernel = torch.stack(kernel_list)
#     print("kernel:", kernel, kernel.shape)
    kernel = kernel.unsqueeze(1)         
#     print(kernel, kernel.shape)
    weight = nn.Parameter(data=kernel, requires_grad=False)
    
    # 单独提取每一个通道
    out_0 = F.conv2d(img[:, :1, ...], weight, padding=int((kernel_n - 1) / 2))
    out_1 = F.conv2d(img[:, 1:2, ...], weight, padding=int((kernel_n - 1) / 2))
    out_2 = F.conv2d(img[:, 2:3, ...], weight, padding=int((kernel_n - 1) / 2))
    out_3 = F.conv2d(img[:, 3:4, ...], weight, padding=int((kernel_n - 1) / 2))
    
    out_list = []
    for n in range(N):
        out_n_list = []
        # 将一张图像4个通道[a, b, c, d]中的tensor按
        # [a0, b0, c0, d0, a1, b1, c1, d1, ..., an, bn, cn, dn]的顺序重组
        for v in range(out_0.shape[1]):
#             print("out_0[n][v]:", out_0[n][v])
            out_n_list.append(out_0[n][v])
            out_n_list.append(out_1[n][v])
            out_n_list.append(out_2[n][v])
            out_n_list.append(out_3[n][v])
        out_n = torch.stack(out_n_list)
        print("out_n:", out_n)
        out_list.append(out_n)
    out = torch.stack(out_list)
    
    return out

img_torch = torch.range(start = 0, end = 71).reshape([2, 4, 3, 3])
# 输出尺寸:(2, 36, 3, 3),其中 36 = C * kernel_n * kernel_n
out_torch = extract_patche(img_torch, kernel_n=3)

print("img_torch:{}\nout_torch:{}".format(img_torch, out_torch))
print(img_torch.shape, out_torch.shape)
print("\nshape:", out_torch.permute(0, 2, 3, 1).shape,'\n',out_torch.permute(0, 2, 3, 1))

部分输出:
pytorch部分输出转成tensorflow格式的部分输出:
tensorflow格式的部分输出

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值