spikingjelly训练自己的网络---量化 --测试

在这里插入图片描述
在这里插入图片描述
第二个=================

在这里插入图片描述
在这里插入图片描述
但是我发现,都要反量化,因为pytorch是只能支持浮点数的。

在这里插入图片描述

https://blog.csdn.net/lai_cheng/article/details/118961420
Pytorch的量化大致分为三种:模型训练完毕后动态量化、模型训练完毕后静态量化、模型训练中开启量化,本文从一个工程项目(Pose Estimation)给大家介绍模型训练后静态量化的过程。

我又提问了
我要在这个上面进行16比特量化的修改,应该怎么修改?【class SNN(nn.Module):
def init(self, tau):
super().init()

    self.layer = nn.Sequential(
        layer.Flatten(),
        layer.Linear(28 * 28, 10, bias=False),
        neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
        )

def forward(self, x: torch.Tensor):
    return self.layer(x)】

在这里插入图片描述
在这里插入图片描述

=

=

=

=

=

=

=
测试【我将模型测试的部分单独写在一个程序中,应该怎么写?】
在这里插入图片描述

import torch
import torch.nn.functional as F
import torchvision
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import time
from main import SNN  # 确保从你的 main.py 或其他文件中正确导入 SNN 类和 encoder
from torch.utils.tensorboard import SummaryWriter

from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer

#python -m main -tau 2.0 -T 50 -device cuda:0 -b 64 -epochs 3 -data-dir \mnist -opt adam -lr 1e-3 -j 2

def test_model(model_path, data_dir, device='cuda:0', T=50,epoch_test = 3):
    start_epoch = 0
    out_dir = '.\\out_dir'
    writer = SummaryWriter(out_dir, purge_step=start_epoch)

    # 加载模型
    net = SNN(tau=2.0)  # 使用适当的参数初始化你的模型
    checkpoint = torch.load(model_path, map_location=device)
    net.load_state_dict(checkpoint['net'])
    net.to(device)
    net.eval()

    # 加载测试数据集
    test_dataset = torchvision.datasets.MNIST(
        root=data_dir,
        train=False,
        transform=ToTensor(),
        download=True
    )
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

    for epoch in range(start_epoch, epoch_test):  
        # 初始化性能指标
        test_loss = 0
        test_acc = 0
        test_samples = 0
        start_time = time.time()

        encoder = encoding.PoissonEncoder()

        with torch.no_grad():
            for img, label in test_loader:
                img = img.to(device)
                label = label.to(device)
                label_onehot = F.one_hot(label, 10).float()
                out_fr = 0.
                for t in range(T):
                    encoded_img = encoder(img)  # 确保 encoder 已经定义
                    out_fr += net(encoded_img)
                out_fr = out_fr / T
                loss = F.mse_loss(out_fr, label_onehot)

                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                # 注意:如果你的网络需要在每次迭代后重置状态,请在这里调用重置函数

        test_time = time.time() - start_time
        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}')
        print(f'Test completed in {test_time:.2f} seconds.')

if __name__ == '__main__':
    model_path = 'logs\\T50_b64_adam_lr0.001\\checkpoint_max.pth'  # 模型路径
    data_dir = 'data'  # 数据集路径
    test_model(model_path, data_dir)
Test Loss: 0.0167, Test Accuracy: 0.9198
Test completed in 5.56 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9186
Test completed in 4.79 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9185
Test completed in 4.77 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9194
Test completed in 4.79 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9188
Test completed in 4.72 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9192
Test completed in 4.74 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9193
Test completed in 4.74 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9189
Test completed in 4.74 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9188
Test completed in 4.76 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9192
Test completed in 4.74 seconds.
test_samples=10000

T=5时候,结果如下T只是影响网络看见 了什么,越长不一定越好,趋于稳定

Test Loss: 0.0205, Test Accuracy: 0.9064
Test completed in 2.04 seconds.
test_samples=10000
Test Loss: 0.0205, Test Accuracy: 0.9050
Test completed in 1.24 seconds.
test_samples=10000
Test Loss: 0.0207, Test Accuracy: 0.9055
Test completed in 1.23 seconds.
test_samples=10000
Test Loss: 0.0203, Test Accuracy: 0.9080
Test completed in 1.24 seconds.
test_samples=10000
Test Loss: 0.0205, Test Accuracy: 0.9074
Test completed in 1.35 seconds.
test_samples=10000
Test Loss: 0.0206, Test Accuracy: 0.9045
Test completed in 1.37 seconds.
test_samples=10000
Test Loss: 0.0207, Test Accuracy: 0.9058
Test completed in 1.40 seconds.
test_samples=10000
Test Loss: 0.0206, Test Accuracy: 0.9049
Test completed in 1.40 seconds.
test_samples=10000
Test Loss: 0.0205, Test Accuracy: 0.9063
Test completed in 1.47 seconds.
test_samples=10000
Test Loss: 0.0207, Test Accuracy: 0.9047
Test completed in 1.35 seconds.
test_samples=10000
量化
import torch

