3D图像的swin transformer的tensorflow2.x实现

unetr这种模型,用于3d医学图像的分割,使用了swin transformer,不同于单纯的multi head attention的是,它还有relativeCoordsmaskpatch_merging,一直以来我都不知道怎么在tensorflow2.x里面实现,今天参考了torch的3D实现,和tensorflow2.x的2D实现,实现了tensorflow2.x的3D实现。

训练了一下,确实是可以的。

help_functions.py

import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import tensorflow as tf
from tensorflow.keras import  layers

def windowPartition(inputs,window_size):
    _,depth,height,width,channels = inputs.shape
    patchNumZ = depth // window_size
    patchNumY = height // window_size
    patchNumX = width // window_size

    x = tf.reshape(inputs,
                    shape=(-1,patchNumZ,window_size
                             ,patchNumY,window_size
                             ,patchNumX,window_size
                             ,channels))
    x = tf.transpose(x,(0,1,3,5,2,4,6,7))
    windows = tf.reshape(x,shape=(-1,window_size,window_size,window_size,channels))
    # -> B*numWindows,windowSize,windowSize,windowSize,channels
    return windows

def windowReverse(windows,window_size,depth,height,width,channles):
    patchNumZ = depth // window_size
    patchNumY = height // window_size
    patchNumX = width // window_size
    x = tf.reshape(windows,shape=(-1,patchNumZ,patchNumY,patchNumX
                                    ,window_size,window_size,window_size
                                    ,channles))
    x = tf.transpose(x,perm=(0,1,4,2,5,3,6,7))
    x = tf.reshape(x,shape=(-1,depth,height,width,channles))
    return x

class PatchEmbedding(layers.Layer):
    def __init__(self,num_patch,embed_dim,**kwargs):
        super(PatchEmbedding,self).__init__(**kwargs)

        self.numPatch = num_patch
        self.proj = layers.Dense(embed_dim)
        self.pos_embed = layers.Embedding(input_dim=num_patch,
                                          output_dim=embed_dim)

    def call(self,patch):
        pos = tf.range(start=0,limit=self.numPatch,delta=1)
        return  self.proj(patch) + self.pos_embed(pos)

class PatchExtract(layers.Layer):
    def __init__(self,patch_size:list,**kwargs):
        super(PatchExtract,self).__init__(**kwargs)

        self.patchSizeZ = patch_size[0]
        self.patchSizeY = patch_size[1]
        self.patchSizeX = patch_size[2]

    def call(self,images):
        batchSize = tf.shape(images)[0]
        patches = tf.image.extract_patches(images=images,
                                           sizes=(1,self.patchSizeZ,self.patchSizeY,self.patchSizeX,1),
                                           strides=(1,self.patchSizeZ,self.patchSizeY,self.patchSizeX,1),
                                           rates=(1,1,1,1),
                                           padding="VALID")
        # -> b,num_pathces,d,h,w,c
        patchDim = patches.shape[-1]
        patchNum = patches.shape[1]
        return tf.reshape(patches,(batchSize,patchNum*patchNum*patchNum,patchDim))


