../d2l/torch.py中的lambda表达式

astype

astype = lambda x, *args, **kwargs: x.type(*args, **kwargs)
cmp = d2l.astype(y_hat, y.dtype) == y

解释:x接受第0个参数y_hat,args接收其它后面的参数y.dtype(这里是torch.int64),x.type是将x的元素强制转换成某个属性。综合起来这个lambda的意思是将y_hat的元素类型设置为和y.dtype一样的类型。

reduce_sum

reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs)
d2l.reduce_sum(d2l.astype(cmp, y.dtype))

解释:等价于d2l.astype(cmp, y.dtype).sum();

另外,bool类型是直接可以加和的,例子如下:

a = torch.tensor([[False, True], [True, True]])
a.sum()
output: tensor(3)

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值