MINICPM-V2_6图像得到embedding-代码解读

目的

基于上一篇MINICPM-V2_6图像预处理流程-代码解读将输入图片得到了input_ids、attention_mask、pixel_values、image_sizes、image_bound、tgt_sizes,但是要怎么通过这些得到图片对应的embedding呢?
这里接着从MINICPM-V2_6入手,了解如何从图像得到embedding的过程

随机位置编码

Randomized Positional Encodings Boost Length Generalization of Transformers
随机位置编码
因为图片的像素不统一,所以位置编码需要设置的比较大(L=2000)。假设图片对应的长度为N(N=40),训练阶段原本长度为N的序列对的位置序列是[0,1,⋯,N−2,N−1],现在改为从{0,1,⋯,L−2,L-1}中均匀地选N个点(0,50,100,。。。),作为当前序列的位置序列。这就解决了预测阶段的位置编码没有被训练过的问题。
这里的代码里用了这个思想,但是这里会复杂得多,因为这里是将2D的位置均匀的映射到[70*70]上面

代码

基本变量

import torch
from torch import nn

# 继承上篇的结果
# 图片块大小
tgt_sizes = torch.tensor([[28, 37],
        [39, 26],
        [39, 26]])
        
# 图片在inputs_id中的位置
image_bound = [torch.tensor([[ 18,  82],
        [ 84, 148],
        [150, 214]])]

# pixel_values,即最后得到的images
img1 = torch.randn(3,14,14504)
img2 = torch.randn(3,14,14196)
img3 = torch.randn(3,14,14196)
pixel_values_list = [[img1, img2, img3]]

max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])# 最大的块尺寸=28*37=1036,这里对应的是上一篇中的原始图片大小

# 将图片patch补齐到统一尺寸,作为一个batch
all_pixel_values = []
for pixel_values in pixel_values_list:
    all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
    # pixel_values[0] 3*14*14504
    # i.flatten(end_dim=1) 把前两维铺平 42,14504
    # i.flatten(end_dim=2) 把前三维铺平 609168
    # i.flatten(end_dim=1).permute(1, 0) 转置 14504,42
    # all_pixel_values包含三个矩阵[14504*42,14196*42,14196*42]

all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, padding_value=0.0)# [3, 14504, 42] 会将不够的位置后面补齐0,这里的3指的是

B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)# [3, 3, 14, 14504] 第一个3是图片数量,第二个3是通道数,14是patch_size,14504是最大的块尺寸*14
# 因为有些patch的像素点是经过pad的,所以需要找到哪些块是经过pad的
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool)# 3,1,1306
for i in range(B):# 将图片块的位置填充为True
    patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
# 即第一张图片对应的patch_attn_mask[0,0,:]=True
# 第二、三张图片对应的patch_attn_mask[1:,0,39*26:]均为False

到这里就得到all_pixel_values,patch_attn_mask, tgt_sizes
可以进行下一步去得到embedding了

函数定义

def get_embedding(all_pixel_values):
    """
    输入:all_pixel_values 经过拼接后的像素点
    输出:embeddings
    demo:
    batch_size = 3
    num_channels = 3
    patch_size = 14
    h,w = 28,37
    num = patch_size * h * w
    all_pixel_values = torch.randn(batch_size, num_channels, patch_size, num)
    embeddings = get_embedding(all_pixel_values)
    # batch_size,1 * h * w,1152
    """
    num_channels = 3
    embed_dim = 1152
    patch_size = 14
    batch_size = 1
    patch_embedding = nn.Conv2d(
                in_channels=num_channels,# 3
                out_channels=embed_dim,# 1152
                kernel_size=patch_size,# 14
                stride=patch_size,# 14
                padding="valid",
            )# 像素点到embedding的过程是通过这个卷积操作完成的
    
    patch_embeds = patch_embedding(all_pixel_values)# patch_embeds是卷积后的patch [3, 1152, 1, 1036] 1036 = 14504/14 1 = 14/14
    embeddings = patch_embeds.flatten(2).transpose(1, 2)# batch_size,1*1036,1152
    # 到这里,像素点就完成了到embedding到过程
    return embeddings