class PatchMerging(layers.Layer):
    def __init__(self,num_patch:list,embed_dim,**kwargs):
        super(PatchMerging,self).__init__(**kwargs)

        self.numPatch = num_patch
        self.embedDim = embed_dim

        self.linearTrans = layers.Dense(2*self.embedDim,use_bias=False)

    def get_config(self):
        config = super(PatchMerging,self).get_config()
        config.update(
            {
                "num_patch":self.numPatch,
                "embed_dim":self.embedDim,
            }
        )
        return config

    def call(self,x):
        depth,height,width = self.numPatch
        _,_,C = x.get_shape().as_list()
        x = tf.reshape(x,shape=(-1,depth,height,width,C))
        # print('x.shape:',x.shape)
        x0 = x[:, 0::2, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, 0::2, :]
        x3 = x[:, 0::2, 0::2, 1::2, :]
        x4 = x[:, 1::2, 0::2, 1::2, :]
        x5 = x[:, 0::2, 1::2, 0::2, :]
        x6 = x[:, 0::2, 0::2, 1::2, :]
        x7 = x[:, 1::2, 1::2, 1::2, :]
        x = tf.concat((x0,x1,x2,x3,x4,x5,x6,x7),axis=-1)
        # print('x.shape:',x.shape)
        x = tf.reshape(x,shape=(-1,(depth//2)*(height//2)*(width//2),8*C))
        # print('x.shape:', x.shape)
        return self.linearTrans(x)

net_layers.py

import warnings

warnings.filterwarnings("ignore")
import os
import tensorflow_addons.layers as tfalayers

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, activations
import numpy as np

# 自定义函数
from nets.help_functions import  windowPartition,windowReverse


class DropPath(layers.Layer):
    def __init__(self, drop_prob=None, **kwargs):
        super(DropPath, self).__init__(**kwargs)
        self.dropProb = drop_prob

    def get_config(self):
        config = super(DropPath,self).get_config()
        config.update(
            {
                "drop_prob":self.dropProb
            }
        )
        return config

    def call(self, x):
        input_shape = tf.shape(x)
        batch_size = input_shape[0]
        rank = x.shape.rank
        shape = (batch_size,) + (1,) * (rank - 1)
        random_tensor = (1 - self.dropProb) + tf.random.uniform(shape, dtype=x.dtype)
        path_mask = tf.floor(random_tensor)
        output = tf.math.divide(x, 1 - self.dropProb) * path_mask
        return output

class WindowAttention(layers.Layer):
    def __init__(self, dim, window_size: list, num_heads, qkv_bias=True,
                 dropout_rate=0.0, **kwargs):
        super(WindowAttention, self).__init__(**kwargs)

        self.dim = dim
        self.windowSize = window_size
        self.numHeads = num_heads
        self.qkvBias = qkv_bias
        self.dropoutRate = dropout_rate

        self.scale = (self.dim // self.numHeads) ** (-0.5)
        self.qkv = layers.Dense(self.dim * 3, use_bias=self.qkvBias)
        self.dropout = layers.Dropout(self.dropoutRate)
        self.proj = layers.Dense(self.dim)

    def get_config(self):
        config = super(WindowAttention,self).get_config()
        config.update(
            {
                "dim":self.dim,
                "window_size":self.windowSize,
                "num_heads":self.numHeads,
                "qkv_bias":self.qkvBias,
                "dropout_rate":self.dropoutRate
            }
        )
        return config

    def build(self, input_shape):
        numWindowElements = (2 * self.windowSize[0] - 1) * (2 * self.windowSize[1] - 1) * (2 * self.windowSize[2] - 1)
        self.relativePositionBiasTable = self.add_weight(
            shape=(numWindowElements, self.numHeads),
            initializer=tf.initializers.Zeros(),
            trainable=True,name='relative_position_bias_table')
        coordsD = np.arange(self.windowSize[0])
        coordsH = np.arange(self.windowSize[1])
        coordsW = np.arange(self.windowSize[2])
        coordsMatrix = np.meshgrid(coordsD, coordsH, coordsW, indexing="ij")
        # -> 3,windowSize,windowSize,windowSize
        coords = np.stack(coordsMatrix)
        # -> 3,(windowSize*windowSize*windowSize)
        coordsFlatten = np.reshape(coords, newshape=(3, -1))
        # -> 3,(windowSize*windowSize*windowSize),(windowSize*windowSize*windowSize)
        relativeCoords = coordsFlatten[:, :, None] - coordsFlatten[:, None, :]
        # -> (windowSize*windowSize*windowSize),(windowSize*windowSize*windowSize),3
        relativeCoords = np.transpose(relativeCoords, axes=(1, 2, 0))

        relativeCoords[:, :, 0] += self.windowSize[0] - 1
        relativeCoords[:, :, 1] += self.windowSize[1] - 1
        relativeCoords[:, :, 2] += self.windowSize[2] - 1
        relativeCoords[:, :, 0] *= (2 * self.windowSize[1] - 1) * (2 * self.windowSize[2] - 1)
        relativeCoords[:, :, 1] *= (2 * self.windowSize[2] - 1)

        relativePositionIndex = relativeCoords.sum(-1)
        self.relativePositionIndex = tf.Variable(initial_value=tf.convert_to_tensor(relativePositionIndex),
                                                 trainable=False,name='relative_position_index')

    def call(self, x, mask=None):
        _, size, channles = x.shape
        # print('attnWinow inputs.shape:', x.shape)
        headDim = channles // self.numHeads
        # -> _,size,channles*3
        qkvX = self.qkv(x)
        # print('qkvX.shape:',qkvX.shape)

        qkvX = tf.reshape(qkvX, shape=(-1, size, 3, self.numHeads, headDim))
        qkvX = tf.transpose(qkvX, perm=(2, 0, 3, 1, 4))
        # -> _,numHeads,(windowSize*windowSize*windowSize),headDim
        q, k, v = qkvX[0], qkvX[1], qkvX[2]
        # print('q.shape:', q.shape)
        q = q * self.scale
        k = tf.transpose(k, perm=(0, 1, 3, 2))
        # -> _,numHeads,(windowSize*windowSize*windowSize),(windowSize*windowSize*windowSize)
        attn = q @ k
        # print('attn.shape:', attn.shape)
        numWindowElements = self.windowSize[0] * self.windowSize[1] * self.windowSize[2]
        relativePositionIndexFlat = tf.reshape(self.relativePositionIndex, shape=(-1,))
        relativePositionBias = tf.gather(self.relativePositionBiasTable,
                                         relativePositionIndexFlat)
        # -> (windowSize*windowSize*windowSize),(windowSize*windowSize*windowSize),numHeads
        relativePositionBias = tf.reshape(relativePositionBias,
                                          shape=(numWindowElements, numWindowElements, -1))
        # -> numHeads,(windowSize*windowSize*windowSize),(windowSize*windowSize*windowSize)
        relativePositionBias = tf.transpose(relativePositionBias, perm=(2, 0, 1))

        attn = attn + tf.expand_dims(relativePositionBias, axis=0)

        if mask is not None:
            # numWindows
            nW = mask.get_shape()[0]
            maskFloat = tf.cast(tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32)
            attn = (tf.reshape(attn, shape=(-1, nW, self.numHeads, size, size)) + maskFloat)
            attn = tf.reshape(attn, shape=(-1, self.numHeads, size, size))
            attn = activations.softmax(attn, axis=-1)
        else:
            attn = activations.softmax(attn, axis=-1)
        attn = self.dropout(attn)

        # -> _,numHeads,(windowSize*windowSize*windowSize),headDim
        qkvX = attn @ v
        # -> _,(windowSize*windowSize*windowSize),numHeads,headDim
        qkvX = tf.transpose(qkvX, perm=(0, 2, 1, 3))
        # -> _,(windowSize*windowSize*windowSize),(numHeads*headDim)
        qkvX = tf.reshape(qkvX, shape=(-1, size, channles))
        qkvX = self.proj(qkvX)
        qkvX = self.dropout(qkvX)
        return qkvX

class SwinTransformer3D(layers.Layer):
    def __init__(self,dim,num_patch:list,num_heads,window_size,shift_size,
                 num_mlp,qkv_bias=True,dropout_rate=0.0,**kwargs):
        super(SwinTransformer3D,self).__init__(**kwargs)

        self.dim = dim
        self.numPatch = num_patch
        self.numHeads = num_heads
        self.windowSize = window_size
        self.shiftSize = shift_size
        self.numMlp = num_mlp
        self.qkvBias = qkv_bias
        self.dropoutRate = dropout_rate

        self.norm1 = layers.LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttention(dim=self.dim,
                                    window_size=(self.windowSize,self.windowSize,self.windowSize),
                                    num_heads=self.numHeads,
                                    qkv_bias=self.qkvBias,
                                    dropout_rate=self.dropoutRate)
        self.dropPath = DropPath(drop_prob=self.dropoutRate)
        self.norm2 = layers.LayerNormalization(epsilon=1e-5)

        self.mlp = keras.Sequential(
            [
                layers.Dense(self.numMlp),
                layers.Activation(keras.activations.gelu),
                layers.Dropout(self.dropoutRate),
                layers.Dense(self.dim),
                layers.Dropout(self.dropoutRate),
            ]
        )

        if min(self.numPatch) < self.windowSize:
            self.shiftSize = 0
            self.windowSize = min(self.numPatch)

    def build(self,input_shape):
        if self.shiftSize == 0:
            self.attnMask = None
        else:
            depth,height,width = self.numPatch
            dSlices = (
                slice(0,-self.windowSize),
                slice(-self.windowSize,-self.windowSize),
                slice(-self.windowSize,None)
            )
            hSlices = (
                slice(0, -self.windowSize),
                slice(-self.windowSize, -self.windowSize),
                slice(-self.windowSize, None)
            )
            wSlices = (
                slice(0, -self.windowSize),
                slice(-self.windowSize, -self.windowSize),
                slice(-self.windowSize, None)
            )
            maskArray = np.zeros((1,depth,height,width,1))
            count = 0
            for d in dSlices:
                for h in hSlices:
                    for w in wSlices:
                        maskArray[:,d,h,w,:] = count
                        count +=1
            maskArray = tf.convert_to_tensor(maskArray)

            # maskArray to windows
            maskWindows = windowPartition(maskArray,self.windowSize)
            maskWindows = tf.reshape(maskWindows,shape=[-1,self.windowSize*self.windowSize*self.windowSize])
            attnMask = tf.expand_dims(maskWindows,axis=1)-tf.expand_dims(maskWindows,axis=2)

            attnMask = tf.where(attnMask!=0,-100,attnMask)
            attnMask = tf.where(attnMask==0,0.0,attnMask)
            self.attnMask = tf.Variable(initial_value=attnMask,trainable=False,name='attn_mask')

    def get_config(self):
        config = super(SwinTransformer3D,self).get_config()
        config.update(
            {
                "dim":self.dim,
                "num_patch":self.numPatch,
                "num_heads":self.numHeads,
                "window_size":self.windowSize,
                "shift_size":self.shiftSize,
                "num_mlp":self.numMlp,
                "qkv_bias":self.qkvBias,
                "dropout_rate":self.dropoutRate
            }
        )
        return config

    def call(self,x):
        depth,height,width = self.numPatch
        _,numPatchesBefore,channels = x.shape
        xSkip = x
        x = self.norm1(x)
        x = tf.reshape(x,shape=(-1,depth,height,width,channels))
        if (self.shiftSize>0):
            shiftedX = tf.roll(x,shift=[-self.shiftSize,-self.shiftSize,-self.shiftSize],
                               axis=[1,2,3])
        else:
            shiftedX = x

        xWindows = windowPartition(shiftedX,self.windowSize)
        xWindows = tf.reshape(xWindows,shape=(-1,self.windowSize*self.windowSize*self.windowSize,channels))
        attnWindows = self.attn(xWindows,mask=self.attnMask)

        attnWindows = tf.reshape(attnWindows,shape=(-1,self.windowSize,self.windowSize,self.windowSize,channels))
        # print('attnWindows.shape:',attnWindows.shape)
        shiftedX = windowReverse(attnWindows,self.windowSize,depth,height,width,channels)
        # print('shiftedX.shape:',shiftedX.shape)
        if self.shiftSize > 0:
            x = tf.roll(shiftedX,
                        shift=[self.shiftSize,self.shiftSize,self.shiftSize],
                        axis=[1,2,3])
        else:
            x = shiftedX

        x = tf.reshape(x,shape=(-1,depth*height*width,channels))
        x = self.dropPath(x)
        x = xSkip + x
        xSkip = x
        # print('xSkip.shape:',xSkip.shape)
        x = self.norm2(x)
        x = self.mlp(x)
        # print('x.shape:',x.shape)
        # input('zzz')
        x = self.dropPath(x)
        x = xSkip + x
        return  x

class PatchEmebdding(layers.Layer):
    def __init__(self,patch_size:list,embed_dim:int,**kwargs):
        super(PatchEmebdding,self).__init__(**kwargs)

        self.patchSize = patch_size
        self.embedDim = embed_dim
        self.proj = layers.Conv3D(embed_dim,
                                  kernel_size=patch_size,
                                  strides=patch_size)

    def build(self,input_shape):
        _, d, h, w, c = input_shape
        self.nPatches = (d//self.patchSize[0])*(h//self.patchSize[1])*(w//self.patchSize[2])
        self.posEmbedding = self.add_weight(name="pos_embedding",
                                            shape=[1,self.nPatches,self.embedDim],
                                            dtype="float32",
                                            initializer="random_normal",
                                            trainable=True)

    def get_config(self):
        config = super(PatchEmebdding,self).get_config()
        config.update(
            {
                "patch_size":self.patchSize,
                "embed_dim":self.embedDim,
            }
        )
        return config

    def call(self,x):
        _,d,h,w,c = x.shape.as_list()
        assert (d%self.patchSize[0]==0 and h%self.patchSize[1]==0 and w%self.patchSize[2]==0), f'input.shape应该可以被patchSize整除'
        x = self.proj(x)
        _,d1,h1,w1,c1 = x.shape.as_list()
        x = tf.reshape(x,shape=(-1,d1*h1*w1,c1))
        x = x + self.posEmbedding

        return x

class ResBlock(layers.Layer):
    def __init__(self,filters,kernel_size=3,stride=1,padding='same',**kwargs):
        super(ResBlock,self).__init__(**kwargs)
        self.filters = filters
        self.kernelSize = kernel_size
        self.stride = stride
        self.padding = padding

        self.act = layers.LeakyReLU(alpha=0.01)
        self.norm1 = tfalayers.InstanceNormalization()
        self.norm2 = tfalayers.InstanceNormalization()
        self.norm3 = tfalayers.InstanceNormalization()

        self.conv1 = layers.Conv3D(self.filters, self.kernelSize, self.stride, padding=self.padding)
        self.conv2 = layers.Conv3D(self.filters, self.kernelSize, self.stride, padding=self.padding)
        self.conv3 = layers.Conv3D(self.filters,1)


    # def build(self,input_shape):
    #     _,d,h,w,c = input_shape

    def get_config(self):
        config = super(ResBlock,self).get_config()
        config.update(
            {
                "filters":self.filters,
                "kernel_size":self.kernelSize,
                "stride":self.stride,
                "padding":self.padding,
            }
        )
        return config

    def call(self,x):
        residual = x
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.act(x)
        x = self.conv2(x)
        x = self.norm2(x)
        if residual.shape[-1]!=x.shape[-1]:
            residual = self.conv3(residual)
            residual = self.norm3(residual)
        x = x+residual
        x = self.act(x)
        return x



if __name__ == '__main__':
    # w = WindowAttention(24, [7, 7, 7], 3)
    # x = tf.ones(shape=(125, 343, 24))
    # x1 = w(x)
    # print(x1.shape)

    # s = SwinTransformer(24,[32,32,32],3,4,3,4,)
    # x = tf.ones(shape=(1,32*32*32,24))
    # x1 = s(x)
    # merge = PatchMerging([32,32,32],24)
    # x2 = merge(x1)
    # print(x2.shape)

    pass
  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值