pytorch常用mask命令


前言

在这里插入图片描述
mask是深度学习里面常用的操作,最近在研究transformer的pytorch代码,总能看到各种mask的命令,在这里总结一下

1.Tensor.masked_fill_(mask, value)

Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor.

Parameters
mask (BoolTensor) – the boolean mask
value (float) – the value to fill in with

举个例子

import torch
mask = torch.tensor([[1, 0, 0], [0, 1, 0],  [0, 0, 1]]).bool()
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
a = torch.randn(3,3)
a.masked_fill(mask, 0)
# tensor([[ 0.0000,  0.6781,  0.6532],
#         [-1.2078,  0.0000,  0.4964],
#         [ 0.2192, -0.6276,  0.0000]])
a.masked_fill(~mask, 0)#可以对mask取反
# tensor([[-0.4438,  0.0000,  0.0000],
#         [ 0.0000,  1.3907,  0.0000],
#         [ 0.0000,  0.0000,  2.2462]])

2.torch.masked_select(input, mask, *, out=None) → Tensor

Returns a new 1-D tensor which indexes the input tensor according to the boolean mask mask which is a BoolTensor.
The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.

(注意)The returned tensor does not use the same storage as the original tensor

Parameters
input (Tensor) – the input tensor.
mask (BoolTensor) – the tensor containing the binary mask to index with

举个例子

import torch
x = torch.randn(3,4)
# tensor([[ 0.2914, -0.1056,  0.4946,  0.2926],
#         [-1.0920, -0.2156,  3.0989, -0.9067],
#         [-0.1522,  1.9527,  0.1660,  0.8310]])
mask = x > 0.5
# tensor([[ 0.2914, -0.1056,  0.4946,  0.2926],
#         [-1.0920, -0.2156,  3.0989, -0.9067],
#         [-0.1522,  1.9527,  0.1660,  0.8310]])
torch.masked_select(x, mask)
# tensor([3.0989, 1.9527, 0.8310])

3.Tensor.masked_scatter_(mask, source)

Tensor.masked_scatter_(mask, source)
Copies elements from source into self tensor at positions where the mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor. The source should have at least as many elements as the number of ones in mask

source大小和mask至少一样,能够被广播到Tensor上,或者source和Tensor一样
作用就是把source里mask是true的位置挑出来给Tensor

Parameters
mask (BoolTensor) – the boolean mask
source (Tensor) – the tensor to copy from

举个例子

import torch
mask = torch.BoolTensor([[1, 0, 0], [0, 1, 0],  [0, 0, 1]])
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
a = torch.randn(2,3,3)
s = torch.ones_like(a)
a.masked_scatter(mask, s)
# tensor([[[ 1.0000, -0.1560, -0.7760],
#          [-0.5192,  1.0000, -0.1709],
#          [ 0.2091,  0.5650,  1.0000]],

#         [[ 1.0000,  0.0623, -0.1447],
#          [-1.2910,  1.0000, -1.2722],
#          [-0.7864, -0.1118,  1.0000]]])
  • 3
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch 是一个流行的深度学习框架,下面是一些 PyTorch 常用命令: 1. 张量创建和操作: - `torch.tensor(data)`:根据给定数据创建张量。 - `torch.zeros(shape)`:创建指定形状的全零张量。 - `torch.ones(shape)`:创建指定形状的全一张量。 - `torch.rand(shape)`:创建指定形状的随机张量。 - `torch.Tensor.size()`:获取张量的形状。 - `torch.Tensor.view(shape)`:改变张量的形状。 2. 张量运算: - `torch.add(tensor1, tensor2)`:将两个张量相加。 - `torch.sub(tensor1, tensor2)`:将一个张量减去另一个张量。 - `torch.mul(tensor1, tensor2)`:将两个张量相乘。 - `torch.div(tensor1, tensor2)`:将一个张量除以另一个张量。 - `torch.mm(tensor1, tensor2)`:执行矩阵乘法操作。 3. 自动求导: - `tensor.requires_grad_(True)`:启用张量的自动求导功能。 - `tensor.backward()`:计算张量的梯度。 - `optimizer = torch.optim.SGD(parameters, lr=0.01)`:定义一个优化器,如随机梯度下降(SGD)。 - `optimizer.step()`:执行优化器的一步更新。 4. 模型构建和训练: - 定义模型类和前向传播函数。 - 定义损失函数,如交叉熵损失。 - 定义优化器。 - 在训练循环中执行前向传播、计算损失、反向传播、优化器更新等操作。 这只是一些 PyTorch 常用命令的简单示例,PyTorch 还提供了许多其他功能和命令。你可以查阅官方文档以获取更多详细信息。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值