问题描述
分析原因
使用torch.sparse.mm进行矩阵相乘,DV2和H两个均为Sparse矩阵。而torch不支持Sparse和Sparse的矩阵相乘,也不支持Dense和Sparse的矩阵相乘,只支持Sparse和Dense的矩阵相乘。
sparse:torch.sparse.FloatTensor、torch.sparse.LongTensor等
dense:torch.FloatTensor、torch.LongTensor等
sparse矩阵可以通过to_dense()转换为dense类型
同样,dense矩阵可以通过to_sparse()转换为sparse类型
解决办法
DV2_H = torch.sparse.mm(DV2, H)
修改为
DV2_H = torch.sparse.mm(DV2, H.to_dense())