pytorch中的nn.Unfold()函数和fold(函数详解

1.nn.Unfold()函数

描述:pytorch中的nn.Unfold()函数,在图像处理领域,经常需要用到卷积操作,但是有时我们只需要在图片上进行滑动的窗口操作,将图片切割成patch,而不需要进行卷积核和图片值的卷积乘法操作。这是就需要用到nn.Unfold()函数,该函数是从一个batch图片中,提取出滑动的局部区域块,也就是卷积操作中的提取kernel filter对应的滑动窗口。

torch.nn.Unfold(kernel_size,dilation=1,paddding=0,stride=1)

该函数的输入是(bs,c,h,w),其中bs为batch-size,C是channel的个数。
而该函数的输出是(bs,Cxkernel_size[0]xkernel_size[1],L)其中L是特征图或者图片的尺寸根据kernel_size的长宽滑动裁剪后得到的多个patch的数量。

import torch.nn as nn
import torch
batches_img=torch.rand(1,2,4,4)#模拟图片数据(bs,2,4,4),通道数C为2
print("batches_img:\n",batches_img)

nn_Unfold=nn.Unfold(kernel_size=(2,2),dilation=1,padding=0,stride=2)
patche_img=nn_Unfold(batches_img)
print("patche_img.shape:",patche_img.shape)
print("patch_img:\n",patche_img)

在这里插入图片描述
该方法的主要应用场景是将图片切割成不同的patch,配合一下代码实现

#上面的代码能够获取到patch_img,(bs,C*K*K,L),L代表的是将每张图片分割成多少块
reshape_patche_img=patche_img.view(batches_img.shape[0],batches_img.shape[1],2,2,-1)
print(reshape_patche_img.shape)#[bs, C, k, k, L]
reshape_patche_img=reshape_patche_img.permute(0,4,1,2,3)#[N, L, C, k, k]
print(reshape_patche_img.shape)

结果:
在这里插入图片描述

2.nn.Fold()函数

该函数是nn.Unfold()函数的逆操作。

fold = torch.nn.Fold(output_size=(4, 4), kernel_size=(2, 2), stride=2)
inputs_restore = fold(patches)
print(inputs_restore)
print(inputs_restore.size())

在这里插入图片描述

def init(self, dim, num_heads, kernel_size=3, padding=1, stride=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().init() head_dim = dim // num_heads self.num_heads = num_heads self.kernel_size = kernel_size self.padding = padding self.stride = stride self.scale = qk_scale or head_dim**-0.5 self.v = nn.Linear(dim, dim, bias=qkv_bias) self.attn = nn.Linear(dim, kernel_size**4 * num_heads) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride) self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True) def forward(self, x): B, H, W, C = x.shape v = self.v(x).permute(0, 3, 1, 2) h, w = math.ceil(H / self.stride), math.ceil(W / self.stride) v = self.unfold(v).reshape(B, self.num_heads, C // self.num_heads, self.kernel_size * self.kernel_size, h * w).permute(0, 1, 4, 3, 2) # B,H,N,kxk,C/H attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) attn = self.attn(attn).reshape( B, h * w, self.num_heads, self.kernel_size * self.kernel_size, self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4) # B,H,N,kxk,kxk attn = attn * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).permute(0, 1, 4, 3, 2).reshape( B, C * self.kernel_size * self.kernel_size, h * w) x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size, padding=self.padding, stride=self.stride) x = self.proj(x.permute(0, 2, 3, 1)) x = self.proj_drop(x) return x
03-22
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值