position = torch.arange(0, seq_length).expand(N,seq_length).to(device)
trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand( N, 1, trg_len, trg_len )
均会出错
修改后:
position = torch.arange(0, seq_length) positions = position.expand(N, seq_length).to(myDevice)
trg_mask1 = torch.tril(torch.ones((trg_len, trg_len))) trg_mask = trg_mask1.expand( N, 1, trg_len, trg_len )