pytorch gather

y
Out[34]: tensor([0, 2])
y_hat
Out[35]: 
tensor([[0.1000, 0.3000, 0.6000],
        [0.3000, 0.2000, 0.5000]])

y_hat.gather(1, y.view(-1, 1))

聚合方向y_hat的维度1,聚合位置:

a[0][y.view(-1,1)[0]] = 0.1

a[1][y.view(-1,1)[1]] = 0.5

发布了23 篇原创文章 · 获赞 5 · 访问量 1万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 大白 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览