VIT源码解读

本文只是自己学习记录,Debug代码流程详解,有兴趣的可以看一下。

Transformebenwenr是全新的特征提取模块。而视觉中的注意力机制是关注重要的像素点。VIT结构图如下图所示。

NLP中是将文字转化为序列,分别进行自注意力机制多头注意力机制,而视觉任务也可以将图像转换为序列。图中先将图像平均分成9份,每一份都是单独的个体。每个小图像都要与其他的算关系。比NLP多了一步展开成序列。文本的第一步是:词向量。图像的第一步:图特征向量提取。

先将图像分成多个区域,再将每个区域通过卷积核提取特征向量。一个卷积对应一个特征,每个位置通过n个卷积核得到n维向量。得到特征向量后堆叠注意力机制,再进行聚合特征,最后进行分类等下游特征。

下面进行debug代码

main入口函数

def main():
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument("--name", required=True,
                        help="Name of this run. Used for monitoring.")
    parser.add_argument("--dataset", choices=["cifar10", "cifar100"], default="cifar10",
                        help="Which downstream task.")
    parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16",
                                                 "ViT-L_32", "ViT-H_14", "R50-ViT-B_16"],
                        default="ViT-B_16",
                        help="Which variant to use.")
    parser.add_argument("--pretrained_dir", type=str, default="checkpoint/ViT-B_16.npz",
                        help="Where to search for pretrained ViT models.")
    parser.add_argument("--output_dir", default="output", type=str,
                        help="The output directory where checkpoints will be written.")

    parser.add_argument("--img_size", default=224, type=int,
                        help="Resolution size")
    parser.add_argument("--train_batch_size", default=16, type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size", default=64, type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--eval_every", default=100, type=int,
                        help="Run prediction on validation set every so many steps."
                             "Will always run one evaluation at the end of training.")

    parser.add_argument("--learning_rate", default=3e-2, type=float,
                        help="The initial learning rate for SGD.")
    parser.add_argument("--weight_decay", default=0, type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--num_steps", default=10000, type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
                        help="How to decay the learning rate.")
    parser.add_argument("--warmup_steps", default=500, type=int,
                        help="Step of training to perform learning rate warmup for.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")

    parser.add_argument("--local_rank", type=int, default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O2',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument('--loss_scale', type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    args = parser.parse_args()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1: # 单机单卡
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             timeout=timedelta(minutes=60))
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" %
                   (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16))

    # Set seed
    set_seed(args)  #设置种子随机初始化,让初始化固定

    # Model & Tokenizer Setup
    args, model = setup(args)

    # Training
    train(args, model)

main函数首先设置参数变量,通过set_seed(args)设置随机种子,设置随机种子函数如下

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

args, model = setup(args)传递参数到模型,初始化模型

def setup(args):
    # Prepare model
    config = CONFIGS[args.model_type]  # 通过字典读取参数

    num_classes = 10 if args.dataset == "cifar10" else 100  #类别,10分类

    model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes)
    model.load_from(np.load(args.pretrained_dir))
    model.to(args.device)
    num_params = count_parameters(model)

    logger.info("{}".format(config))
    logger.info("Training parameters %s", args)
    logger.info("Total Parameter: \t%2.1fM" % num_params)
    print(num_params)
    return args, model

VisionTransformer视觉Transformer,构造函数__init__断点进入

class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier

        self.transformer = Transformer(config, img_size, vis)  #传入参数,图像大小
        self.head = Linear(config.hidden_size, num_classes)

    def forward(self, x, labels=None):
        x, attn_weights = self.transformer(x)
        print(x.shape)
        logits = self.head(x[:, 0])
        print(logits.shape)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
            return loss
        else:
            return logits, attn_weights

Transformer函数

class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids):
        embedding_output = self.embeddings(input_ids)
        encoded, attn_weights = self.encoder(embedding_output)
        return encoded, attn_weights

Embeddings函数

