SWIN-UNet--tensorflow

模型github下载地址:

GitHub - yingkaisha/keras-unet-collection: The Tensorflow, Keras implementation of U-net, V-net, U-net++, UNET 3+, Attention U-net, R2U-net, ResUnet-a, U^2-Net, TransUNET, and Swin-UNET with optional ImageNet-trained backbones.

文件列表:

_model_swin_unet_2d.py--模型主要流程

layer_utils.py

transformer_layers.py

activations.py

swin_unet_2d:

代码引用主要参数:

input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
filter_num_begin: number of channels in the first downsampling block; 
                  it is also the number of embedded dimensions.default:96
n_labels: number of output labels.
depth: the depth of Swin-UNET, e.g., depth=4 means three down/upsampling levels and a bottom level.
stack_num_down: number of convolutional layers per downsampling level/block. 
stack_num_up: number of convolutional layers (after concatenation) per upsampling level/block.
name: prefix of the created keras model and its layers.
def swin_unet_2d(input_size, filter_num_begin, n_labels, depth, stack_num_down, stack_num_up, 
                      patch_size, num_heads, window_size, num_mlp, output_activation='Softmax', shift_window=True, name='swin_unet'):   

swin_unet_2d_base:

def swin_unet_2d_base(input_tensor, filter_num_begin, depth, stack_num_down, stack_num_up, patch_size, num_heads, window_size, num_mlp, shift_window=True, name='swin_unet'):
 # Compute number be patches to be embeded
    input_size = input_tensor.shape.as_list()[1:]
    num_patch_x = input_size[0]//patch_size[0]  #56
    num_patch_y = input_size[1]//patch_size[1]
    
    # Number of Embedded dimensions
    embed_dim = filter_num_begin #96
    
    depth_ = depth
    
    X_skip = []

    X = input_tensor #224*224*3
    
    # Patch extraction
    X = patch_extract(patch_size)(X)

    # Embed patches to tokens
    X = patch_embedding(num_patch_x*num_patch_y, embed_dim)(X)
    
    # The first Swin Transformer stack
    X = swin_transformer_stack(X, stack_num=stack_num_down, 
                               embed_dim=embed_dim, num_patch=(num_patch_x, num_patch_y), 
                               num_heads=num_heads[0], window_size=window_size[0], num_mlp=num_mlp, 
                               shift_window=shift_window, name='{}_swin_down0'.format(name))
    X_skip.append(X)
    
    # Downsampling blocks
    for i in range(depth_-1):
        
        # Patch merging
        X = patch_merging((num_patch_x, num_patch_y), embed_dim=embed_dim, name='down{}'.format(i))(X)
        
        # update token shape info
        embed_dim = embed_dim*2
        num_patch_x = num_patch_x//2
        num_patch_y = num_patch_y//2
        
        # Swin Transformer stacks
        X = swin_transformer_stack(X, stack_num=stack_num_down, 
                                   embed_dim=embed_dim, num_patch=(num_patch_x, num_patch_y), 
                                   num_heads=num_heads[i+1], window_size=window_size[i+1], num_mlp=num_mlp, 
                                   shift_window=shift_window, name='{}_swin_down{}'.format(name, i+1))
        
        # Store tensors for concat
        X_skip.append(X)
        
    # reverse indexing encoded tensors and hyperparams
    X_skip = X_skip[::-1]
    num_heads = num_heads[::-1]
    window_size = window_size[::-1]
    
    # upsampling begins at the deepest available tensor
    X = X_skip[0]
    
    # other tensors are preserved for concatenation
    X_decode = X_skip[1:]
    
    depth_decode = len(X_decode)
    
    for i in range(depth_decode):
        
        # Patch expanding
        X = patch_expanding(num_patch=(num_patch_x, num_patch_y),
                            embed_dim=embed_dim, upsample_rate=2, return_vector=True, name='{}_swin_up{}'.format(name, i))(X)
        

        # update token shape info
        embed_dim = embed_dim//2
        num_patch_x = num_patch_x*2
        num_patch_y = num_patch_y*2
        
        # Concatenation and linear projection
        X = concatenate([X, X_decode[i]], axis=-1, name='{}_concat_{}'.format(name, i))
        X = Dense(embed_dim, use_bias=False, name='{}_concat_linear_proj_{}'.format(name, i))(X)
        
        # Swin Transformer stacks
        X = swin_transformer_stack(X, stack_num=stack_num_up, 
                           embed_dim=embed_dim, num_patch=(num_patch_x, num_patch_y), 
                           num_heads=num_heads[i], window_size=window_size[i], num_mlp=num_mlp, 
                           shift_window=shift_window, name='{}_swin_up{}'.format(name, i))
        
    # The last expanding layer; it produces full-size feature maps based on the patch size
    # !!! <--- "patch_size[0]" is used; it assumes patch_size = (size, size)
    X = patch_expanding(num_patch=(num_patch_x, num_patch_y),
                        embed_dim=embed_dim, upsample_rate=patch_size[0], return_vector=False)(X)
    
    return X

