Pytorch基础语法

这篇博客详细记录了PyTorch使用中常见的问题,包括数据类型问题、NaN数据处理、梯度影响因素、L2正则化、模型参数查看与初始化、View与Reshape函数的区别、matmul函数的矩阵运算规则以及模型参数的更改方法。通过实例解析,帮助读者理解和解决PyTorch编程中的常见挑战。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

这里主要记录平时使用pytorch遇到的一些问题,目录有些混乱,日后再来整理

Next: pytorch的数据类型问题

问题描述
在用pytorch训练网络的时候,有时候一不注意就可能产生数据类型报错的问题,最经常遇到的是形如Expected object of scalar type Double but got scalar type Float for argument这样的报错

问题解决
上述问题的主要原因是因为,数据类型没有得到统一,因此pytorch不能将两种不同数据类型的torch.Tensor进行运算操作,因此解决方法是:在创建数据时,将数据的类型定义为同一数据类型

解决方法如下

  • 方法1
# from numpy
x = np.random.randn((4,4))

# numpy to torch.Tensor
# with dtype torch.float32
x = torch.tensor(x, dtype=torch.float32)
  • 方法2:
    在模型训练时,会输入input,调用torch.tensor对象的float(), long()方法则依次将数据转换成相应的结构
model.train()
model.forward(x.float())  # float for x
loss = torch.nn.CrossEntropyLoss(x ,y.long())  # long for label

问题解决

更多的可以看这篇文章

Next:关于nan数据的问题

问题描述
在用LSTM训练股票数据时,有的股票在设置的开始日期时就会有停盘,这个时候该股票是不含有价格信息的,因此值为nan,如果不处理这个nan数据,直接放到LSTM中进行训练,由于LSTM cell是沿着时间滑动的,因此某一天的nan值必然会影响到后续的数据,导致最终输出的output, h, c全部都是nan,代码如下:

a = np.array([[np.nan, 1], [4, 2]])
a = np.expand_dims(a, axis=1)
a = torch.tensor(a, dtype=torch.float32)
print(a.size())
print('\n',a)

# out
# (seq_lenth, batch, features),可以看到在第一天出现了nan值
torch.Size([2, 1, 2])

tensor([[[nan, 1.]],
       [[4., 2.]]])

输入LSTM,最终会导致输出全部都是nan

lstm = torch.nn.LSTM(2, 4)
out, (h,c) = lstm.forward(a)
h

# out
tensor([[[nan, nan, nan, nan]]], grad_fn=<StackBackward>)

但是注意:nan只会出现在拥有nan的那一个batch,其余的batch不会受到影响,例如以下:

a = np.array([[np.nan, 1], [4, 2]])
b = np.array([[2, 3], [4, 5]])
x = np.dstack((a,b))
print(x.shape)
x = np.transpose(x, axes=(0,2,1))
x = torch.tensor(x, dtype=torch.float32)
print(x.size())
print('\n',x)

lstm = torch.nn.LSTM(2, 4)
out, (h,c) = lstm.forward(x)
print(h)

# out
(2, 2, 2)
torch.Size([2, 2, 2])

tensor([[[nan, 1.],
         [2., 3.]],
        [[4., 2.],
         [4., 5.]]])

# 可以看到,只是在第一条batch全是nan
tensor([[[    nan,     nan,     nan,     nan],
         [ 0.0822, -0.1743, -0.0854, -0.0826]]], grad_fn=<StackBackward>)

问题解决
将nan全部填充为0,然后创建mask标签,这样就不会产生梯度了,也不会将nan计算到损失函数中


Next:哪些操作会影响梯度?

问题描述
在尝试Pytorch的自动求导功能时,想知道究竟哪些操作会影响梯度,哪些不会影响,也算做一个测试吧。

问题解决
结论:只有数值运算操作会影响梯度,其余的拼接、复制、维度变化都不会影响pytorch对梯度的求取

  • 拼接concat操作不会影响
a = torch
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值