SwinTransformer模型转化:pytorch模型转keras。,工作感悟

本文介绍了如何将基于PyTorch实现的SwinTransformer模型转换为Keras版本,详细展示了从qkv计算、注意力矩阵到FFN的各个步骤,并探讨了转换过程中的关键细节和技术挑战。
摘要由CSDN通过智能技术生成

coords_h = np.arange(self.window_size[0])

coords_w = np.arange(self.window_size[1])

coords = np.stack(np.meshgrid(coords_h, coords_w, indexing=“ij”)) # [2, Mh, Mw]

coords_flatten = np.reshape(coords, [2, -1]) # [2, Mh*Mw]

[2, MhMw, 1] - [2, 1, MhMw]

relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, MhMw, MhMw]

relative_coords = np.transpose(relative_coords, [1, 2, 0]) # [MhMw, MhMw, 2]

relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0

relative_coords[:, :, 1] += self.window_size[1] - 1

relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

relative_position_index = relative_coords.sum(-1) # [MhMw, MhMw]

self.relative_position_index = tf.Variable(tf.convert_to_tensor(relative_position_index),

trainable=False,

dtype=tf.int32,

name=“relative_position_index”)

def call(self, x, mask=None, training=None):

“”"

Args:

x: input features with shape of (num_windowsB, MhMw, C)

mask: (0/-inf) mask with shape of (num_windows, WhWw, WhWw) or None

training: whether training mode

“”"

[batch_sizenum_windows, MhMw, total_embed_dim]

B_, N, C = x.shape

qkv(): -> [batch_sizenum_windows, MhMw, 3 * total_embed_dim]

qkv = self.qkv(x)

reshape: -> [batch_sizenum_windows, MhMw, 3, num_heads, embed_dim_per_head]

qkv = tf.reshape(qkv, [B_, N, 3, self.num_heads, C // self.num_heads])

transpose: -> [3, batch_sizenum_windows, num_heads, MhMw, embed_dim_per_head]

qkv = tf.transpose(qkv, [2, 0, 3, 1, 4])

[batch_sizenum_windows, num_heads, MhMw, embed_dim_per_head]

q, k, v = qkv[0], qkv[1], qkv[2]

transpose: -> [batch_sizenum_windows, num_heads, embed_dim_per_head, MhMw]

multiply -> [batch_sizenum_windows, num_heads, MhMw, Mh*Mw]

attn = tf.matmul(a=q, b=k, transpose_b=True) * self.scale

relative_position_bias(reshape): [MhMwMhMw,nH] -> [MhMw,Mh*Mw,nH]

relative_position_bias = tf.gather(self.relative_position_bias_table,

tf.reshape(self.relative_position_index, [-1]))

relative_position_bias = tf.reshape(relative_position_bias,

[self.window_size[0] * self.window_size[1],

self.window_size[0] * self.window_size[1],

-1])

relative_position_bias = tf.transpose(relative_position_bias, [2, 0, 1]) # [nH, MhMw, MhMw]

attn = attn + tf.expand_dims(relative_position_bias, 0)

if mask is not None:

mask: [nW, MhMw, MhMw]

nW = mask.shape[0] # num_windows

attn(reshape): [batch_size, num_windows, num_heads, MhMw, MhMw]

mask(expand_dim): [1, nW, 1, MhMw, MhMw]

attn = tf.reshape(attn, [B_ // nW, nW, self.num_heads, N, N]) + tf.expand_dims(tf.expand_dims(mask, 1), 0)

attn = tf.reshape(attn, [-1, self.num_heads,

  • 14
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值