nn.Unfold or F.unfold
Extracts sliding local blocks from a batched input tensor.
将一个分批次的输入张量按局部块状区域逐次展开。
核心思想
作用于image-like的输入张量或特征图。作用后的特征图特征为原始特征图空间邻域展开、并逐通道连接后得到的新的按照邻域信息-通道信息排列组合的特征图。
图示
[
1.11 1.12
1.21 1.22
2.11 2.12
2.21 2.22
> x.yz: x=通道 y=高 z=宽
> (1, 1)位置点的特征为(1.11 2.11)
]
-> nn.Unfold(kernel_size=3, padding=1)
unfold后(1, 1)位置点的特征为
通道1邻域展开 [0 0 0 0 1.11 1.12 0 1.21 1.22]
衔接
通道2邻域展开 [0 0 0 0 2.11 2.12 0 2.21 2.22]
也即
[0 0 0 0 1.11 1.12 0 1.21 1.22] + [0 0 0 0 2.11 2.12 0 2.21 2.22]
[0 0 0 0 1.11 1.12 0 1.21 1.22 0 0 0 0 2.11 2.12 0 2.21 2.22]
代码解释
import torch
from torch import nn
# 创建一个测试输入
data = torch.tensor([[[1.11, 1.12], [1.21, 1.22]], [[2.11, 2.12], [2.21, 2.22]]])
assert data.shape == (2, 2, 2)
# unfold函数要求输入数据为四维张量,当输入为image-like的张量时,其功能可解释为于im2col
# 为测试数据添加batch维
data = data[None, ...]
assert data.shape == (1, 2, 2, 2)
# 创建局部块提取核,与卷积核参数类似,决定局部块的大小和提取方式
unfold = nn.Unfold(kernel_size=3, padding=1)
unfold_data = unfold(data)
# 其中2的含义与上述shape的输出依次对应(通道、高、宽),3表示局部块核尺寸
assert unfold_data.shape == (1, 3 * 3 * 2, 2 * 2)
# 将最后两个维度恢复为原始图像尺寸,通道层对应逐通道展开后局部块区域的一维特征表示,此时特征图表示为unfold前特征图对应点位的局部块展开特征表示
unfold_data = unfold_data.view(1, 3 * 3 * 2, 2, 2)