像unetr
这种模型,用于3d医学图像的分割,使用了swin transformer
,不同于单纯的multi head attention
的是,它还有relativeCoords
、mask
、patch_merging
,一直以来我都不知道怎么在tensorflow2.x
里面实现,今天参考了torch
的3D实现,和tensorflow2.x
的2D实现,实现了tensorflow2.x
的3D实现。
训练了一下,确实是可以的。
help_functions.py
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
from tensorflow.keras import layers
def windowPartition(inputs,window_size):
_,depth,height,width,channels = inputs.shape
patchNumZ = depth // window_size
patchNumY = height // window_size
patchNumX = width // window_size
x = tf.reshape(inputs,
shape=(-1,patchNumZ,window_size
,patchNumY,window_size
,patchNumX,window_size
,channels))
x = tf.transpose(x,(0,1,3,5,2,4,6,7))
windows = tf.reshape(x,shape=(-1,window_size,window_size,window_size,channels))
# -> B*numWindows,windowSize,windowSize,windowSize,channels
return windows
def windowReverse(windows,window_size,depth,height,width,channles):
patchNumZ = depth // window_size
patchNumY = height // window_size
patchNumX = width // window_size
x = tf.reshape(windows,shape=(-1,patchNumZ,patchNumY,patchNumX
,window_size,window_size,window_size
,channles))
x = tf.transpose(x,perm=(0,1,4,2,5,3,6,7))
x = tf.reshape(x,shape=(-1,depth,height,width,channles))
return x
class PatchEmbedding(layers.Layer):
def __init__(self,num_patch,embed_dim,**kwargs):
super(PatchEmbedding,self).__init__(**kwargs)
self.numPatch = num_patch
self.proj = layers.Dense(embed_dim)
self.pos_embed = layers.Embedding(input_dim=num_patch,
output_dim=embed_dim)
def call(self,patch):
pos = tf.range(start=0,limit=self.numPatch,delta=1)
return self.proj(patch) + self.pos_embed(pos)
class PatchExtract(layers.Layer):
def __init__(self,patch_size:list,**kwargs):
super(PatchExtract,self).__init__(**kwargs)
self.patchSizeZ = patch_size[0]
self.patchSizeY = patch_size[1]
self.patchSizeX = patch_size[2]
def call(self,images):
batchSize = tf.shape(images)[0]
patches = tf.image.extract_patches(images=images,
sizes=(1,self.patchSizeZ,self.patchSizeY,self.patchSizeX,1),
strides=(1,self.patchSizeZ,self.patchSizeY,self.patchSizeX,1),
rates=(1,1,1,1),
padding="VALID")
# -> b,num_pathces,d,h,w,c
patchDim = patches.shape[-1]
patchNum = patches.shape[1]
return tf.reshape(patches,(batchSize,patchNum*patchNum*patchNum,patchDim))
class PatchMerging(layers.Layer):
def __init__(self,num_patch:list,embed_dim,**kwargs):
super(PatchMerging,self).__init__(**kwargs)
self.numPatch = num_patch
self.embedDim = embed_dim
self.linearTrans = layers.Dense(2*self.embedDim,use_bias=False)
def get_config(self):
config = super(PatchMerging,self).get_config()
config.update(
{
"num_patch":self.numPatch,
"embed_dim":self.embedDim,
}
)
return config
def call(self,x):
depth,height,width = self.numPatch
_,_,C = x.get_shape().as_list()
x = tf.reshape(x,shape=(-1,depth,height,width,C))
# print('x.shape:',x.shape)
x0 = x[:, 0::2, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, 0::2, :]
x2 = x[:, 0::2, 1::2, 0::2, :]
x3 = x[:, 0::2, 0::2, 1::2, :]
x4 = x[:, 1::2, 0::2, 1::2, :]
x5 = x[:, 0::2, 1::2, 0::2, :]
x6 = x[:, 0::2, 0::2, 1::2, :]
x7 = x[:, 1::2, 1::2, 1::2, :]
x = tf.concat((x0,x1,x2,x3,x4,x5,x6,x7),axis=-1)
# print('x.shape:',x.shape)
x = tf.reshape(x,shape=(-1,(depth//2)*(height//2)*(width//2),8*C))
# print('x.shape:', x.shape)
return self.linearTrans(x)
net_layers.py
import warnings
warnings.filterwarnings("ignore")
import os
import tensorflow_addons.layers as tfalayers
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, activations
import numpy as np
# 自定义函数
from nets.help_functions import windowPartition,windowReverse
class DropPath(layers.Layer):
def __init__(self, drop_prob=None, **kwargs):
super(DropPath, self).__init__(**kwargs)
self.dropProb = drop_prob
def get_config(self):
config = super(DropPath,self).get_config()
config.update(
{
"drop_prob":self.dropProb
}
)
return config
def call(self, x):
input_shape = tf.shape(x)
batch_size = input_shape[0]
rank = x.shape.rank
shape = (batch_size,) + (1,) * (rank - 1)
random_tensor = (1 - self.dropProb) + tf.random.uniform(shape, dtype=x.dtype)
path_mask = tf.floor(random_tensor)
output = tf.math.divide(x, 1 - self.dropProb) * path_mask
return output
class WindowAttention(layers.Layer):
def __init__(self, dim, window_size: list, num_heads, qkv_bias=True,
dropout_rate=0.0, **kwargs):
super(WindowAttention, self).__init__(**kwargs)
self.dim = dim
self.windowSize = window_size
self.numHeads = num_heads
self.qkvBias = qkv_bias
self.dropoutRate = dropout_rate
self.scale = (self.dim // self.numHeads) ** (-0.5)
self.qkv = layers.Dense(self.dim * 3, use_bias=self.qkvBias)
self.dropout = layers.Dropout(self.dropoutRate)
self.proj = layers.Dense(self.dim)
def get_config(self):
config = super(WindowAttention,self).get_config()
config.update(
{
"dim":self.dim,
"window_size":self.windowSize,
"num_heads":self.numHeads,
"qkv_bias":self.qkvBias,
"dropout_rate":self.dropoutRate
}
)
return config
def build(self, input_shape):
numWindowElements = (2 * self.windowSize[0] - 1) * (2 * self.windowSize[1] - 1) * (2 * self.windowSize[2] - 1)
self.relativePositionBiasTable = self.add_weight(
shape=(numWindowElements, self.numHeads),
initializer=tf.initializers.Zeros(),
trainable=True,name='relative_position_bias_table')
coordsD = np.arange(self.windowSize[0])
coordsH = np.arange(self.windowSize[1])
coordsW = np.arange(self.windowSize[2])
coordsMatrix = np.meshgrid(coordsD, coordsH, coordsW, indexing="ij")
# -> 3,windowSize,windowSize,windowSize
coords = np.stack(coordsMatrix)
# -> 3,(windowSize*windowSize*windowSize)
coordsFlatten = np.reshape(coords, newshape=(3, -1))
# -> 3,(windowSize*windowSize*windowSize),(windowSize*windowSize*windowSize)
relativeCoords = coordsFlatten[:, :, None] - coordsFlatten[:, None, :]
# -> (windowSize*windowSize*windowSize),(windowSize*windowSize*windowSize),3
relativeCoords = np.transpose(relativeCoords, axes=(1, 2, 0))
relativeCoords[:, :, 0] += self.windowSize[0] - 1
relativeCoords[:, :, 1] += self.windowSize[1] - 1
relativeCoords[:, :, 2] += self.windowSize[2] - 1
relativeCoords[:, :, 0] *= (2 * self.windowSize[1] - 1) * (2 * self.windowSize[2] - 1)
relativeCoords[:, :, 1] *= (2 * self.windowSize[2] - 1)
relativePositionIndex = relativeCoords.sum(-1)
self.relativePositionIndex = tf.Variable(initial_value=tf.convert_to_tensor(relativePositionIndex),
trainable=False,name='relative_position_index')
def call(self, x, mask=None):
_, size, channles = x.shape
# print('attnWinow inputs.shape:', x.shape)
headDim = channles // self.numHeads
# -> _,size,channles*3
qkvX = self.qkv(x)
# print('qkvX.shape:',qkvX.shape)
qkvX = tf.reshape(qkvX, shape=(-1, size, 3, self.numHeads, headDim))
qkvX = tf.transpose(qkvX, perm=(2, 0, 3, 1, 4))
# -> _,numHeads,(windowSize*windowSize*windowSize),headDim
q, k, v = qkvX[0], qkvX[1], qkvX[2]
# print('q.shape:', q.shape)
q = q * self.scale
k = tf.transpose(k, perm=(0, 1, 3, 2))
# -> _,numHeads,(windowSize*windowSize*windowSize),(windowSize*windowSize*windowSize)
attn = q @ k
# print('attn.shape:', attn.shape)
numWindowElements = self.windowSize[0] * self.windowSize[1] * self.windowSize[2]
relativePositionIndexFlat = tf.reshape(self.relativePositionIndex, shape=(-1,))
relativePositionBias = tf.gather(self.relativePositionBiasTable,
relativePositionIndexFlat)
# -> (windowSize*windowSize*windowSize),(windowSize*windowSize*windowSize),numHeads
relativePositionBias = tf.reshape(relativePositionBias,
shape=(numWindowElements, numWindowElements, -1))
# -> numHeads,(windowSize*windowSize*windowSize),(windowSize*windowSize*windowSize)
relativePositionBias = tf.transpose(relativePositionBias, perm=(2, 0, 1))
attn = attn + tf.expand_dims(relativePositionBias, axis=0)
if mask is not None:
# numWindows
nW = mask.get_shape()[0]
maskFloat = tf.cast(tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32)
attn = (tf.reshape(attn, shape=(-1, nW, self.numHeads, size, size)) + maskFloat)
attn = tf.reshape(attn, shape=(-1, self.numHeads, size, size))
attn = activations.softmax(attn, axis=-1)
else:
attn = activations.softmax(attn, axis=-1)
attn = self.dropout(attn)
# -> _,numHeads,(windowSize*windowSize*windowSize),headDim
qkvX = attn @ v
# -> _,(windowSize*windowSize*windowSize),numHeads,headDim
qkvX = tf.transpose(qkvX, perm=(0, 2, 1, 3))
# -> _,(windowSize*windowSize*windowSize),(numHeads*headDim)
qkvX = tf.reshape(qkvX, shape=(-1, size, channles))
qkvX = self.proj(qkvX)
qkvX = self.dropout(qkvX)
return qkvX
class SwinTransformer3D(layers.Layer):
def __init__(self,dim,num_patch:list,num_heads,window_size,shift_size,
num_mlp,qkv_bias=True,dropout_rate=0.0,**kwargs):
super(SwinTransformer3D,self).__init__(**kwargs)
self.dim = dim
self.numPatch = num_patch
self.numHeads = num_heads
self.windowSize = window_size
self.shiftSize = shift_size
self.numMlp = num_mlp
self.qkvBias = qkv_bias
self.dropoutRate = dropout_rate
self.norm1 = layers.LayerNormalization(epsilon=1e-5)
self.attn = WindowAttention(dim=self.dim,
window_size=(self.windowSize,self.windowSize,self.windowSize),
num_heads=self.numHeads,
qkv_bias=self.qkvBias,
dropout_rate=self.dropoutRate)
self.dropPath = DropPath(drop_prob=self.dropoutRate)
self.norm2 = layers.LayerNormalization(epsilon=1e-5)
self.mlp = keras.Sequential(
[
layers.Dense(self.numMlp),
layers.Activation(keras.activations.gelu),
layers.Dropout(self.dropoutRate),
layers.Dense(self.dim),
layers.Dropout(self.dropoutRate),
]
)
if min(self.numPatch) < self.windowSize:
self.shiftSize = 0
self.windowSize = min(self.numPatch)
def build(self,input_shape):
if self.shiftSize == 0:
self.attnMask = None
else:
depth,height,width = self.numPatch
dSlices = (
slice(0,-self.windowSize),
slice(-self.windowSize,-self.windowSize),
slice(-self.windowSize,None)
)
hSlices = (
slice(0, -self.windowSize),
slice(-self.windowSize, -self.windowSize),
slice(-self.windowSize, None)
)
wSlices = (
slice(0, -self.windowSize),
slice(-self.windowSize, -self.windowSize),
slice(-self.windowSize, None)
)
maskArray = np.zeros((1,depth,height,width,1))
count = 0
for d in dSlices:
for h in hSlices:
for w in wSlices:
maskArray[:,d,h,w,:] = count
count +=1
maskArray = tf.convert_to_tensor(maskArray)
# maskArray to windows
maskWindows = windowPartition(maskArray,self.windowSize)
maskWindows = tf.reshape(maskWindows,shape=[-1,self.windowSize*self.windowSize*self.windowSize])
attnMask = tf.expand_dims(maskWindows,axis=1)-tf.expand_dims(maskWindows,axis=2)
attnMask = tf.where(attnMask!=0,-100,attnMask)
attnMask = tf.where(attnMask==0,0.0,attnMask)
self.attnMask = tf.Variable(initial_value=attnMask,trainable=False,name='attn_mask')
def get_config(self):
config = super(SwinTransformer3D,self).get_config()
config.update(
{
"dim":self.dim,
"num_patch":self.numPatch,
"num_heads":self.numHeads,
"window_size":self.windowSize,
"shift_size":self.shiftSize,
"num_mlp":self.numMlp,
"qkv_bias":self.qkvBias,
"dropout_rate":self.dropoutRate
}
)
return config
def call(self,x):
depth,height,width = self.numPatch
_,numPatchesBefore,channels = x.shape
xSkip = x
x = self.norm1(x)
x = tf.reshape(x,shape=(-1,depth,height,width,channels))
if (self.shiftSize>0):
shiftedX = tf.roll(x,shift=[-self.shiftSize,-self.shiftSize,-self.shiftSize],
axis=[1,2,3])
else:
shiftedX = x
xWindows = windowPartition(shiftedX,self.windowSize)
xWindows = tf.reshape(xWindows,shape=(-1,self.windowSize*self.windowSize*self.windowSize,channels))
attnWindows = self.attn(xWindows,mask=self.attnMask)
attnWindows = tf.reshape(attnWindows,shape=(-1,self.windowSize,self.windowSize,self.windowSize,channels))
# print('attnWindows.shape:',attnWindows.shape)
shiftedX = windowReverse(attnWindows,self.windowSize,depth,height,width,channels)
# print('shiftedX.shape:',shiftedX.shape)
if self.shiftSize > 0:
x = tf.roll(shiftedX,
shift=[self.shiftSize,self.shiftSize,self.shiftSize],
axis=[1,2,3])
else:
x = shiftedX
x = tf.reshape(x,shape=(-1,depth*height*width,channels))
x = self.dropPath(x)
x = xSkip + x
xSkip = x
# print('xSkip.shape:',xSkip.shape)
x = self.norm2(x)
x = self.mlp(x)
# print('x.shape:',x.shape)
# input('zzz')
x = self.dropPath(x)
x = xSkip + x
return x
class PatchEmebdding(layers.Layer):
def __init__(self,patch_size:list,embed_dim:int,**kwargs):
super(PatchEmebdding,self).__init__(**kwargs)
self.patchSize = patch_size
self.embedDim = embed_dim
self.proj = layers.Conv3D(embed_dim,
kernel_size=patch_size,
strides=patch_size)
def build(self,input_shape):
_, d, h, w, c = input_shape
self.nPatches = (d//self.patchSize[0])*(h//self.patchSize[1])*(w//self.patchSize[2])
self.posEmbedding = self.add_weight(name="pos_embedding",
shape=[1,self.nPatches,self.embedDim],
dtype="float32",
initializer="random_normal",
trainable=True)
def get_config(self):
config = super(PatchEmebdding,self).get_config()
config.update(
{
"patch_size":self.patchSize,
"embed_dim":self.embedDim,
}
)
return config
def call(self,x):
_,d,h,w,c = x.shape.as_list()
assert (d%self.patchSize[0]==0 and h%self.patchSize[1]==0 and w%self.patchSize[2]==0), f'input.shape应该可以被patchSize整除'
x = self.proj(x)
_,d1,h1,w1,c1 = x.shape.as_list()
x = tf.reshape(x,shape=(-1,d1*h1*w1,c1))
x = x + self.posEmbedding
return x
class ResBlock(layers.Layer):
def __init__(self,filters,kernel_size=3,stride=1,padding='same',**kwargs):
super(ResBlock,self).__init__(**kwargs)
self.filters = filters
self.kernelSize = kernel_size
self.stride = stride
self.padding = padding
self.act = layers.LeakyReLU(alpha=0.01)
self.norm1 = tfalayers.InstanceNormalization()
self.norm2 = tfalayers.InstanceNormalization()
self.norm3 = tfalayers.InstanceNormalization()
self.conv1 = layers.Conv3D(self.filters, self.kernelSize, self.stride, padding=self.padding)
self.conv2 = layers.Conv3D(self.filters, self.kernelSize, self.stride, padding=self.padding)
self.conv3 = layers.Conv3D(self.filters,1)
# def build(self,input_shape):
# _,d,h,w,c = input_shape
def get_config(self):
config = super(ResBlock,self).get_config()
config.update(
{
"filters":self.filters,
"kernel_size":self.kernelSize,
"stride":self.stride,
"padding":self.padding,
}
)
return config
def call(self,x):
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.act(x)
x = self.conv2(x)
x = self.norm2(x)
if residual.shape[-1]!=x.shape[-1]:
residual = self.conv3(residual)
residual = self.norm3(residual)
x = x+residual
x = self.act(x)
return x
if __name__ == '__main__':
# w = WindowAttention(24, [7, 7, 7], 3)
# x = tf.ones(shape=(125, 343, 24))
# x1 = w(x)
# print(x1.shape)
# s = SwinTransformer(24,[32,32,32],3,4,3,4,)
# x = tf.ones(shape=(1,32*32*32,24))
# x1 = s(x)
# merge = PatchMerging([32,32,32],24)
# x2 = merge(x1)
# print(x2.shape)
pass