torch中的替换操作

目录

1.1通过比较操作得到布尔矩阵

1.2布尔矩阵作为索引

1.3布尔矩阵的强转


 1.通过比较替换

1.1通过比较操作得到布尔矩阵

a = torch.rand((5, 6), dtype=torch.float32)
print(a)
print(a > 0.5)
-----------------------------------------------------------------------------------------
tensor([[0.7172, 0.0393, 0.0810, 0.7734, 0.0044, 0.3191],
        [0.3387, 0.2647, 0.3805, 0.0339, 0.5828, 0.2851],
        [0.9486, 0.1131, 0.4608, 0.9621, 0.2813, 0.1919],
        [0.8794, 0.8339, 0.7018, 0.3440, 0.8190, 0.2513],
        [0.4571, 0.0413, 0.0869, 0.0827, 0.2947, 0.1362]])
tensor([[ True, False, False,  True, False, False],
        [False, False, False, False,  True, False],
        [ True, False, False,  True, False, False],
        [ True,  True,  True, False,  True, False],
        [False, False, False, False, False, False]])

1.2布尔矩阵作为索引

当布尔矩阵出现在下标位置时,充当索引的角色,可以进行赋值、数学运算等操作。

a[a > 0.5] = 1

对应于True位置的数值替换为1.0 

a = torch.rand((5, 6), dtype=torch.float32)
index = a > 0.5
print(a)
print(a > 0.5)
a[a > 0.5] = 1
print(a)

-------------------------------------------------------------------------------------------
tensor([[0.7869, 0.9450, 0.6493, 0.3837, 0.0440, 0.7851],
        [0.0643, 0.2934, 0.4264, 0.5007, 0.4667, 0.8721],
        [0.7292, 0.8509, 0.0858, 0.7933, 0.3860, 0.4047],
        [0.6582, 0.5485, 0.4721, 0.8394, 0.6033, 0.9711],
        [0.4886, 0.2380, 0.7981, 0.9216, 0.4049, 0.7381]])
tensor([[ True,  True,  True, False, False,  True],
        [False, False, False,  True, False,  True],
        [ True,  True, False,  True, False, False],
        [ True,  True, False,  True,  True,  True],
        [False, False,  True,  True, False,  True]])
tensor([[1.0000, 1.0000, 1.0000, 0.3837, 0.0440, 1.0000],
        [0.0643, 0.2934, 0.4264, 1.0000, 0.4667, 1.0000],
        [1.0000, 1.0000, 0.0858, 1.0000, 0.3860, 0.4047],
        [1.0000, 1.0000, 0.4721, 1.0000, 1.0000, 1.0000],
        [0.4886, 0.2380, 1.0000, 1.0000, 0.4049, 1.0000]])

1.3布尔矩阵的强转

(decoder_out > 0.05).type(torch.int16)

通过强转也可以将布尔矩阵转为数值矩阵,然后在做数值运算。

a = torch.rand((5, 6), dtype=torch.float32)
print(a)
print((a > 0.5).type(torch.int16))

--------------------------------------------------------------------------------------
tensor([[0.6869, 0.0706, 0.1450, 0.2567, 0.9260, 0.2848],
        [0.8413, 0.1677, 0.2624, 0.7488, 0.5229, 0.0857],
        [0.9411, 0.9034, 0.0416, 0.4502, 0.9404, 0.3534],
        [0.9312, 0.5303, 0.3516, 0.2819, 0.3869, 0.4441],
        [0.0093, 0.5280, 0.7881, 0.1460, 0.9870, 0.1433]])
tensor([[1, 0, 0, 0, 1, 0],
        [1, 0, 0, 1, 1, 0],
        [1, 1, 0, 0, 1, 0],
        [1, 1, 0, 0, 0, 0],
        [0, 1, 1, 0, 1, 0]], dtype=torch.int16)

2.通过块索引替换

类似于python的列表索引切片,tensor也可以这样做。

a = torch.rand((5, 6), dtype=torch.float32)
print(a)
a[0:2, 0:2] = 1
print(a)

-------------------------------------------------------------------------------------------
tensor([[0.2457, 0.9803, 0.6518, 0.7468, 0.8744, 0.9395],
        [0.5727, 0.4081, 0.2449, 0.3435, 0.5808, 0.9078],
        [0.6032, 0.8980, 0.8591, 0.9664, 0.1635, 0.8161],
        [0.0347, 0.3219, 0.0546, 0.2887, 0.6355, 0.8978],
        [0.0848, 0.8553, 0.6150, 0.6221, 0.9916, 0.0254]])
