torch.nn.Unfold 是把batch中的数据按 C、Kernel_W、Kernel_H 打包,详细解释参考:
PyTorch中torch.nn.functional.unfold函数使用详解
本文主要是把 Unfold 返回的tensor的中间部分还原成 patches。
# -*- coding:utf-8 -*-
import cv2
import torch
import numpy as np
img1 = cv2.imread('../128128/1.png')
img2 = cv2.imread('../128128/2.png')
batch = torch.tensor([img1, img2]).permute(0,3,1,2)
print(batch.shape) # [2,3,128,128]
hor_block_num = 2
ver_block_num = 4
n,c,h,w = batch.shape
assert not h%ver_block_num
assert not w%hor_block_num
patch_h, patch_w = h//ver_block_num, w//hor_block_num
use_unfold = True
if (not use_unfold):
# 方法一
patches = batch.view(n, c, ver_block_num, patch_h, hor_block_num, patch_w)
print(patches.shape) # [2,3,4,32,2,64]
patches = patches.permute(0,2,4,3,5,1)
print(patches.shape) # [2,4,2,32,64,3]
else:
# 方法二
split_block = torch.nn.Unfold(kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
patches = split_block(batch.float())
print(patches.shape) # [2,6144,8]
patches = patches.permute(0,2,1).view(n,ver_block_num,hor_block_num,-1,patch_h,patch_w)
print(patches.shape) # [2,4,2,3,32,64]
patches = patches.permute(0,1,2,4,5,3).byte()
print(patches.shape) # [2,4,2,32,64,3]
# 保存patch
for i in range(patches.shape[0]):
img = patches[i]
for y in range(ver_block_num):
for x in range(hor_block_num):
patch_filename = '../128128/%d_y%d_x%d.png'%(i+1,y,x)
patch = img[y,x].numpy()
cv2.imwrite(patch_filename, patch)
效果如下: