tf 1.x 对应的 pytorch 3.6+ 函数
- tf.nn.l2_loss(变量) <–> torch.norm(tensor)** 2 / 2
a = tf.Variable([[3.0, 4.0], [3.0, 4.0]])
a_l2_loss = tf.nn.l2_loss(a)
a_tor = torch.tensor([[3.0, 4.0], [3.0, 4.0]], dtype=torch.float32)
a_torch = torch.norm(a_tor) ** 2 / 2
print(a_l2_loss)
print(a_torch)
- tf.nn.embedding_lookup < – > torch.index_select
a = tf.Variable([[3.0, 4.0], [5.0, 6.0]])
a_tor = torch.tensor([[3.0, 4.0], [5.0, 6.0]], dtype=torch.float32)
index = torch.tensor([1])
print(a)
print(a_tor)
a_select = tf.nn.embedding_lookup(a, index)
b_select = torch.index_select(a_tor, 0, index)
print(a_select)
print(b_select)
- 大坑,由于这个我找了4h的bug,o(╥﹏╥)o
keep_prob = prob
tf.nn.dropout(keep_prob) < – > torch.nn.Dropout(prob)(tensor)
=torch.nn.F.dropout(tensor, prob)