with open('model_params.txt', 'r') as file:
    lines = file.readlines()

with open('model_params_quantized.txt', 'w') as file:
    for line in lines:
        # 去除换行符并按逗号和空格拆分字符串
        values = line.strip().split(',')
        for val in values:
            float_val = float(val.strip())
            quantized_val = int(round(float_val * 10000))  # 量化为int32
            file.write(f"{quantized_val}\n")


量化后再把数字写入进去
import torch

# 加载原始的checkpoint_max.pth文件
model_path = 'logs\\T50_b64_adam_lr0.001\\checkpoint_max.pth'
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

# 读取量化后的数据
with open('model_params_quantized.txt', 'r') as file:
    quantized_values = [int(line.strip()) for line in file.readlines()]

# 将量化后的数据写回到模型参数中
index = 0
for name, param in checkpoint['net'].items():
    if isinstance(param, torch.Tensor):
        numel = param.numel()
        quantized_param = torch.tensor(quantized_values[index:index+numel]).view(param.size())
        checkpoint['net'][name] = quantized_param
        index += numel

# 保存新的checkpoint文件
torch.save(checkpoint, 'logs\\T50_b64_adam_lr0.001\\checkpoint_max_quantized.pth')


model_state_dict = checkpoint['net']
for name, param in model_state_dict.items():
    print(f"{name}: {param}")
    print(f"{name}: {param.size()}")

量化为int32之后的准确率  下降
Test Loss: 0.1182, Test Accuracy: 0.6758
Test completed in 2.10 seconds.
test_samples=10000
Test Loss: 0.1181, Test Accuracy: 0.6765
Test completed in 1.23 seconds.
test_samples=10000
Test Loss: 0.1181, Test Accuracy: 0.6789
Test completed in 1.25 seconds.
test_samples=10000
Test Loss: 0.1180, Test Accuracy: 0.6785
Test completed in 1.30 seconds.
test_samples=10000
Test Loss: 0.1181, Test Accuracy: 0.6755
Test completed in 1.35 seconds.
test_samples=10000
Test completed in 1.39 seconds.
test_samples=10000
Test Loss: 0.1183, Test Accuracy: 0.6800
Test completed in 1.35 seconds.
test_samples=10000
Test Loss: 0.1185, Test Accuracy: 0.6750
Test completed in 1.38 seconds.
test_samples=10000
  • 8
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
量化金融是指利用数学和统计方法来分析和预测金融市场的方法。Python是一种流行的编程语言,因其简单易用和丰富的数据分析工具而广泛应用于量化金融领域。 TF Quant Finance是一个高性能的TensorFlow库,专门用于量化金融。它提供了一系列功能强大的工具和模型,可以帮助开发者进行金融数据的处理、计算和建模。该库支持硬件加速,并且利用TensorFlow的自动区分功能,使得模型的训练和优化更加高效和精确。 安装TF Quant Finance非常简单,只需按照官方文档提供的安装指南进行操作即可。安装完成后,你可以使用TF Quant Finance提供的各种函数和类来进行量化金融分析和建模。官方文档也提供了一些示例代码,可以帮助你入门和理解如何使用这个库。 如果你对TF Quant Finance感兴趣并希望贡献代码,你可以参考开发路线图和贡献指南,了解如何参与到这个开源项目中来。同时,也可以加入TF Quant Finance的社区,与其他开发者进行交流和分享经验。 需要注意的是,TF Quant Finance库是一个开源项目,使用时请遵守相应的许可证和免责声明。这样才能确保你的使用是合法的并保护你的权益。 总之,量化金融-python是指利用Python编程语言来进行量化金融分析和建模。TF Quant Finance是一个基于TensorFlow的高性能库,为开发者提供了丰富的工具和模型来支持量化金融。你可以通过安装和使用TF Quant Finance来加快你的量化金融工作流程,并参与到这个开源项目中来贡献代码。<span class="em">1</span> #### 引用[.reference_title] - *1* [高性能TensorFlow库,用于量化金融。-Python开发](https://download.csdn.net/download/weixin_42179184/19055785)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

weixin_44781508

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值