pytorch的框架使用记录
文章平均质量分 82
pytorch的框架使用记录
magic_ll
有空就记记,没空就休息
展开
-
【pytorch记录】SummaryWriter保存日志
在pytorch框架中,关于日志的保存,其中一种方式就是借鉴使用了tensorboard的库。所以我们需要在环境中安装tensorboard库,然后再在工程中进行该库的调用1 安装与导入或者导入2 添加需要保存标量数据从源码中我们能看到核心的三个参数为前三个。通俗的讲分别代表tag:图的标签名,唯一标识scalar_value:y轴数据,标量数据的具体数值global_step:x轴数据,要记录的全局步长值多项标题记录方法,其中:main_tag —— 该图的标签。原创 2023-03-20 14:35:26 · 1186 阅读 · 1 评论 -
【pytorch记录】模型的分布式训练DataParallel、DistributedDataParallel
使用多GPU对神经网络进行训练时,pytorch有相应的api将模型放到多GPU上运行。torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])两者的区别:nn.DataParallel使用单进程控制,将模型和数据加载到多个GPU中gpus=[0,1]torch.nn.DataParallel(model.cuda(), decice_ids=gpus, output_device=gpu原创 2022-06-27 20:01:36 · 1973 阅读 · 4 评论 -
【pytorch记录】torch.utils.data.Dataset、DataLoader、分布式读取并数据
pytorch提供了一个数据读取的方法,使用了 torch.utils.data.Dataset 和 torch.utils.data.DataLoader。要自定义自己数据的方法,就要继承 torch.utils.data.Dataset,实现了数据读取以及数据处理方式,并得到相应的数据处理结果。然后将 Dataset封装到 DataLoader中,可以实现了单/多进程迭代输出数据。1 torch.utils.data.Dataset要自定义自己的 Dataset 类,需要重载两个方式,【_.原创 2022-03-07 11:15:28 · 8003 阅读 · 0 评论 -
【pytorch 记录】pytorch的变量parameter_buffer、self.register_buffer()、self.register_parameter()
在pytorch中模型需要保存下来的参数包括:parameter:反向传播需要被 optimizer 更新的,可以被训练。buffer:反向传播不需要被 optimizer 更新,不可被训练。 这两种参数都会分别保存到 一个OrderDict 的变量中,最终由 model.state_module() 返回进行保存。1 nn.Module的介绍需要先说明下:直接torch.randn(1, 2) 这种定义的变量,没有绑定在pytorch的网络中,训练结束后也就没有在保存在模型中。当我们.原创 2022-05-25 17:27:41 · 2255 阅读 · 2 评论 -
【pytorch记录】自动混合精度训练 torch.cuda.amp
Nvidia 在Volta 架构中引入 Tensor Core 单元,来支持 FP32 和 FP16 混合精度计算。同年提出了一个pytorch 扩展apex,来支持模型参数自动混合精度训练自动混合精度(Automatic Mixed Precision, AMP)训练,是在训练一个数值精度为32的模型时,一部分算子的操作 数值精度为FP16,其余算子的操作精度为FP32。具体的哪些算子使用的精度,amp自动设置好了,不需要用户额外设置。..............................原创 2022-06-28 09:24:20 · 5838 阅读 · 1 评论 -
【pytorch记录】pytorch的分布式 torch.distributed.launch 命令在做什么呢
在查阅pytorch分布式训练,在不了解相关内容时,直接查看api的使用,更云里雾里。反而对着一个完整的例子,会更上手一些。所以在这里整理下,能够快速读懂分布式训练具体实现的流程。1 背景知识神经网络的训练,从硬件的使用,可分为3种情况:多机多卡、单机多卡、单机单卡。(机–主机、卡–显卡)在pytorch进行分布式训练中,会有相应的变量名表示上面的情况:测试脚本:train.pyimport torchimport torch.distributed as distimport osim原创 2022-03-03 20:13:29 · 46548 阅读 · 15 评论 -
【pytorch】使用stat、profile打印网络的参数量、Flops、MAdd、内存使用的情况
pytorch获取网络的参数量、MAdd、Flops安装工具:pip install torchstat使用例子:import sysimport torchimport torch.nn as nnimport numpy as npfrom thop import profilefrom torchstat import statclass Net(nn.Module): def __init__(self): super(Net, self).__ini原创 2021-12-27 10:47:27 · 10183 阅读 · 20 评论