Masked AutoEncoders (MAE)代码学习

目录

 

本文说明

1.models_mae(MAE模型)

2.main_pretrain(预训练)


本文说明

本人学习所用,主要是对MAE模型官方(facebookresearch)实现代码进行一些注释

(MAE)论文地址

官方代码地址

注释参考chatGPT

1.models_mae(MAE模型)


from functools import partial

import torch
import torch.nn as nn

from timm.models.vision_transformer import PatchEmbed, Block

from util.pos_embed import get_2d_sincos_pos_embed


class MaskedAutoencoderViT(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3,   #定义了一个类的初始化方法(__init__)
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
        super().__init__()

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)  #一个模块,将输入图像分割成小块并将它们嵌入到高维空间
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))   #一个可学习的参数,代表类别标记。它被添加到嵌入的patch列表的前面
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding   加到patch嵌入上的位置嵌入,以保留位置信息。这些是固定的正弦嵌入,不是学习得到的。

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(depth)])    #构成编码器的Block模块的序列。每个块通常包含一个多头自注意力机制和一个MLP
        self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)  #一个线性层,将编码器嵌入投影到解码器的嵌入维度

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))  #一个可学习的参数,代表训练中用于遮蔽patch的标记

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch   一个线性层,将解码器的输出投影回原始的patch大小,实质上重构图像。
        # --------------------------------------------------------------------------

        self.norm_pix_loss = norm_pix_loss

        self.initialize_weights()  #一个可能在类的其他地方定义的方法,用于初始化模型的权重,未在提供的代码片段中显示

    def initialize_weights(self):  #负责初始化模型中各个参数的权重。这个方法通过特定的初始化策略,确保模型在开始训练之前处于一个合适的状态
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) #根据位置嵌入的维度和patch数量的平方根计算正弦余弦嵌入,如果有类别标记(cls_token),则在嵌入中考虑它
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) #使用copy_方法将这些初始化的嵌入值复制到模型的位置嵌入参数中

        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data #patch_embed.proj.weight是将图像分块后,每个块嵌入到高维空间的线性投影层(通常实现为一个卷积层)的权重
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) #使用xavier_uniform_方法对这个投影层的权重进行初始化,这是一种常用的权重初始化方法,旨在保持输入和输出的方差一致,有助于模型的稳定训练

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02) #类别标记 #使用具有标准差为0.02的正态分布对这些标记进行初始化。这样的初始化有助于在模型训练初期,让这些标记与其他嵌入有一定的区分度
        torch.nn.init.normal_(self.mask_token, std=.02) #遮蔽标记

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)  #通过应用_init_weights函数递归地对模型中所有的线性层(nn.Linear)和层规范化(nn.LayerNorm)进行初始化

    def _init_weights(self, m):  #用于初始化线性层(nn.Linear)和层规范化(nn.LayerNorm)的权重和偏置
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs): #将输入的图像分割成小块(patches)
        """
        imgs: (N, 3, H, W) #输入是一个四维张量imgs,代表一批图像,其维度为(N, 3, H, W),其中N是批大小,3代表图像的颜色通道(RGB),H和W分别是图像的高度和宽度。
        x: (N, L, patch_size**2 *3) #输出是一个三维张量x,其维度为(N, L, patch_size**2 * 3),其中L是经过切分后得到的每张图像的patch总数,patch_size**2 * 3是每个patch的维度,乘以3是因为每个颜色通道都被考虑在内。
        """
        p = self.patch_embed.patch_size[0] #patch的大小
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 #确保图像能够被平均分割成大小为p x p的patches

        h = w = imgs.shape[2] // p  #计算分割后的维度
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) #(N, C, h, p, w, p)
        x = torch.einsum('nchpwq->nhwpqc', x) #(N, h, w, p, p, C) 'nchpwq->nhwpqc'表示将通道维度从第二个位置移动到最后,同时保持patches在一起。这一步骤实质上是在对每个patch进行重新组织,使其排列成一个线性序列。
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) #(N, L, patch_size**2 *3)
        return x

    def unpatchify(self, x):
        """
        x: (N, L, patch_size**2 *3)
        imgs: (N, 3, H, W)
        """
        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1]**.5)
        assert h * w == x.shape[1]
        
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio)) #确定保留的元素数量  列长度L乘以(1减去掩码比例mask_ratio)
        
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1] #[N, L] 生成随机噪声,这些噪声值将用于随机选择哪些元素被掩码
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove 通过对每个样本的噪声向量进行排序,得到一个索引序列ids_shuffle。这些索引决定了元素将如何被重新排列,以便随机选取一部分元素进行保留。
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep] #根据ids_keep(即ids_shuffle的前len_keep个元素),使用torch.gather从每个样本中选取要保留的元素
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) #形成掩码后的序列x_masked

        # generate the binary mask: 0 is keep, 1 is remove 生成二进制掩码:创建一个全为1的掩码(表示移除),然后将前len_keep个位置设为0(表示保留)。通过将这个掩码按ids_restore索引进行重新排序,我们得到了与原始序列对应的掩码,其中0表示元素被保留,1表示元素被掩码。
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        # x_masked: 经过随机掩码处理后的序列,只包含被选中保留的元素。
        # mask: 一个二进制掩码,指示了哪些元素被保留(0表示保留),哪些被掩码(1表示掩码)。
        # ids_restore: 用于将x_masked和mask恢复到原始排序的索引。
        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        # embed patches 嵌入补丁 将输入图像x转换为一系列补丁(patches)的嵌入表示
        x = self.patch_embed(x)

        # add pos embed w/o cls token  添加位置嵌入(Add Position Embeddings without CLS Token)
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio   对补丁进行掩码操作。将补丁的长度从 length 转换为 length * mask_ratio,
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # append cls token  创建一个特殊的CLS标记,并将其添加到补丁序列的开头。这通常是为了处理分类任务,在这些任务中,CLS标记的最终隐藏状态用作整个序列的表示。CLS标记通常用于聚合整个序列的信息,其在序列处理任务中的表现被用作整个输入的表示
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks  通过循环遍历 self.blocks 中的每个Transformer块,对序列进行变换。这些块通常包括多头自注意力和前馈神经网络层,用于捕捉序列中的信息
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        # 通过self.norm对序列进行归一化处理。这通常是为了稳定训练过程,确保各层的激活在合适的范围内。
        return x, mask, ids_restore
        # x: 经过Transformer编码器处理后的序列,包含了添加了CLS标记的补丁的嵌入表示。
        # mask: 表示哪些补丁被掩码(隐藏)的二进制掩码,用于后续的自监督学习任务中指导模型重建被掩码的部分。
        # ids_restore: 掩码操作中用于恢复序列原始顺序的索引,允许模型在处理完掩码数据后将输出恢复到与原始输入相匹配的顺序。

    def forward_decoder(self, x, ids_restore):
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x

    def forward_loss(self, imgs, pred, mask):
        """
        imgs: [N, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is keep保留, 1 is remove掩码,
        """
        target = self.patchify(imgs) #将原始图像imgs转换成补丁表示,作为损失计算的目标。
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        #计算预测补丁与目标补丁之间的均方误差(MSE)损失
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches 通过将平均损失与掩码相乘,只保留被掩码补丁的损失,将这些损失求和并除以掩码中值为1的元素总数,得到所有被掩码补丁的平均损失
        return loss #值反映了模型对被掩码补丁预测的准确性

    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) #将输入图像imgs通过编码器,执行随机掩码操作,并返回编码后的潜在表示latent、掩码mask和用于恢复原始顺序的索引ids_restore
        pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3] 解码器的任务是根据潜在表示重建被掩码的补丁
        loss = self.forward_loss(imgs, pred, mask)  #这一步骤评估了模型在重建被掩码部分方面的性能
        return loss, pred, mask


