Pytorch:关于epoch、batch_size和batch_idx(iteration )的一些理解(深度学习)

本文介绍了深度学习中重要的概念epoch、batch_size和iteration,通过PyTorch中的CIFAR10数据集举例,解释了它们的关系和作用。batch_size表示每次训练的数据量,batch_idx(或iteration)代表训练的步数,epoch则是一个完整的数据集遍历。理解这些参数对于优化模型训练至关重要。

前言

在新手搭建神经网络时,常弄不清epoch、batch_size、iteration和batch_idx(iteration )的区别。
这里以torchvision自带的CIFAR10数据集来举例,通过代码操作来直观地对这几个概念进行理解。
声明,这里batch_idx==iteration

数据准备

首先加载数据集:

import torch
import torch.nn as nn
import torchvision

train_dataset = torchvision.datasets.CIFAR10(root="data/",train=True,download=False)
test_dataset = torchvision.datasets.CIFAR10(root
common: run_label: "run_1" accum_freq: 1 accum_after_epoch: -1 log_freq: 200 auto_resume: true mixed_precision: true dataset: root_train: "/media/Datasets/VOCdevkit" root_val: "/media/Datasets/VOCdevkit" name: "pascal" category: "segmentation" train_batch_size0: 12 val_batch_size0: 12 eval_batch_size0: 1 workers: 12 persistent_workers: false pin_memory: false pascal: use_coco_data: true coco_root_dir: "/media/Datasets/coco_preprocess" image_augmentation: random_resize: enable: true min_size: 256 max_size: 1024 random_crop: enable: true mask_fill: 255 resize_if_needed: true random_horizontal_flip: enable: true sampler: name: "batch_sampler" bs: crop_size_width: 512 crop_size_height: 512 loss: category: "segmentation" ignore_idx: 255 segmentation: name: "cross_entropy" optim: name: "adamw" weight_decay: 0.01 no_decay_bn_filter_bias: false adamw: beta1: 0.9 beta2: 0.999 scheduler: name: "cosine" is_iteration_based: false max_epochs: 50 warmup_iterations: 500 warmup_init_lr: 0.00009 cosine: max_lr: 0.0009 # [2.7e-3 * N_GPUS^2 x (BATCH_SIZE_GPU0/ 32) * 0.02 ] # 0.02 comes from this fact 0.1 (ResNet SGD LR)/0.002 (MIT ADAMW LR) min_lr: 1.e-6 model: segmentation: name: "encoder_decoder" lr_multiplier: 10 seg_head: "deeplabv3" output_stride: 16 classifier_dropout: 0.1 activation: name: "relu" deeplabv3: aspp_dropout: 0.1 aspp_sep_conv: false aspp_out_channels: 256 aspp_rates: [6, 12, 18] classification: name: "mobilevit_v3" classifier_dropout: 0.1 mit: mode: "small_v3" ffn_dropout: 0.0 attn_dropout: 0.0 dropout: 0.1 number_heads: 4 no_fuse_local_global_features: false conv_kernel_size: 3 activation: name: "swish" pretrained: "results/mobilevitv3_small_e300_7930/run_1/checkpoint_ema_best.pt" normalization: name: "sync_batch_norm" momentum: 0.1 activation: name: "relu" inplace: false layer: global_pool: "mean" conv_init: "kaiming_normal" linear_init: "normal" conv_weight_std: false ema: enable: true momentum: 0.0005 ddp: enable: true rank: 0 world_size: -1 dist_port: 30786 stats: name: [ "loss", "iou"] checkpoint_metric: "iou" checkpoint_metric_max: true 帮我逐行详细解释这段代码
03-08
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值