直观理解 torch.nn.Unfold

torch.nn.Unfold 是把batch中的数据按 C、Kernel_W、Kernel_H 打包,详细解释参考:

PyTorch中torch.nn.functional.unfold函数使用详解

本文主要是把 Unfold 返回的tensor的中间部分还原成 patches。

# -*- coding:utf-8 -*-
import cv2
import torch
import numpy as np

img1 = cv2.imread('../128128/1.png')
img2 = cv2.imread('../128128/2.png')

batch = torch.tensor([img1, img2]).permute(0,3,1,2)
print(batch.shape) # [2,3,128,128]

hor_block_num = 2
ver_block_num = 4

n,c,h,w = batch.shape

assert not h%ver_block_num
assert not w%hor_block_num

patch_h, patch_w = h//ver_block_num, w//hor_block_num

use_unfold = True
if (not use_unfold):
    # 方法一
    patches = batch.view(n, c, ver_block_num, patch_h, hor_block_num, patch_w)
    print(patches.shape) # [2,3,4,32,2,64]
    patches = patches.permute(0,2,4,3,5,1)
    print(patches.shape) # [2,4,2,32,64,3]
else:
    # 方法二
    split_block = torch.nn.Unfold(kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
    patches = split_block(batch.float())
    print(patches.shape) # [2,6144,8]
    patches = patches.permute(0,2,1).view(n,ver_block_num,hor_block_num,-1,patch_h,patch_w)
    print(patches.shape) # [2,4,2,3,32,64]
    patches = patches.permute(0,1,2,4,5,3).byte()
    print(patches.shape) # [2,4,2,32,64,3]

# 保存patch
for i in range(patches.shape[0]):
    img = patches[i]    
    for y in range(ver_block_num):
        for x in range(hor_block_num):
            patch_filename = '../128128/%d_y%d_x%d.png'%(i+1,y,x)
            patch = img[y,x].numpy()
            cv2.imwrite(patch_filename, patch)

效果如下:

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值