def mae_vit_base_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_large_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_huge_patch14_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks




2.main_pretrain(预训练)

import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import timm

# assert timm.__version__ == "0.3.2"  # version check
import timm.optim.optim_factory as optim_factory

import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler

import models_mae

from engine_pretrain import train_one_epoch


def get_args_parser():
    parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
    parser.add_argument('--batch_size', default=64, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--epochs', default=400, type=int)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')

    # Model parameters
    parser.add_argument('--model', default='mae_vit_base_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')

    parser.add_argument('--input_size', default=224, type=int,
                        help='images input size')

    parser.add_argument('--mask_ratio', default=0.75, type=float,
                        help='Masking ratio (percentage of removed patches).')

    parser.add_argument('--norm_pix_loss', action='store_true',
                        help='Use (per-patch) normalized pixels as targets for computing loss')
    parser.set_defaults(norm_pix_loss=False)

    # Optimizer parameters
    parser.add_argument('--weight_decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')

    parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')

    parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
                        help='epochs to warmup LR')

    # Dataset parameters
    parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
                        help='dataset path')

    parser.add_argument('--output_dir', default='./output_dir',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir', default='./output_dir',
                        help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='',
                        help='resume from checkpoint')

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')

    return parser


def main(args):
    misc.init_distributed_mode(args)  #分布式训练初始化

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) #打印作业目录
    print("{}".format(args).replace(', ', ',\n')) #打印参数

    device = torch.device(args.device)

    # fix the seed for reproducibility 随机数生成器设置种子,以确保结果的可重复性
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True #启用了cuDNN的基准模式

    # simple augmentation 设置图像数据的简单增强处理,并加载训练数据集
    transform_train = transforms.Compose([
            transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic  随机大小裁剪图像
            transforms.RandomHorizontalFlip(),  #以一定的概率(默认为0.5)进行随机水平翻转
            transforms.ToTensor(),  #将PIL图像或NumPy ndarray转换为torch.Tensor,并且把像素值范围从[0, 255]归一化到[0.0, 1.0]
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])  #标准化处理,使用指定的均值和标准差对每个通道的数据进行标准化
    dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)  #加载数据集
    print(dataset_train)

    if True:  # args.distributed:
        num_tasks = misc.get_world_size() #获取全局的任务数,即分布式训练中的总进程数
        global_rank = misc.get_rank()  #获取当前进程的全局排名
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )  #创建一个分布式采样器实例,为每个进程分配数据子集,并确保每次迭代时数据都会被打乱
        print("Sampler_train = %s" % str(sampler_train))
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    if global_rank == 0 and args.log_dir is not None: #日志记录的设置
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train, #基于是否进行分布式训练条件下创建的采样器。在分布式训练中,它确保每个进程获得的数据子集是唯一的,而在非分布式训练中,可能使用的是随机采样器。
        batch_size=args.batch_size,
        num_workers=args.num_workers,  #指定了加载数据时使用的子进程数量。增加工作进程可以加速数据加载过程,但也会增加内存消耗。
        pin_memory=args.pin_mem,  #当使用GPU训练时,设置pin_memory=True可以让数据加载更高效
        drop_last=True, #这个参数决定是否在数据集大小不能被批量大小整除时丢弃最后一个不完整的批次
    )
    
    # define the model
    model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss) #从models_mae模块中动态地选择并创建一个模型实例  args.model是一个字符串,指定了要使用的模型名称
    # norm_pix_loss = args.norm_pix_loss是传递给模型构造函数的参数,它控制是否对像素损失进行归一化
    # 将模型移动到之前定义的计算设备上(device)
    model.to(device)

    model_without_ddp = model  #处理分布式数据并行(DDP)
    print("Model = %s" % str(model_without_ddp))

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()  #计算的是有效批量大小(effective batch size),考虑了每个进程的批量大小(args.batch_size)、梯度累积迭代次数(args.accum_iter)以及分布式训练中的总进程数(misc.get_world_size())
    
    if args.lr is None:  # only base_lr is specified
        args.lr = args.blr * eff_batch_size / 256  #做法常见于需要根据批量大小自动缩放学习率的场景,以保持不同配置下训练动态的一致性。

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)

    print("accumulate grad iterations: %d" % args.accum_iter)  #梯度累积迭代次数
    print("effective batch size: %d" % eff_batch_size)  #计算出的有效批量大小

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module
    
    # following timm: set wd as 0 for bias and norm layers
    param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) #设置权重衰减
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
    print(optimizer)
    loss_scaler = NativeScaler()

    misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)
        train_stats = train_one_epoch(      #函数执行模型在一个epoch上的训练过程,返回训练统计信息。
            model, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            log_writer=log_writer,
            args=args
        )
        if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs):  #指定的输出目录存在(args.output_dir)并且当前epoch满足20的倍数或者最后一个epoch,保存模型
            misc.save_model(
                args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
                loss_scaler=loss_scaler, epoch=epoch)

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},  #创建了一个字典log_stats,它用于记录训练过程中的统计信息,
                        'epoch': epoch,}

        #将训练过程中的统计信息记录到日志文件中
        if args.output_dir and misc.is_main_process():  #检查是否指定了输出目录(args.output_dir)且当前进程是主进程
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)

  • 7
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值