mxnet中ndarray*ndarray用来作为掩码进行与运算的用法

def batch_loss(encoder, decoder, X, Y, loss):
    batch_size = X.shape[0]
    enc_state = encoder.begin_state(batch_size=batch_size)
    enc_outputs, enc_state = encoder(X, enc_state)
    # 初始化解码器的隐藏状态
    dec_state = decoder.begin_state(enc_state)
    # 解码器在最初时间步的输入是BOS
    dec_input = nd.array([out_vocab.token_to_idx[BOS]] * batch_size)
    # 我们将使用掩码变量mask来忽略掉标签为填充项PAD的损失
    mask, num_not_pad_tokens = nd.ones(shape=(batch_size,)), 0
    bflag1=1
    print("调用我的bflag1=",bflag1,"type(bflag1)=",type(bflag1))
    if(bflag1 == 1):
        print("type(mask)=",type(mask),"origin mask=",mask,"mask.shape=",mask.shape)
        bflag1 = bflag1 +1
    l = nd.array([0])   
    for y in Y.T:
        if(bflag1==2):
            print("y=",y)
            bflag1 = bflag1+1
        dec_output, dec_state = decoder(dec_input, dec_state, enc_outputs)
        l = l + (mask * loss(dec_output, y)).sum()
        dec_input = y  # 使用强制教学
        num_not_pad_tokens += mask.sum().asscalar()
        # 当遇到EOS时,序列后面的词将均为PAD,相应位置的掩码设成0
        mask = mask * (y != out_vocab.token_to_idx[EOS])
        if(bflag1==3):
            print("mask =",mask,"mask.shape=",mask.shape,"(y != out_vocab.token_to_idx[EOS])=",(y != out_vocab.token_to_idx[EOS]))
            bflag1 = bflag1+1
    return l / num_not_pad_tokens

输出   注意mask 与 (y != out_vocab.token_to_idx[EOS])的结果都是ndarray

X= 
[[ 9. 30. 27. 45.  4.  3.  1.]
 [12.  5. 44. 37. 16.  4.  3.]]
<NDArray 2x7 @cpu(0)> Y= 
[[ 7.  5. 31. 11. 18.  4.  3.]
 [ 7.  5. 11. 27. 29.  4.  3.]]
<NDArray 2x7 @cpu(0)> Y.T= 
[[ 7.  7.]
 [ 5.  5.]
 [31. 11.]
 [11. 27.]
 [18. 29.]
 [ 4.  4.]
 [ 3.  3.]]
<NDArray 7x2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[7. 7.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[8. 9.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[8. 8.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[9. 8.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[7. 7.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[8. 8.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[7. 9.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[7. 9.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[9. 7.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[8. 9.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[8. 7.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[8. 9.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[9. 9.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[7. 8.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[7. 9.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[9. 7.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[7. 8.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[7. 8.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[8. 7.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask= 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y= 
[9. 8.]
<NDArray 2 @cpu(0)>
mask = 
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])= 
[1. 1.]
<NDArray 2 @cpu(0)>
(y != out_vocab.token_to_idx[EOS])= <class 'mxnet.ndarray.ndarray.NDArray'>
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值