记录一下9.7节中红线部分(不熟悉python里边一些用法)
是什么意思?
def sequence_mask(X,valid_len,value=0):
maxlen = X.size(1) #3
print(maxlen)
mask = torch.arange((maxlen),dtype=torch.float32,device=X.device)[None,:] < valid_len[:,None]
X[~mask] = value
return X
X = torch.tensor([[1,2,3],[4,5,6]])
sequence_mask(X,torch.tensor([1,2]))
-------------------------------------------------------------------------------------------
例1:
例2:
综上(个人理解):
[None,:] 可以理解为,保持最内层列数不变,将最外层元素作为基本单元进行分割。
例如(2,3) -> (1,2,3)
(3,) -> (1,3)
[:,None] 可以理解为,保持最外层的行数不变,将最内层作为基本单元进行分割。
例如(2,3) -> (2,1,3)
(3,) ->(3,1)