结构分析
PatchEmbed
读取输入x,进行一次二维卷积:embed_dim = 96, patch_size = 4
filters=embed_dim, kernel_size=patch_size, strides=patch_size, padding='SAME'
进行一次reshape:[B, H, W, C] -> [B, H*W, C]
进行一次layer_norm(1e-6)
Dropout(略)
BasicLayer
(2, 2, 6, 2)4个stage
create_mask
SwinTransformerBlock
2 2 6 2个数
一次layer_norm(1e-6)
一次reshape->[B, H, W, C]
偶数次(第2、4、6...次)进行 cyclic shift;奇数次 attn_mask = None
一次window_partition
一次reshape->[nW*B, Mh*Mw, C]
一次WindowAttention
计算relative_position_index
qkv—dense
一次reshape + transpose
矩阵乘一次
一次gather + reshape + transpose
一次attn = attn + expand_dims
若mask存在,则再reshape两次
一次softmax
一次矩阵乘 + transpose + reshape
一次dense
一次reshape -> [nW*B, Mh, Mw, C]
一次window_reverse
一次reverse cyclic shift(依然是偶数次执行)
一次reshape->[B, H * W, C]
一次FFN
y1 = layer_norm(1e-6)
x = x + mlp(y1)
MLP:两层dense,一层gelu
downsample(PatchMerging)
除了最后一层都有
进行一次layer_norm(1e-6)
reduce_mean(axis = 1)
进行一次1000输出的dense
使用SPDZ改写要点
很多tf的函数库中没有,需要使用SPDZ语法自行编写。
tf.transpose
通过临时数组将相应类型Tensor的位置进行替换。
tf.reshape
先使用count(-1)进行判断,以确保新形状中的-1数量小于等于1。再通过reduce方法计算x_elements,最后
new_shape[new_shape.index(-1)] = x_elements // abs(reduce(mul, new_shape))
gelu函数
def gelu(x):
return x*ml.approx_sigmoid(x*1.702)
#ml是spdz库中的机器学习模块
layer_norm层
使用基础运算实现:求和、开根号、求均值等。
未完待续