import torch.nn as nn
import torch
# 创建三维tensor
a = torch.randn(3,4,5)
print(a.shape)
print(a)
# 升维,升成四维
a = torch.unsqueeze(a, 0)
print(a.shape)
print(a)
# AdaptiveAvgPool2d(X) 是将W H 使用平均池化降为X维
avg = nn.AdaptiveAvgPool2d(1)
b = avg(a)
print(a.shape)
print(b)
print(b.shape)