二维张量dim=1

在PyTorch中处理二维张量(矩阵)时,‌dim=1表示沿着列的横向维度操作‌,也就是逐行处理每一行的所有列元素。以下是详细解释:

直观理解:

想象一个Excel表格:

  • 行(dim=0)‌:垂直方向(从顶部到底部)
  • (dim=1)‌:水平方向(从左到右)

具体操作示例:

import torch

# 创建一个2x3矩阵(2行,3列)
data = torch.tensor([
    [1, 2, 3],  # 第0行
    [4, 5, 6]   # 第1行
])

# dim=1 操作(跨列处理)
max_indices = torch.argmax(data, dim=1)  # 每行找最大值的位置
row_sums = torch.sum(data, dim=1)        # 每行求和
row_means = torch.mean(data, dim=1)      # 每行平均值

print("原始矩阵:\n", data)
print("每行最大值位置(dim=1):", max_indices)  # 输出: tensor([2, 2])
print("每行求和(dim=1):", row_sums)        # 输出: tensor([6, 15])
print("每行平均值(dim=1):", row_means)      # 输出: tensor([2., 5.])

关键特性:

  1. 操作方向‌:

    • dim=1:横向操作 ➡️ 处理每行的所有列元素
    • 相当于:‌"对每一行进行操作"
  2. 维度变化‌:

    # 原始shape: (2, 3)
    # dim=1操作后: 
    #   (2, 3) --[sum(dim=1)]--> (2,) 
    #   行维度保留,列维度被压缩
    
  3. 常见场景‌:

    • 神经网络输出处理(如代码示例中的Q值矩阵)
    • 批量数据中每个样本的特征统计
    • 分类任务中每张图像的预测分数处理

示例代码:

q_values = torch.tensor([[0.1, 0.8, 0.1]])  # 1行3列(1个状态,3个动作的Q值)

torch.argmax(q_values, dim=1) 
# 沿着列方向(dim=1)找最大值位置
# 结果: tensor([1]) → 选择第2个动作(索引从0开始)

记忆技巧:

dim=1 = ‌"行操作"‌ = 把每行看作一个整体处理
(就像Excel中:每行是一个记录,每列是不同特征)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值