错误如图:
a[:,:2]的shape为(2,2);b[:,:2]的shape为(3,2)。比较的时候出错如下
解决方法:
也就是在a的中间加一个维度,且该维度长度为1:
原因在于:torch.max本质上还是要比较两个形状相同的张量的,当两个张量不相同的时候,需要通过张量的广播机制,将两者先变成张量一致,在进行比较。而这里涉及到张量的广播,其原理参考:https://blog.csdn.net/luoganttcc/article/details/117925855
张量广播原理简述:广播其实就是复制粘贴,当两个张量shape不同时,若满足以下两个条件之一,则可以进行张量的广播使得两者形状相同,再进行运算
1:张量后缘维度的轴长相符,就是说从后往前数张量的维度相同。比如:
shape[1,2,3],shape[2,3]从后往前数维度分别为3,2是一致的,这个时候可以进行广播(先增加一个轴,再在这个轴上复制原来的数据)
2:另一种是有一方的长度为1。这个条件其实是为第一个条件做准备的。比如:
a= shape[1,2,3],b = shape[1,3]。先把b的1广播到2变成[2,3],再继续广播成[1,2,3]