在运行如下代码时遇到报错
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
RuntimeError: index_copy(): self and source expected to have the same dtype, but got (self) Int and (source) Long
原因是因为数据类型不对,一个时int,一个是long,修改一下数据类型就行。将代码修改为如下代码即可运行。
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind).int())