其中, Patch extraction: X = patch_extract(patch_size)(X) --patch_size=4

参数解释:

stack_num_down:每个下采样级别/块的卷积层数。default:2

stack_num_up:每个上采样级别/块的卷积层数(串联后)。efault:

depth:swin-unet的深度,当深度为4时,进行3次上/下采样。default:4

从一张图片中提取多个patch,patch大小4*4,移动步长为4*4

    def __init__(self, patch_size, **kwargs):
        super(patch_extract, self).__init__(**kwargs)
        self.patch_size = patch_size
        self.patch_size_x = patch_size[0]
        self.patch_size_y = patch_size[0]
    
    def call(self, images):
        
        batch_size = tf.shape(images)[0]
        #tensorflow.image() -->extract_patches
        patches = extract_patches(images=images,
                                  sizes=(1, self.patch_size_x, self.patch_size_y, 1),
                                  strides=(1, self.patch_size_x, self.patch_size_y, 1),
                                  rates=(1, 1, 1, 1), padding='VALID',)
        # patches.shape = (num_sample, patch_num, patch_num, patch_size*channel)
        
        patch_dim = patches.shape[-1]
        patch_num = patches.shape[1]
        patches = tf.reshape(patches, (batch_size, patch_num*patch_num, patch_dim))
        # patches.shape = (num_sample, patch_num*patch_num, patch_size*channel)
        
        return patches

patches.shape参数解释:

num_sample:

channel:

def swin_transformer_stack(X, stack_num, embed_dim, num_patch, num_heads, window_size, num_mlp, shift_window=True, name=''):
    '''
    Stacked Swin Transformers that share the same token size.
    
    Alternated Window-MSA and Swin-MSA will be configured if `shift_window=True`, Window-MSA only otherwise.
    *Dropout is turned off.
    '''
    # Turn-off dropouts
    mlp_drop_rate = 0 # Droupout after each MLP layer
    attn_drop_rate = 0 # Dropout after Swin-Attention
    proj_drop_rate = 0 # Dropout at the end of each Swin-Attention block, i.e., after linear projections
    drop_path_rate = 0 # Drop-path within skip-connections
    
    qkv_bias = True # Convert embedded patches to query, key, and values with a learnable additive value
    qk_scale = None # None: Re-scale query based on embed dimensions per attention head # Float for user specified scaling factor
    
    if shift_window:
        shift_size = window_size // 2
    else:
        shift_size = 0
    
    for i in range(stack_num):
    
        if i % 2 == 0:
            shift_size_temp = 0
        else:
            shift_size_temp = shift_size

        X = SwinTransformerBlock(dim=embed_dim, num_patch=num_patch, num_heads=num_heads, 
                                 window_size=window_size, shift_size=shift_size_temp, num_mlp=num_mlp, qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 mlp_drop=mlp_drop_rate, attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, drop_path_prob=drop_path_rate, 
                                 name='name{}'.format(i))(X)
    return X

根据(bool)shift_window选择W-MSA或是SW-MSA。

调用SwinTransformerBlock()进行循环。

后续继续补充。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
Swin-UNet是一种新型图像分割模型,它融合了Swin Transformer和UNet的特性,具有较好的图像分割能力。下面我将简单介绍如何使得Swin-UNet运行。 首先,准备好所需的开发环境,包括Python环境和必要的库。确保安装好PyTorch、Torchvision和其他所需的依赖项。 接下来,下载Swin-Transformer和Swin-UNet的代码。这些代码可以从GitHub上的相关仓库获取,可以使用Git命令将代码克隆到本地。确保克隆了最新的代码版本。 然后,准备好训练数据集。您可以选择一个适合您的应用场景的图像分割数据集,确保该数据集已经按照要求进行标注。将训练和验证数据集划分好,并按照指定的格式准备好。 接着,根据Swin-UNet的文档或示例代码,配置模型的参数和超参数。这些参数包括输入图像大小、批大小、学习率、网络层的尺寸等。根据您的需求和硬件资源,进行相应的调整。 之后,使用准备好的数据集进行训练。使用训练数据集和配置好的参数,运行训练代码,开始训练Swin-UNet模型。根据需要,您可以设定训练的迭代次数或停止条件。 训练完成后,您可以使用训练好的Swin-UNet模型进行图像分割任务的推理。提供一张测试图像,通过加载训练好的模型并对测试图像进行预测,获取图像分割的结果。 最后,根据需要对模型进行评估和调优。使用预留的验证数据集,计算模型在图像分割任务中的精度、召回率、准确率等指标。根据评估结果,进行模型的参数调整或其他优化操作。 总结来说,要使Swin-UNet跑通,您需要准备好开发环境、获取代码和数据集、配置参数、进行训练和推理,并对模型进行评估和调优。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值