torch.sactter_()的用法简析
明天就要开周会了,今天我的ppt和内容都还没有做,集中注意力学习了一下scatter_()
函数,弄清了二维和独热编码是的工作原理。
如果你经常看到类似下面的东西:
y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
经常性的感觉到困惑,查看官方文档,不知道内在如何操作,查看博客,一大堆照抄的内容,标点符号都没有改。今天我可以讲清楚 二维 torch tensor
结构的内容。
官方函数: 将src中的所有值按照index确定的索引写入本tensor中。其中索引是根据给定的dimension,dim按照gather()描述的规则来确定。
注意,index的值必须是在_0_到_(self.size(dim)-1)_之间,
scatter_(input, dim, index, src) → Tensor
参数: - input (Tensor)-源tensor - dim (int)-索引的轴向 - index (LongTensor)-散射元素的索引指数 - src (Tensor or float)-散射的源元素。