python打印变量懒人版 locals

python print的时候如果想要print x=xvalue格式变量名得输出两边,下面是懒人搞法,但后缀有些长,如果有更好的方法欢迎推荐:

import torch
import numpy as np


def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, len_q]
    seq_k: [batch_size, len_k]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

def get_attn_subsequence_mask(seq):
    '''
    seq: [batch_size, tgt_len]
    '''
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1) # Upper triangular matrix
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask

#不太work,输出的变量名成val了
def print_val(val):
    tmp = locals()
    print('{val}'.format(**{k: '{}={}'.format(k, v) for (k, v) in locals().items()}))

if __name__ == '__main__':
    seq = torch.Tensor([[2.0,3.0,5.0],[1.0,2.0,0.0]])
    pad_mask = get_attn_pad_mask(seq,seq)
    print(f'pad_mask={pad_mask}')
    print('pad_mask={}'.format(pad_mask))
    print('{pad_mask}'.format(**{k: '{}={}'.format(k, v) for (k, v) in locals().items()}))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值