方法一:reshape
如果我们将已经展平的一维向量,恢复成原来的特征图形状为[batch_size, channels, height, width]的张量,可以写成以下代码:
import torch
# 定义已经展平的一维向量x
x = torch.randn(10, 64)
# 将一维向量x恢复为原来的结构[batch_size, channels, height, width]
x = x.reshape(10, 4, 4, 4) # 假设特征图的大小为:[4, 4]
其中,torch.randn(10, 64)定义了一个 batch_size 为 10,特征图大小为 4x4,通道数为 64 的特征图。x.reshape(10, 4, 4, 4)将一维向量 x 恢复回原来的结构。
需要注意的是,使用 reshape函数 需要保证恢复后的尺寸与之前卷积神经网络的输出结构匹配,不然会导致后续操作中的维度错误。如果在卷积神经网络中使用了池化或升采样等操作,恢复的特征图的大小也需要相应修改。
方法二:x.view
import torch
x = torch.randn(2, 3, 4, 5)
flatten_x = torch.flatten(x)
print(flatten_x.shape) # torch.Size([120])
# 使用view将其恢复为4D的张量
x_reconstruct = flatten_x.view(2, 3, 4, 5)
print(x_reconstruct.shape) # torch.Size([2, 3, 4, 5])
可以看到,通过使用view函数,我们可以方便地将一维向量恢复成原来的张量结构。需要注意的是,使用view函数后得到的新张量与原来的张量共享相同的数据存储,这也就意味着修改其中一个张量的值会同时影响到另一个张量的值。因此,在处理过程中需要特别注意。