我在进行Pytorch网络向mindspore迁移时,发现mindspore里缺少一个einsum算子。经过查阅资料发现,numpy中有类似的api。所以就想在网络构建时调用numpy中的
einsum算子替代。
我看mindspore文档中说是,在construct函数中不能调用第三方算子,于是我先在init函数中初始化了一个np.einsum对象,然后在construct函数中调用该算子。
class SpatialGraphConv(nn.Cell):
def __init__(self, in_channels, out_channels, max_graph_distance):
super(SpatialGraphConv, self).__init__()
# spatial class number (distance = 0 for class 0, distance = 1 for class 1, ...)
self.s_kernel_size = max_graph_distance + 1
# weights of different spatial classes
self.gcn = nn.Conv2d(in_channels=in_channels, out_channels=out_channels*self.s_kernel_size,
kernel_size=1, has_bias=True, pad_mode='valid')
self.einsum = np.einsum
def construct(self, x, A):
# numbers in same class have same weight
x = self.g