高维度计算
from torch import nn
class DotProductSimilarity(nn.Module):
def __init__(self,scale_output=True):
super(DotProductSimilarity,self).__init__()
self.scale_output=scale_output
def forward(self,tensor_1,tensor_2):
result=(tensor_1*tensor_2).sum(dim=-1)
if (self.scale_output):
result /= math.sqrt(tensor_1.size())
return result
input_1=torch.ones(2,1,720,1280,requires_grad=True)
input_2=torch.ones(2,1,720,1280,requires_grad=True)
input_3=torch.ones(2,1,3,5,requires_grad=True)
con= DotProductSimilarity()
CS=con(input_1,input_2)
print(CS.shape)
CS = CS.unsqueeze(3)
s = CS*input_1
print(input_1.shape,input_2.shape,CS.shape,s.shape)
print(s)