def batch_loss(encoder, decoder, X, Y, loss):
batch_size = X.shape[0]
enc_state = encoder.begin_state(batch_size=batch_size)
enc_outputs, enc_state = encoder(X, enc_state)
# 初始化解码器的隐藏状态
dec_state = decoder.begin_state(enc_state)
# 解码器在最初时间步的输入是BOS
dec_input = nd.array([out_vocab.token_to_idx[BOS]] * batch_size)
# 我们将使用掩码变量mask来忽略掉标签为填充项PAD的损失
mask, num_not_pad_tokens = nd.ones(shape=(batch_size,)), 0
bflag1=1
print("调用我的bflag1=",bflag1,"type(bflag1)=",type(bflag1))
if(bflag1 == 1):
print("type(mask)=",type(mask),"origin mask=",mask,"mask.shape=",mask.shape)
bflag1 = bflag1 +1
l = nd.array([0])
for y in Y.T:
if(bflag1==2):
print("y=",y)
bflag1 = bflag1+1
dec_output, dec_state = decoder(dec_input, dec_state, enc_outputs)
l = l + (mask * loss(dec_output, y)).sum()
dec_input = y # 使用强制教学
num_not_pad_tokens += mask.sum().asscalar()
# 当遇到EOS时,序列后面的词将均为PAD,相应位置的掩码设成0
mask = mask * (y != out_vocab.token_to_idx[EOS])
if(bflag1==3):
print("mask =",mask,"mask.shape=",mask.shape,"(y != out_vocab.token_to_idx[EOS])=",(y != out_vocab.token_to_idx[EOS]))
bflag1 = bflag1+1
return l / num_not_pad_tokens
输出 注意mask 与 (y != out_vocab.token_to_idx[EOS])的结果都是ndarray
X=
[[ 9. 30. 27. 45. 4. 3. 1.]
[12. 5. 44. 37. 16. 4. 3.]]
<NDArray 2x7 @cpu(0)> Y=
[[ 7. 5. 31. 11. 18. 4. 3.]
[ 7. 5. 11. 27. 29. 4. 3.]]
<NDArray 2x7 @cpu(0)> Y.T=
[[ 7. 7.]
[ 5. 5.]
[31. 11.]
[11. 27.]
[18. 29.]
[ 4. 4.]
[ 3. 3.]]
<NDArray 7x2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[7. 7.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[8. 9.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[8. 8.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[9. 8.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[7. 7.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[8. 8.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[7. 9.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[7. 9.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[9. 7.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[8. 9.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[8. 7.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[8. 9.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[9. 9.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[7. 8.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[7. 9.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[9. 7.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[7. 8.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[7. 8.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[8. 7.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
调用我的bflag1= 1 type(bflag1)= <class 'int'>
type(mask)= <class 'mxnet.ndarray.ndarray.NDArray'> origin mask=
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,)
y=
[9. 8.]
<NDArray 2 @cpu(0)>
mask =
[1. 1.]
<NDArray 2 @cpu(0)> mask.shape= (2,) (y != out_vocab.token_to_idx[EOS])=
[1. 1.]
<NDArray 2 @cpu(0)>
(y != out_vocab.token_to_idx[EOS])= <class 'mxnet.ndarray.ndarray.NDArray'>