PyTorch 基础学习(3) - 张量的数学操作

下面是关于PyTorch中常见数学操作的概述和教程,包括逐点运算、比较操作、线性代数操作等,突出每个操作的重点用法和示例。

逐点操作 (Pointwise Operations)

1. torch.abs
  • 功能: 计算输入张量的每个元素的绝对值。
  • 用法: torch.abs(input)
  • 示例:
    import torch
    tensor = torch.FloatTensor([-1, -2, 3])
    result = torch.abs(tensor)
    print(result)  # 输出: tensor([1., 2., 3.])
    
2. torch.acos
  • 功能: 返回输入张量每个元素的反余弦。
  • 用法: torch.acos(input)
  • 示例:
    a = torch.tensor([0.5, -1.0, 1.0])
    result = torch.acos(a)
    print(result)  # 输出: tensor([1.0472, 3.1416, 0.0000])
    
3. torch.add
  • 功能: 将一个标量值逐元素加到输入张量上。
  • 用法: torch.add(input, value)
  • 示例:
    a = torch.randn(4)
    result = torch.add(a, 10)
    print(result)
    
4. torch.mul
  • 功能: 用标量值乘以输入张量的每个元素。
  • 用法: torch.mul(input, value)
  • 示例:
    a = torch.randn(3)
    result = torch.mul(a, 100)
    print(result)
    

线性代数操作 (Linear Algebra Operations)

1. torch.mm
  • 功能: 矩阵乘法。
  • 用法: torch.mm(mat1, mat2)
  • 示例:
    mat1 = torch.randn(2, 3)
    mat2 = torch.randn(3, 3)
    result = torch.mm(mat1, mat2)
    print(result)
    
2. torch.inverse
  • 功能: 计算方阵的逆矩阵。
  • 用法: torch.inverse(input)
  • 示例:
    x = torch.rand(3, 3)
    result = torch.inverse(x)
    print(result)
    
3. torch.svd
  • 功能: 奇异值分解。
  • 用法: torch.svd(input)
  • 示例:
    a = torch.randn(5, 3)
    u, s, v = torch.svd(a)
    print(u, s, v)
    

比较操作 (Comparison Operations)

1. torch.eq
  • 功能: 比较两个张量的元素是否相等。
  • 用法: torch.eq(input, other)
  • 示例:
    a = torch.tensor([1, 2, 3])
    b = torch.tensor([1, 1, 4])
    result = torch.eq(a, b)
    print(result)  # 输出: tensor([ True, False, False])
    
2. torch.gt
  • 功能: 比较两个张量的元素是否大于另一个。
  • 用法: torch.gt(input, other)
  • 示例:
    a = torch.tensor([1, 2, 3])
    b = torch.tensor([1, 1, 4])
    result = torch.gt(a, b)
    print(result)  # 输出: tensor([False,  True, False])
    
3. torch.max
  • 功能: 返回输入张量所有元素的最大值。
  • 用法: torch.max(input)
  • 示例:
    a = torch.randn(1, 3)
    result = torch.max(a)
    print(result)
    

下面是一个包含逐点操作、线性代数操作和比较操作的PyTorch综合应用实例。该示例展示了如何处理一个简单的线性回归任务。

应用实例:线性回归

在这个例子中,我们将使用PyTorch来进行一个简单的线性回归任务。我们将生成一些合成数据,使用线性回归模型进行拟合,并展示模型参数的更新过程。

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# 生成合成数据
torch.manual_seed(0)  # 固定随机种子
X = torch.linspace(0, 10, 100).view(-1, 1)  # 输入特征
y = 2 * X + 1 + torch.randn(X.size()) * 2  # 线性关系加噪声

# 定义线性回归模型
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(1, 1)  # 输入和输出都是1维

    def forward(self, x):
        return self.linear(x)

# 初始化模型、损失函数和优化器
model = LinearRegressionModel()
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器

# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
    # 前向传播
    outputs = model(X)
    loss = criterion(outputs, y)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 每10个epoch输出一次损失
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 获取训练后的参数
with torch.no_grad():
    slope = model.linear.weight.item()
    intercept = model.linear.bias.item()
    print(f'Learned parameters: Slope={slope:.4f}, Intercept={intercept:.4f}')

# 可视化结果
predicted = model(X).detach().numpy()
plt.plot(X.numpy(), y.numpy(), 'ro', label='Original data')
plt.plot(X.numpy(), predicted, label='Fitted line')
plt.legend()
plt.show()

输出:
在这里插入图片描述

说明

  1. 数据生成: 我们生成了一组带有噪声的线性数据,线性关系为 (y = 2x + 1)。

  2. 模型定义: 使用nn.Linear定义了一个简单的线性回归模型。

  3. 损失函数和优化器: 使用均方误差作为损失函数,随机梯度下降作为优化器。

  4. 训练过程: 在训练过程中,我们执行前向传播计算损失,然后通过反向传播更新模型参数。

  5. 结果展示: 训练完成后,我们打印出学习到的参数,并使用matplotlib绘制原始数据和拟合直线。

总结

  • 逐点操作用于对张量中的每个元素进行相同的数学运算。
  • 线性代数操作涉及矩阵运算,如矩阵乘法、逆矩阵和分解。
  • 比较操作用于比较张量元素之间的关系。
  • PyTorch中的这些操作通常具有广播机制,允许对不同形状的张量进行运算,只要它们的形状是兼容的。
  • 11
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值