#多输入输出通道
import torch
from d2l import torch as d2l
def corr2d_multi_in(x,k):
return sum(d2l.corr2d(x,k) for x,k in zip(x,k))
x = torch.tensor([[[0.0,1.0,2.0],[3.0,4.0,5.0],[6.0,7.0,8.0]],
[[1.0,2.0,3.0],[4.0,5.0,6.0],[7.0,8.0,9.0]]])
k = torch.tensor([[[0.0,1.0],[2.0,3.0]],[[1.0,2.0],[3.0,4.0]]])
print(corr2d_multi_in(x,k))
#tensor([[ 56., 72.],
# [104., 120.]])
def corr2d_multi_in_out(x,k):
return torch.stack([corr2d_multi_in(x,k) for k in k],0)
k= torch.stack((k,k+1,k+2),0)
print(k.shape)
print(corr2d_multi_in_out(x,k))
# tensor([[[ 56., 72.],
# [104., 120.]],
#
# [[ 76., 100.],
# [148., 172.]],
#
# [[ 96., 128.],
# [192., 224.]]])
#1*1卷积
def corr2d_multi_in_out_1x1(x,k):
c_i,h,w = x.shape
c_o = k.shape[0]
x = x.reshape((c_i,h*w))
k = k.reshape((c_o,c_i))
y = torch.matmul(k,x)
return y.reshape((c_o,h,w))
x = torch.normal(0,1,(3,3,3))
k = torch.normal(0,1,(2,3,1,1))
y1 = corr2d_multi_in_out_1x1(x,k)
y2 = corr2d_multi_in_out(x,k)
assert float(torch.abs(y1-y2).sum()) < 1e-6
21-多输入多输出 动手深度学习
最新推荐文章于 2024-09-16 21:10:20 发布