tensor([[1.0000, 1.0000, 0.6518, 0.7468, 0.8744, 0.9395],
        [1.0000, 1.0000, 0.2449, 0.3435, 0.5808, 0.9078],
        [0.6032, 0.8980, 0.8591, 0.9664, 0.1635, 0.8161],
        [0.0347, 0.3219, 0.0546, 0.2887, 0.6355, 0.8978],
        [0.0848, 0.8553, 0.6150, 0.6221, 0.9916, 0.0254]])

3.clip操作

pytorch官方的 api 输入一个tensor input,设定最小值min 最大值max,超过最大值的数值会被设置为最大值,小于最小值的数值被设置为最小值。

torch.clip(input, min=None, max=None, *, out=None) → Tensor

相同的还有torch.clamp API,作用与clip相同。

torch.clamp(inputmin=Nonemax=None*out=None) → Tensor 

a = torch.randn(4)
print(a)
torch.clamp(a, min=-0.5, max=0.5)
print(a)

-------------------------------------------------------------------------------------------

tensor([-1.7120,  0.1734, -0.0478, -0.0922])

tensor([-0.5000,  0.1734, -0.0478, -0.0922])

 

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 我可以提供一些建议,但无法保证它可以在没有改变原模型的情况下替换 torch.gather 操作。首先,可以尝试使用 torch.index_select 来实现同样的功能。其次,可以考虑使用 torch.masked_select 来替换 torch.gather,具体可以参考 PyTorch 的文档。 ### 回答2: 在PyTorch,如果想要在不改变原模型的情况下替换forward函数torch.gather操作,可以使用torch.index_select函数来实现相同的功能。torch.index_select函数接受一个tensor和一个维度索引作为参数,返回按照指定维度索引的元素。 首先,我们需要理解torch.gather操作的作用。torch.gather可以按照指定的维度,在一个tensor进行索引,并返回相应的值。例如,对于一个大小为(3, 4)的tensor,我们可以通过torch.gather(tensor, 0, index)来按照第0个维度的索引index来获取对应值。 下面是一个示例代码,展示如何使用torch.index_select替换forward函数torch.gather操作: ```python import torch import torch.nn as nn class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.weights = nn.Parameter(torch.randn(3, 4)) def forward(self, index): # 使用torch.gather操作 output = torch.gather(self.weights, 0, index) return output def replace_forward(self, index): # 使用torch.index_select替换torch.gather操作 output = torch.index_select(self.weights, 0, index) return output ``` 在上面的示例代码,MyModel类的forward函数使用了torch.gather操作,而replace_forward函数则使用了torch.index_select来实现相同的功能。这样,我们可以在不改变原模型的情况下替换forward函数torch.gather操作。 ### 回答3: 在不改变原模型的情况下,我们可以通过使用其他的操作替换`torch.gather`。 `torch.gather`操作通常用于根据索引从输入的张量提取特定元素。它的一般形式是`torch.gather(input, dim, index, out=None)`,其`input`是输入张量,`dim`是提取索引的维度,`index`是包含提取索引的张量。 为了替换`torch.gather`操作,我们可以使用`torch.index_select`和`torch.unsqueeze`来实现相似的功能。 首先,我们可以使用`torch.index_select`操作来选择指定维度上的索引。这个操作的一般形式是`torch.index_select(input, dim, index, out=None)`,其`input`是要选择的张量,`dim`是选择的维度,`index`是包含索引的一维张量。 然后,我们可以使用`torch.unsqueeze`操作来在选择的维度上增加一个维度。这个操作的一般形式是`torch.unsqueeze(input, dim, out=None)`,其`input`是要增加维度的张量,`dim`是要增加的维度。 综上所述,为了替换`torch.gather`操作,我们可以使用以下代码: ```python import torch class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, input, index): # 替换 torch.gather 的操作 output = torch.index_select(input, 1, index.unsqueeze(1)).squeeze(1) return output ``` 在上面的代码,我们使用`torch.index_select`选择了指定维度`dim=1`上的索引,并使用`torch.unsqueeze`增加了一个维度。最后,我们使用`squeeze`操作将这个额外的维度去除,以匹配`torch.gather`操作的输出。 这样,我们就在不改变原模型的情况下替换了`torch.gather`操作,实现了相似的功能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值