记录一下transformers.models.pix2struct.image_processing_pix2struct.Pix2StructImageProcessor代码阅读,总是好奇图片patch怎么来的,跟着代码解读了一遍,主要步骤是第1步得到第5步,中间可忽略(步骤3没有截全),然后第五步中的每一行都是第一步(原图的)一个个patch拉平而已,并不是想象钟的对应patch位置的3通道直接拉平。好奇这么做比如做ocr时,9 10 13 14位置的像素,不是会被硬生生的分开么?这样模型能正确识别出来么?
贴一下笔记草图:
最后上面的result的行列会先剥离出来,然后剩下的图片过线性层(这里限制为768维),加上行列位置embedding,得到最终的图片embedding,代码如下:
def forward(self, flattened_patches: torch.Tensor) -> torch.Tensor:
# the row and column indices are stored in the first and second position of the flattened_patches
# flattened_patches: `batch_size`, `seq_len`, `hidden_size` + 2
row_indices = flattened_patches[:, :, 0].long()
col_indices = flattened_patches[:, :, 1].long()
flattened_patches = flattened_patches[:, :, 2:]
embeddings = self.patch_projection(flattened_patches)
row_embeddings = self.row_embedder(row_indices)
col_embeddings = self.column_embedder(col_indices)
# sum all embeddings together
embeddings = embeddings + row_embeddings + col_embeddings
embeddings = self.dropout(embeddings)
return embeddings