def get_position_embedding(all_pixel_values, patch_attn_mask, tgt_sizes):
    """
    输入:all_pixel_values 经过拼接后的像素点
         patch_attn_mask mask矩阵
         tgt_sizes 图片尺寸大小
    输出:位置embeddings
    demo:
    batch_size = 3
    num_channels = 3
    patch_size = 14
    h,w = 28,37
    num = patch_size * h * w
    all_pixel_values = torch.randn(batch_size, num_channels, patch_size, num)
    tgt_sizes = torch.tensor([[28, 37],
        [39, 26],
        [39, 26]])
    patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool)# 3,1,1306
    for i in range(B):# 将图片块的位置填充为True
        patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
    embeddings = get_position_embedding(all_pixel_values, patch_attn_mask, tgt_sizes)
    # batch_size,1 * h * w,1152
    """
    embed_dim = 1152
    num_patches_per_side = 70
    num_positions = num_patches_per_side**2
    position_embedding = nn.Embedding(num_positions, embed_dim)# 4900*1152
    batch_size = all_pixel_values.size(0)# all_pixel_values原始图片大小 batch_size=3
    max_im_h, max_im_w = all_pixel_values.size(2), all_pixel_values.size(3)# max_im_h=14,max_im_w=14504
    max_nb_patches_h, max_nb_patches_w = max_im_h // patch_size, max_im_w // patch_size# 1 1036
    position_ids = torch.full(
                size=(
                    batch_size,
                    max_nb_patches_h * max_nb_patches_w,
                ),
                fill_value=0,
            )# 3,1 * 1036 全0
    boundaries = torch.arange(1 / num_patches_per_side, 1.0, 1 / num_patches_per_side)
    # 从1/70开始到1,间隔是1/70,共69个数,注意torch.arange是左闭右开的,不包含1
    # [0.0143, 0.0286,...,0.9714, 0.985]    
    for batch_idx, p_attn_mask in enumerate(patch_attn_mask):
        if tgt_sizes is not None:
            nb_patches_h = tgt_sizes[batch_idx][0]# 28
            nb_patches_w = tgt_sizes[batch_idx][1]# 37
        else:
            nb_patches_h = patch_attn_mask[:, 0].sum()
            nb_patches_w = patch_attn_mask[0].sum()
        fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)# 生成从0到1 - 1e-6,间隔是1 / nb_patches_h,共28个数
        fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)# 生成从0到1 - 1e-6,间隔是1 / nb_patches_w,共37个数
        bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
        # 从boundaries中找到fractional_coords_h该差值的地方
        # 根据boundaries序列返回fractional_coords_h中每个元素的区间索引
        # boundaries = [0.0143, 0.0286, 0.0429, 0.0571, 0.0714, 0.0857, 0.1000, 0.1143, 0.1286...]
        # fractional_coords_h = [0.0000, 0.0357, 0.0714, 0.1071, 0.1429, 0.1786, 0.2143, 0.2500, 0.2857...]
        # 0.0000<0.0143,找到索引0
        # 0.0286<0.0357<0.0429,找到索引2
        # 0.0714=0.0714<0.0857,找到索引5
        # 0.1000=0.1071<0.1143,找到索引7
        # [ 0,  2,  5,  7, 10, 12, 15, 17, 20, 22, 25, 27, 30, 32, 35, 37, 40, 42, 45, 47, 50, 52, 55, 57, 60, 62, 65, 67]
        bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
        # [ 0,  1,  3,  5,  7,  9, 11, 13, 15, 17, 18, 20, 22, 24, 26, 28, 30, 32, 34, 35, 37, 39, 41, 43, 45, 47, 49, 51, 52, 54, 56, 58, 60, 62, 64, 66, 68]
        pos_ids = (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten()
        # bucket_coords_h[:, None] * num_patches_per_side得到0 140 350...
        # bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w 
        # [   0,    1,    3,  ...,   64,   66,   68],
        # [ 140,  141,  143,  ...,  204,  206,  208],
        # [ 350,  351,  353,  ...,  414,  416,  418] 。。。
        # (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten()得到
        # tensor([   0,    1,    3,  ..., 4754, 4756, 4758])
        # 这样把位置id就映射到4900上了
        position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
        # 这个看起来复杂的过程是将[28,37]找到对应的位置id
        # 这里使用的是随机位置编码的方法
        # 这里是将2D均匀映射到[70,70]上面
        """
        tensor([[   0,    1,    3,  ..., 4754, 4756, 4758],
                [   0,    2,    5,  ...,    0,    0,    0],
                [   0,    2,    5,  ...,    0,    0,    0]])
        """
    embeddings = position_embedding(position_ids)# 3,1036,1152
    return embeddings

函数调用

ori_embeddings = get_embedding(all_pixel_values)# batch_size,1 * h * w,1152
pos_embeddings = get_position_embedding(all_pixel_values, patch_attn_mask, tgt_sizes)# 3,h*w,1152
embeddings = ori_embeddings + pos_embeddings# 3,h*w,1152

额外说几句

为什么这里得到pos_ids会这么复杂呢?
将2D的[h,w]对应到2D的[70,70],按照我一开始想的,那将[h,w]拉平到h*w,直接映射到70*70多简单啊,但是看了代码就发现不是这么做的
代码中是将每一行单独映射到70,每一列也单独映射到70,这么说有点空
给个demo:
按照第一张图片的宽是37,高是28
每一行的尺寸是37,均匀映射到70对应的位置是[0 1 3 5 …]
每一列的尺寸是28,均匀映射到70对应的位置是[0 2 5 7…]
因为每一行都有70个位置,所以每一列的位置id都需要乘上70得到这一列的真实列位置id
[ 0 2 5 7 . . . 67 ] ∗ 70 = [ 0 140 350 490 . . . 4690 ] \begin{bmatrix} &0\\ &2\\ &5\\ &7\\ &...\\ &67 \end{bmatrix} *70=\begin{bmatrix} &0\\ &140\\ &350\\ &490\\ &...\\ &4690 \end{bmatrix} 0257...67 70= 0140350490...4690
那真实列id+每一行的映射id就得到了2D位置编码
[ 0 140 350 490 . . . 4690 ] + [ 0 1 3 5 . . . 68 ] = [ 0 1 3 . . . 68 140 141 143 . . . 208 350 351 353 . . . 418 490 491 493 . . . 558 . . . 4690 4691 4693 . . . 4758 ] \begin{bmatrix} &0\\ &140\\ &350\\ &490\\ &...\\ &4690 \end{bmatrix} +\begin{bmatrix} &0 &1 &3 &5 &... &68 \end{bmatrix}=\begin{bmatrix} &0&1&3&...&68\\ &140&141&143&...&208\\ &350&351&353&...&418\\ &490&491&493&...&558\\ &...\\ &4690&4691&4693&...&4758 \end{bmatrix} 0140350490...4690 +[0135...68]= 0140350490...46901141351491469131433534934693...............682084185584758
将这个位置编码拉平就是pos_ids了,随后就可以根据pos_ids得到对应的位置编码了

参考

modeling_minicpmv
modeling_navit_siglip

  • 15
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值