def mhsa_block(input_layer, input_channel): # W, H = 25, 25 W, H = int(input_layer.shape[1]), int(input_layer.shape[2]) # From 2-D to Sequence: WxHxd -> WHxd (e.g., 25x25x512 -> 1x625x512) conv = Reshape((1, WH, input_channel))(input_layer) # Position Encoding: 1x625x512 -> 1x625x512 pos_encoding = Conv2D(input_channel, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv) # Element-wise Sum: 1x625x512 conv = Add()([conv, pos_encoding]) # Query: Conv1x1 --> 1x625x512 conv_q = Conv2D(input_channel, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv) # Key: Conv1x1 --> 1x625x512 conv_k = Conv2D(input_channel, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv) # Value: Conv1x1 --> 1x625x512 conv_v = Conv2D(input_channel, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv) # Transposed Key: 1x512x612 conv_k = Permute(dims=(1, 3, 2))(conv_k) # Content-content: Query * Key_T --> 1x625x625 conv = Dot(axes=(3,2))([conv_q, conv_k]) conv = Reshape((1, WH, WH))(conv) # Softmax --> 1x625x625 conv = Softmax()(conv) # Output: Dot(1x625x625, 1x625x512) --> 1x625x512 conv = Dot(axes=(3,2))([conv, conv_v]) # From Sequence to 2-D conv = Reshape((W, H, input_channel))(conv) return conv 是直接运行还是需要调用
最新发布