分层图像金字塔变压器

本文介绍了一种名为分层图像金字塔变换器(HIPT)的新型视觉变换器,专门用于处理计算病理学中的全幻灯片图像(WSI)。通过自我监督学习,HIPT在多种癌症类型的大型数据集上预训练,表现出色,尤其在癌症亚型分型和生存预测任务中。HIPT利用分层结构进行图像表示学习,逐步聚合从细胞到组织的特征,捕捉不同粒度的信息。
摘要由CSDN通过智能技术生成

文章来源:hierarchical-image-pyramid-transformers

2024 年 2 月 5 日

本文介绍了分层图像金字塔变换器 (HIPT),这是一种新颖的视觉变换器 (ViT) 架构,设计用于分析计算病理学中的十亿像素全幻灯片图像 (WSI)。 HIPT 利用 WSI 固有的层次结构通过自我监督学习来学习高分辨率图像表示。 HIPT 在涵盖 33 种癌症类型的大型数据集上进行预训练,并在多个幻灯片级任务中进行评估,在癌症亚型分型和生存预测方面表现出卓越的性能,展示了自我监督学习模型在捕获肿瘤微环境中关键的归纳偏差和表型方面的潜力。

1

本图展示了计算病理学中使用的全切片图像 (WSI) 的分层结构。左图显示的是多层次方法,在这种方法中,大型组织图像(150,000 x 150,000 像素)被分解成更小、更易于管理的部分:首先是显示组织表型的 4096 x 4096 区域,然后是 256 x 256 细胞组织斑块,最后是最小的 16 x 16 细胞特征。右图展示了 256 x 256 图像是如何由 256 个较小的 16 x 16 标记序列组成的,反过来,每个 256 x 256 图像又是如何成为 4096 x 4096 区域内 256 x 256 标记的更大的不连续序列的一部分。这种分层标记化方法可以处理和分析不同分辨率和比例的超大图像。

该模型由三个阶段的分层聚合组成,首先是自下而上地聚合各自 256x256 和 4096x4096 窗口中的 16x16 视觉标记,最终形成幻灯片级表示。HIPT 模型的主要组成部分可写如下:

1. 分层聚合: HIPT 在细胞、斑块和区域层面聚合视觉标记,形成幻灯片表征。这种分层方法是受自然语言处理中使用分层表示法的启发,在自然语言处理中,嵌入可以在不同层次上聚合,形成文档表示法。同样,在 WSI 的背景下,分层聚合允许模型捕捉不同粒度级别的信息,从单个细胞到更广泛的组织结构。

2. Transformer自注意力: 为了在聚合的每个阶段对视觉概念之间的重要依赖关系进行建模,HIPT 将 Transformer 自注意力调整为包络变换聚合层。这样,该模型就能捕捉视觉标记之间的复杂关系,并学习能编码图像中局部和全局上下文的表征。

3. 预训练和自我监督学习: HIPT 采用自我监督学习的方式对 33 种癌症类型的千兆像素 WSI 大数据集进行预训练。该模型利用两个层次的自我监督学习来学习高分辨率图像表征,并利用学生-教师知识提炼来对每个聚合层进行预训练,对大至 4096x4096 的区域进行自我监督学习。

4. 性能和应用: 研究结果表明,采用分层预训练的 HIPT 在幻灯片级任务上的表现优于目前最先进的方法。该模型的性能在包括癌症亚型和生存预测在内的 9 项幻灯片级任务上进行了评估,并显示其在捕捉组织微环境中更广泛的预后特征方面表现出色。

2

图中从左到右显示了三个聚合级别:

  1. 细胞级聚合: 单个细胞由 16 px tokens表示,然后使用 ViT256-16 模型将其聚合为片段级表示,再进行全局池化以获得单一矢量表示。
  2. 斑块级聚合: 使用专为 256 px 输入设计的更大 ViT 变体来处理 256 px 补丁,然后再次使用池化层将补丁级特征汇总为区域级表示。
  3. 区域级聚合: 最后,对 4096 px 的区域进行聚合,这一次使用的是将整个区域作为输入的 ViT,从而形成一个全局注意力汇集层,提供幻灯片级表示。

这一分层过程将问题分解为易于处理的部分,并关注从细胞到组织结构等不同层次的细节,从而使模型能够处理规模巨大的 WSI。

下面的脚本利用了专门用于高分辨率图像分析的视觉转换器(ViTs),并结合了几种先进的功能和技术:

1. 截断法线初始化: 这是一种用于初始化神经网络权重的技术,可避免与平均值产生较大偏差,从而确保早期训练阶段的稳定性。

2. Drop Path: 一种正则化方法,在训练过程中随机丢弃网络中的路径,通过模拟更薄的网络来提高泛化效果,类似于dropout,但针对的是残余连接。

3. 多层感知器(MLP)模块: 定义一个简单的双层 MLP,具有 GELU 激活函数和滤除功能,用于在转换器模块中处理特征。

4. 注意机制:采用可选偏置和缩放的自注意机制,这对捕捉输入数据中的全局依赖性至关重要。

5. Transformer模块: 将规范层、注意机制和 MLP 组合成一个内聚块,并可选择路径剔除进行正则化。

6. VisionTransformer4K:Vision Transformer 的专用版本,专为超高分辨率图像而设计,采用了位置嵌入插值等技术,以适应不同的图像尺寸,其结构也针对处理大规模图像进行了优化。

7. 实用功能: 包括用于截断法线权重初始化、下落路径模拟和参数计算的函数,以帮助进行模型设置和分析。

import argparse
import os
import sys
import datetime
import time
import math
import json
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision import models as torchvision_models
import vision_transformer as vits
from vision_transformer import DINOHead
import math
from functools import partial
import torch
import torch.nn as nn
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.
    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)
    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        # Uni
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值