unfold其实功能很简单就是在N*C*H*W的H*W平面上取块,然后堆叠,主要用于CV中的patch提取。
fold的功能与unfold相反,但块重叠部分是相加处理的,因此需要除以重叠次数。
import torch
from torch.nn import functional as f
x = torch.arange(0, 1 * 3 * 15 * 15).float()
x = x.view(1, 3, 15, 15)
x1 = f.unfold(x, kernel_size=3, stride=1)
#print(x1)
x=torch.arange(0, 1 * 3*3*3*4).float()
x=x.view(1,3*3*3,4)
print(x)
x2 = f.fold(x,(3,6),kernel_size=3,stride=1)
print(x2.shape)
print(x2)
# print(x1[0,:,0])
# print(x1.shape)
# B, C_kh_kw, L = x1.size()
# x1 = x1.permute(0, 2, 1)
# x1 = x1.view(B, L, -1, 3, 3)
#print(x1)
inp = torch.arange(0, 1 * 3 * 5 * 5).float()
inp = inp.view(1, 3, 5, 5)
print(inp)
w = torch.arange(0,1*3*3*3).float()
w=w.view(1, 3, 3, 3)
print(w)
inp_unf = torch.nn.functional.unfold(inp,kernel_size= (3, 3),padding=1)
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
print(inp_unf.transpose(1, 2).shape)
print(w.view(w.size(0), -1).t().shape)
print(out_unf.shape)
out = out_unf.view(1, 1, 5, 5)
print(out)
print((torch.nn.functional.conv2d(inp, w,padding=1) - out).abs().max())