torch 对参数处理后回填

文章介绍了如何在PyTorch中处理深度学习模型的参数,首先下载参数并将其整合为向量,然后在每个元素上添加噪声,接着还原噪声向量到模型参数并进行校验。使用`torch.randn`生成噪声并利用`torch.testing.assert_allclose`确保添加噪声后的参数与原始参数在指定误差范围内一致。
摘要由CSDN通过智能技术生成

先从网络中下载参数,将参数所有元素整合在一个向量中,对每个元素添加一定的噪声,接下来还原向量中的元素到参数的对应位置,并用加噪后的参数回填到模型中,最后检查一遍回填情况。

假设在函数里面构造模型

import torch
import torch.nn as nn


class MyModel(nn.Module):
   def __init__(self):
       super(MyModel, self).__init__()
       self.fc1 = nn.Linear(256, 128)
       self.relu = nn.ReLU()
       self.fc2 = nn.Linear(128, 10)

   def forward(self, x):
       x = self.fc1(x)
       x = self.relu(x)
       x = self.fc2(x)
       return x

model = MyModel()

# 获取深度网络的所有参数
parameters = list(model.parameters())

# 将参数转换为一个向量
param_vector = torch.cat([p.view(-1) for p in parameters], dim=0)
print(param_vector[:20])



# 在向量的每个元素上随机添加噪声
noise = torch.randn_like(param_vector)
noisy_param_vector = param_vector + noise

# print(noise.shape, noisy_param_vector.shape)
print(noisy_param_vector[:20])



# 将添加噪声的向量还原到对应位置的深度网络的参数中
start_idx = 0
for i, param in enumerate(parameters):
    # 计算当前参数的元素个数
    num_elements = param.numel()
    # 从噪声参数中截取当前参数的元素
    current_noise = noisy_param_vector[start_idx:start_idx + num_elements]
    # 重新调整形状为当前参数的形状
    current_noise = current_noise.view(param.shape)
    # 将还原后的参数加载到深度网络中
    param.data.copy_(current_noise)

    # 更新起始索引
    start_idx += num_elements


parametersa = list(model.parameters())

# 将参数转换为一个向量
param_vectora = torch.cat([p.view(-1) for p in parametersa], dim=0)
print(param_vectora[:20])

假设已有模型

# 下载和上传参数

def renew_model_param(DeepNet):


    # 获取深度网络的所有参数
    parameters = list(DeepNet.parameters())

    # 将参数转换为一个向量
    param_vector = torch.cat([p.view(-1) for p in parameters], dim=0)
    # print(param_vector[:20])


    # 在向量的每个元素上随机添加噪声
    noise = torch.randn_like(param_vector)
    noisy_param_vector = param_vector + noise

    # print(noise.shape, noisy_param_vector.shape)
    # print(noisy_param_vector[:20])



    # 将添加噪声的向量还原到对应位置的深度网络的参数中
    start_idx = 0
    for i, param in enumerate(parameters):
        # 计算当前参数的元素个数
        num_elements = param.numel()
        # 从噪声参数中截取当前参数的元素
        current_noise = noisy_param_vector[start_idx:start_idx + num_elements]
        # 重新调整形状为当前参数的形状
        current_noise = current_noise.view(param.shape)
        # 将还原后的参数加载到深度网络中
        param.data.copy_(current_noise)

        # 更新起始索引
        start_idx += num_elements


    parametersa = list(DeepNet.parameters())

    # 将参数转换为一个向量
    param_vectora = torch.cat([p.view(-1) for p in parametersa], dim=0)
    # print(param_vectora[:20])

    # 断言两个张量的值在一定误差范围内相等
    torch.testing.assert_allclose(noisy_param_vector, param_vectora, rtol=1e-3, atol=1e-3)

输出

tensor([ 0.0411,  0.0229,  0.0049, -0.0484,  0.0459,  0.0118, -0.0562, -0.0538,
         0.0094, -0.0068,  0.0483, -0.0045, -0.0097,  0.0123,  0.0309,  0.0003,
        -0.0542, -0.0463, -0.0452, -0.0189], grad_fn=<SliceBackward0>)
tensor([-0.4091, -1.7616,  0.1588, -0.3146, -1.9634, -0.8395,  0.0965, -1.9828,
        -0.9309, -0.9996, -1.1171,  0.0702,  0.5957, -0.2442,  0.8054,  0.3737,
        -0.6276,  1.9177, -0.2753,  0.9166], grad_fn=<SliceBackward0>)
tensor([-0.4091, -1.7616,  0.1588, -0.3146, -1.9634, -0.8395,  0.0965, -1.9828,
        -0.9309, -0.9996, -1.1171,  0.0702,  0.5957, -0.2442,  0.8054,  0.3737,
        -0.6276,  1.9177, -0.2753,  0.9166], grad_fn=<SliceBackward0>)

在 PyTorch 中,你可以使用 torch.testing.assert_allclose() 函数来断言两个张量的值是否在一定的误差范围内相等。这对于由于浮点数精度等原因可能出现的微小差异非常有用。以下是使用这个函数的示例:

import torch

# 创建两个张量
tensor1 = torch.tensor([1.0, 2.0, 3.0])
tensor2 = torch.tensor([1.001, 2.001, 2.999])

# 断言两个张量的值在一定误差范围内相等
torch.testing.assert_allclose(tensor1, tensor2, rtol=1e-3, atol=1e-3)

在上面的示例中,rtol 参数表示相对误差容忍度,atol 参数表示绝对误差容忍度。你可以根据需要调整这些参数的值,以便于在实际情况下检查两个张量是否相等。如果两个张量的值在指定的误差范围内相等,那么程序会继续执行,否则会引发一个断言错误。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

高山莫衣

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值