Pytorch自定义参数

博客介绍了Pytorch自定义参数相关内容。若想灵活使用模型,可能需自定义参数,如定义参数矩阵A,默认情况下模型不包含A,训练时不更新,移到GPU上也不会跟随。要使模型包含参数,需手动注册参数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Pytorch自定义参数

如果想要灵活地使用模型,可能需要自定义参数,比如

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.A = torch.randn((2,3),requires_grad=True)
        self.B = nn.Linear(2,2)
	def forward(self,x):
     	pass

这里在模型里定义了一个参数矩阵A,但输出模型的参数会发现

>>>net = Net()
>>>for i in net.parameters():
...    print(i)

Parameter containing:
tensor([[-0.6075,  0.5390],
        [ 0.5895, -0.3631]], requires_grad=True)
Parameter containing:
tensor([-0.4341, -0.1234], requires_grad=True)

模型中并没有A,而且模型训练的时候,也不会更新A,将模型移到GPU上时,A也不会跟着走,如果自定义参数,需要手动注册参数

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        A = torch.randn((2,3),requires_grad=True)
        self.A = torch.nn.Parameter(A)
        self.B = nn.Linear(2,2)
        self.register_parameter("Ablah",self.A)
	def forward(self,x):
	     return x

这样就可以使模型包含参数A了

>>>net = Net()
>>>for i in net.parameters():
...    print(i)

Parameter containing:
tensor([[ 0.5211,  0.2569,  1.1290],
        [-0.5820,  0.1013, -1.3352]], requires_grad=True)
Parameter containing:
tensor([[-0.4867,  0.0765],
        [-0.0178,  0.5943]], requires_grad=True)
Parameter containing:
tensor([0.3423, 0.1557], requires_grad=True)
### 创建 PyTorch 自定义类的方法 #### 1. 自定义 Dataset 类 `torch.utils.data.Dataset` 是 PyTorch 中的核心抽象之一,用于封装数据集。要创建自定义数据集,可以通过继承 `Dataset` 并实现其必要的方法来完成。 以下是实现自定义数据集的一个示例: ```python import os from PIL import Image import torch from torch.utils.data import Dataset class CustomImageDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.image_files = [] # 遍历目录并收集所有图像文件路径 for category in ['cats', 'dogs']: category_path = os.path.join(root_dir, category) for file_name in os.listdir(category_path): if file_name.endswith(('.png', '.jpg', '.jpeg')): self.image_files.append(os.path.join(category_path, file_name)) def __len__(self): return len(self.image_files) def __getitem__(self, idx): img_path = self.image_files[idx] image = Image.open(img_path).convert('RGB') label = 0 if 'cat' in img_path else 1 if self.transform: image = self.transform(image) return image, label ``` 上述代码展示了如何构建一个简单的猫狗分类数据集[^4]。它实现了三个主要部分: - 初始化函数 (`__init__`):负责读取和存储数据。 - 获取长度函数 (`__len__`):返回数据集中样本的数量。 - 获取单个项函数 (`__getitem__`):根据索引获取特定样本及其标签。 --- #### 2. 自定义神经网络层 除了数据集外,还可以通过继承 `torch.nn.Module` 来创建自定义的神经网络层或模型组件。 下面是一个简单的一维卷积操作扩展的例子: ```python import torch import torch.nn as nn import torch.nn.functional as F class CustomConvLayer(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): super(CustomConvLayer, self).__init__() self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding) def forward(self, x): x = self.conv(x) x = F.relu(x) return x ``` 此代码片段展示了一个带有 ReLU 激活函数的自定义一维卷积层[^2]。用户可以根据需求进一步修改该结构以适应更复杂的场景。 --- #### 3. 自定义 C++ 扩展算子 对于性能敏感的应用程序,可能需要编写高效的底层运算符。PyTorch 支持使用 C++ 和 CUDA 编写自定义算子,并将其集成到 Python 环境中。 以下是如何注册和调用自定义算子的简化流程[^3]: ##### (a) 定义 C++ 运算符 假设我们有一个名为 `my_custom_op.cpp` 的文件,其中包含如下内容: ```cpp #include <torch/extension.h> at::Tensor my_custom_add(at::Tensor a, at::Tensor b) { return a + b; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("add", &my_custom_add, "Add two tensors"); } ``` ##### (b) 构建扩展模块 运行以下命令编译 C++ 文件为共享库: ```bash python setup.py install ``` ##### (c) 调用自定义算子 在 Python 中导入并测试新算子: ```python import torch import custom_extension a = torch.tensor([1., 2., 3.]) b = torch.tensor([4., 5., 6.]) result = custom_extension.add(a, b) print(result) # 输出: tensor([5., 7., 9.]) ``` 这种技术允许开发者充分利用 GPU 加速功能优化计算密集型任务。 --- #### 4. 导出自定义模型至 ONNX 格式 如果希望将含自定义算子的模型导出为 ONNX,则需指定额外参数以便正确解析这些特殊节点。 示例如下: ```python import torch model_test = ... # 替换为实际模型对象 ipt = torch.randn(1, 3, 224, 224) # 输入张量形状匹配模型预期输入 model_name = "custom_model.onnx" torch.onnx.export( model_test, ipt, 'results/onnx/' + model_name, input_names=['input'], output_names=['output'], custom_opsets={"custom": 1} # 添加自定义命名空间及版本号 ) ``` 此处的关键在于传递字典形式的 `custom_opsets` 参数给 `export()` 方法。 --- ### 总结 无论是管理复杂的数据源还是设计新颖的架构单元,在 PyTorch 生态系统内均提供了灵活多样的工具链支持开发人员探索创新解决方案。从基础层面理解各个组成部分的工作原理有助于加速项目进展并提升最终成果质量。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值