class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        img_size = _pair(img_size)

        if config.patches.get("grid") is not None:
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            n_patches = (img_size[0] // 16) * (img_size[1] // 16)
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])  # patch_size是选择多大的区域进行分块提取特征
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
                                         width_factor=config.resnet.width_factor)
            in_channels = self.hybrid_model.width * 16
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        self.dropout = Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        print(x.shape)
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        print(cls_tokens.shape)
        if self.hybrid:
            x = self.hybrid_model(x)
        x = self.patch_embeddings(x)#Conv2d: Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        print(x.shape)
        x = x.flatten(2)
        print(x.shape)
        x = x.transpose(-1, -2)
        print(x.shape)
        x = torch.cat((cls_tokens, x), dim=1)
        print(x.shape)

        embeddings = x + self.position_embeddings
        print(embeddings.shape)
        embeddings = self.dropout(embeddings)
        print(embeddings.shape)
        return embeddings

patch_size是选择多大的区域进行分块提取特征,n_patches一共有多少块(图像宽/patch_size宽)x(图像高/patch_size高)patch_embeddings卷积stride为patch_size,提取特征时的卷积不重叠提取特征。position_embeddings位置编码,多了一个全局特征cls。cls_token对应向量维度要和特征向量对应

Encoder函数

class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)  #归一化操作
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states):
        print(hidden_states.shape)
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights

Batch Norm和Layer Norm

1.Batch Norm基本是在CNN中,一批数据中每个个体向集体靠拢,对于一个batch中NLP中的相同位置中含义需要一致(但NLP中同一位置每个信息无法确保类似)

2.Layer Norm是解决自己特征的归一化操作

Block函数

class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)  #全连接层
        self.attn = Attention(config, vis)

    def forward(self, x):
        print(x.shape)
        h = x
        x = self.attention_norm(x)
        print(x.shape)
        x, weights = self.attn(x)
        x = x + h
        print(x.shape)

        h = x
        x = self.ffn_norm(x)
        print(x.shape)
        x = self.ffn(x)
        print(x.shape)
        x = x + h
        print(x.shape)
        return x, weights

Attention函数

class Attention(nn.Module):
    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        print(new_x_shape)
        x = x.view(*new_x_shape)
        print(x.shape)
        print(x.permute(0, 2, 1, 3).shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        print(hidden_states.shape)
        mixed_query_layer = self.query(hidden_states)#Linear(in_features=768, out_features=768, bias=True)
        print(mixed_query_layer.shape)
        mixed_key_layer = self.key(hidden_states)
        print(mixed_key_layer.shape)
        mixed_value_layer = self.value(hidden_states)
        print(mixed_value_layer.shape)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        print(query_layer.shape)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        print(key_layer.shape)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        print(value_layer.shape)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        print(attention_scores.shape)
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        print(attention_scores.shape)
        attention_probs = self.softmax(attention_scores)
        print(attention_probs.shape)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)
        print(attention_probs.shape)

        context_layer = torch.matmul(attention_probs, value_layer)
        print(context_layer.shape)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        print(context_layer.shape)
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        print(context_layer.shape)
        attention_output = self.out(context_layer)
        print(attention_output.shape)
        attention_output = self.proj_dropout(attention_output)
        print(attention_output.shape)
        return attention_output, weights

forward函数是实际代码走的train(args, model)代码执行进入训练函数,在每个模块forward处打断点,进行debug。首先进入VisionTransformer函数

输入(16,3,224,224)训练batch大小,3通道RGB,图像大小224x224

将输入图像进行embedding,输入时3通道,输出维768维向量,进行分块向量转换时的卷积核大小为16*16代表以16*16进行分组。stride也为16,因为特征提取时没用重复的提取部分,下图为进入Embeddings

cls_token为(1,1,768)进行expand将维度复制到B,这里的B是16,每个数据都要有对应的cls_token。

进入Encoder函数,前面已经获得了隐层特征,随后开始注意力机制

开始提特征,先进行LayerNorm再进行注意力机制,进入attn也就是Attention函数。

Attention的forward函数

先进行kqv初始化

这里的transpose_for_score是进行多头注意力机制转换,此时的q,k,v维度为(16,12,197,64)代表12头。再算内积,q*k,此时的attention_scores.shape为(16,12,197,197)因为每个token都要和其他的token算关系。还要除以根号64,因为随着维度增大,内积越大,要排除掉特征长度对序列的影响。

权重加权V,随后进行多头注意力机制合并

注意力机制后再回到block模块,进行残差连接,ffn就是全连接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值