pytorch 中遇到的 code

  1. masked_copy_(mask, source)
    将mask中值为1元素对应的source中位置的元素复制到本tensor中。mask应该有和本tensor相同数目的元素。
a = torch.zeros(3, 4).byte()
index = torch.LongTensor([0])
mask = a.index_fill_(0, index, 1)
print(mask)

source = torch.randn(3, 4)
print(source)

target = torch.ones(3,4)
target.masked_scatter_(mask, source)
print(target)

tensor([[1, 1, 1, 1],
        [0, 0, 0, 0],
        [0, 0, 0, 0]], dtype=torch.uint8)
tensor([[ 0.2118,  1.3267,  0.9646,  0.8180],
        [ 1.7216,  0.4548, -1.3549,  1.5633],
        [-1.1666, -0.5386,  0.2197,  0.0596]])
tensor([[0.2118, 1.3267, 0.9646, 0.8180],
        [1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000]])'''
  1. log_sum_exp(vec, m_size), 以下两个式子得到的结果相同,后者的处理是为了防止上溢问题
vec = torch.FloatTensor([[[-1.1880,  0.2350, -1.0619],
                          [ 0.9061,  0.7829, -1.1289]],

                         [[ 0.8700, -0.6614,  1.5453],
                          [-0.0818,  0.7460,  3.2312]]])
res_ = torch.log(torch.sum(torch.exp(vec),1))

#tensor([[ 1.0223,  1.2392, -0.4017],
#        [ 1.1965,  0.9650,  3.4012]])
def log_sum_exp(vec, m_size):
    """
    结果和右式相同:torch.log(torch.sum(torch.exp(vec),1))
    直接计算可能会出现 exp(999)=INF 上溢问题
    所以 考虑 torch.max(vec, 1)这部分, 以避免 上溢问题

    Args:
        vec: size=(batch_ size, vanishing_dim, hidden_dim)
        m_size: hidden_dim
    Returns:
        size=(batch_size, hidden_dim)
        
    """
    _, idx = torch.max(vec, 1)  # B * 1 * M ,为了防止 log(过大值max),所有值减去每列最大值
    max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size)  # B * M
    return max_score.view(-1, m_size) + torch.log(torch.sum(
        torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值