使用Keras实现多分类输出multi-class classification(两种网络结构)

如何让一个网络同时分类一张图像的两个独立标签?

一般我们会构建一个输出网络,每一个label作为属性输出;或者构建两个分支网络,针对不同label输出。

1、数据集组成(fashion)

本人的数据集有12类,共计5547张图片。其中有6类是从网上下载获取的,剩余的6类是自己在网上爬虫分类整理得到的。

该数据集主要是有两类信息:颜色(黑色、红色、蓝色、白色)和服饰类型(牛仔裤、连衣裙、短袖、鞋子、包包),具体的数据集内容如下:

黑色连衣裙:black_dress(333张)

黑色牛仔裤:black_jeans(344张)

黑色短袖:black_shirt(436张)

黑色鞋子:black_shoe(534张)

蓝色连衣裙:blue_dress(386张)

蓝色牛仔裤:blue_jeans(356张)

蓝色短袖:blue_shirt(369张)

红色连衣裙:red_dress(384张)

红色短袖:red_shirt(332)

红色鞋子:red_shoe(486)

白色包包:white_bag(747)

白色鞋子:white_shoe(840)

2、构建网络(单输出)

2.1、采用类似vgg的网络结构(SimpleNet)

class SimpleNet(object):
    def __init__(self, input_shape, classes, finalAct="softmax"):
        #default input_shape = (width, height, channel)
        self.input_shape = input_shape
        self.classes = classes
        self.finalAct = finalAct

        #chanDim = inputShape[2]
        chanDim = -1
        if K.image_data_format() == "channels_first":
            chanDim = 1
        self.chanDim = chanDim
		
	
    def build_model(self):
        model =  Sequential()
        # CONV => RELU => POOL
        model.add(Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), padding="same", input_shape=self.input_shape))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=self.chanDim))
        model.add(MaxPooling2D(pool_size=(3, 3)))
        model.add(Dropout(0.25))

        # (CONV => RELU) * 2 => POOL
        model.add(Conv2D(64, (3, 3), padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=self.chanDim))
        model.add(Conv2D(64, (3, 3), padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=self.chanDim))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))
            
        # (CONV => RELU) * 2 => POOL
        model.add(Conv2D(128, (3, 3), padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=self.chanDim))
        model.add(Conv2D(128, (3, 3), padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=self.chanDim))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        # (CONV => RELU) * 2 => POOL
        model.add(Conv2D(256, (3, 3), padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=self.chanDim))
        model.add(Conv2D(256, (3, 3), padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(axis=self.chanDim))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        # use global average pooling instead of fc layer
        model.add(GlobalAveragePooling2D())
        model.add(Activation("relu"))
        model.add(BatchNormalization())
        model.add(Dropout(0.5))

        # softmax classifier
        model.add(Dense(self.classes))
        model.add(Activation(self.finalAct))
        model.summary()

        return model

说明:该种结构仅能识别上述12类,若是出现了某类其他类型和颜色搭配,如红色包包,则会识别错误。

在多分类中,最常用的就是softmax层。由于标签间是独立的,因此对于一个二分类问题,常用的激活函数是sigmoid函数。

在多标签分类中,大多使用binary_crossentropy损失而不是通常在多类分类中使用的categorical_crossentropy损失函数。

2.2、采用多分枝的网络结构(FashionNet)

该网络结构中一个用于识别类型,一个识别色彩。类型识别的结构可以复杂点,主要是形状识别,因此传入的图片做了灰度化处理;色彩识别比较简单,因此对应的网络结构比较简单。

该结构的好出是可以出数据中没有出现的类型,比如蓝色鞋子、红色包包等,前一个网络结构则无法识别。

class FashionNet(object):
    def __init__(self, input_shape, category_classes, color_classes, finalAct="softmax"):
        #default input_shape = (width, height, channel)
        self.input_shape = input_shape
        self.category_classes = category_classes
        self.color_classes = color_classes
        self.finalAct = finalAct

        #chanDim = inputShape[2]
        chanDim = -1
        if K.image_data_format() == "channels_first":
            chanDim = 1
        self.chanDim = chanDim
		
    def build_category_branch(self, inputs):
        # convert 3 channel(rgb) input to gray 
        x = Lambda(lambda c: tf.image.rgb_to_grayscale(c))(inputs)

        #Conv->ReLU->BN->Pool
        x = Conv2D(filters=32, kernel_size=(3,3), strides=(1,1), padding='same')(x)
        x = Activation('relu')(x)
        x = BatchNormalization(axis=self.chanDim)(x)
        x = MaxPooling2D(pool_size=(3,3))(x)

        #(CONV => RELU) * 2 => POOL
        x = Conv2D(64, (3, 3), padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization(axis=self.chanDim)(x)
        x = Conv2D(64, (3, 3), padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization(axis=self.chanDim)(x)
        x = MaxPooling2D(pool_size=(2, 2))(x)
        x = Dropout(0.25)(x)

        # (CONV => RELU) * 2 => POOL
        x = Conv2D(128, (3, 3), padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization(axis=self.chanDim)(x)
        x = Conv2D(128, (3, 3), padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization(axis=self.chanDim)(x)
        x = MaxPooling2D(pool_size=(2, 2))(x)
        x = Dropout(0.25)(x)

        # (CONV => RELU) * 2 => POOL
        x = Conv2D(256, (3, 3), padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization(axis=self.chanDim)(x)
        x = Conv2D(256, (3, 3), padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization(axis=self.chanDim)(x)
        x = MaxPooling2D(pool_size=(2, 2))(x)
        x = Dropout(0.25)(x)

        # use global average pooling instead of fc layer
        x = GlobalAveragePooling2D()(x)
        x = Activation("relu")(x)
        x = BatchNormalization()(x)
        x = Dropout(0.5)(x)

        # softmax classifier
        x = Dense(self.category_classes)(x)
        x = Activation(self.finalAct, name='category_output')(x)

        return x
	
    def build_color_branch(self, inputs):
        #Conv->ReLU->BN->Pool
        x = Conv2D(filters=16, kernel_size=(3,3), strides=(1,1), padding='same')(inputs)
        x = Activation('relu')(x)
        x = BatchNormalization(axis=self.chanDim)(x)
        x = MaxPooling2D(pool_size=(3,3))(x)

        #Conv->ReLU->BN->Pool*2
        x = Conv2D(filters=32, kernel_size=(3,3), strides=(1,1), padding='same')(x)
        x = Activation('relu')(x)
        x = BatchNormalization(axis=self.chanDim)(x)
        x = Conv2D(filters=32, kernel_size=(3,3), strides=(1,1), padding='same')(x)
        x = Activation('relu')(x)
        x = BatchNormalization(axis=self.chanDim)(x)
        x = MaxPooling2D(pool_size=(2,2))(x)
        x = Dropout(0.25)(x)

        #Conv->ReLU->BN->Pool*2
        x = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding='same')(x)
        x = Activation('relu')(x)
        x = BatchNormalization(axis=self.chanDim)(x)
        x = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding='same')(x)
        x = Activation('relu')(x)
        x = BatchNormalization(axis=self.chanDim)(x)
        x = MaxPooling2D(pool_size=(2,2))(x)
        x = Dropout(0.25)(x)

        x = Flatten()(x)
        x = Dense(128)(x)
        x = Activation('relu')(x)
        x = BatchNormalization()(x)
        x = Dropout(0.5)(x)
        x = Dense(self.color_classes)(x)
        x = Activation(self.finalAct, name='color_output')(x)
        return x 

    def build_model(self):
        input_shape = self.input_shape
        inputs = Input(shape=input_shape)
        category_branch = self.build_category_branch(inputs) 
        color_branch = self.build_color_branch(inputs) 

        model = Model(inputs=inputs, outputs=[category_branch, color_branch])
        model.summary()
        return model	

3、模型训练

针对两种不同的方式,训练代码中的函数做了如下区分:

#! -*- coding:utf-8 

# import the necessary packages
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.preprocessing.image import img_to_array
from sklearn.preprocessing import MultiLabelBinarizer,LabelBinarizer
from sklearn.model_selection import train_test_split
from cnn import SimpleNet
#from cnn import SmallerInceptionNet
from cnn import FashionNet
import matplotlib.pyplot as plt
from imutils import paths
import numpy as np
import argparse
import random
import pickle
import cv2
import os
from PIL import Image

# grab the image paths and randomly shuffle them
def load_data(data_dir, img_size):
    print("[INFO] loading images...")
    if not os.path.exists(data_dir):
        return None
    imagePaths = sorted(list(paths.list_images(data_dir)))
    random.seed(42)
    random.shuffle(imagePaths)

    datas = []
    labels = []
    for imagePath in imagePaths:
        image = cv2.imread(imagePath, cv2.IMREAD_UNCHANGED)
        if image is None:
           print(imagePath)
           continue
        # convert 8depth to 24 depth
        if len(image.shape)==2:
            with Image.open(imagePath) as img:
                rgb_img = img.convert('RGB')
                image = cv2.cvtColor(np.asarray(rgb_img), cv2.COLOR_RGB2BGR)
        elif len(image.shape)==3: 
            if image.shape[2]==4:
                image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
            elif image.shape[2]==1:
                image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

        image = cv2.resize(image, img_size)
        image = img_to_array(image)
        datas.append(image)

        label = imagePath.split(os.path.sep)[-2].split("_")
        labels.append(label)

    # scale the raw pixel intensities to the range [0, 1]
    datas = np.array(datas, dtype="float") / 255.0
    labels = np.array(labels)
    return datas, labels

def load_data_multilabels(data_dir, img_size):
    print("[INFO] loading images...")
    if not os.path.exists(data_dir):
        return None
    imagePaths = sorted(list(paths.list_images(data_dir)))
    random.seed(42)
    random.shuffle(imagePaths)

    datas = []
    category_labels = []
    color_labels = []
    for imagePath in imagePaths:
        image = cv2.imread(imagePath)
        if image is None:
           print(imagePath)
           continue
        if image.shape[2]==4:
            image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
        image = cv2.resize(image, img_size)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = img_to_array(image)
        datas.append(image)

        (color_label, category_label) = imagePath.split(os.path.sep)[-2].split("_")
        category_labels.append(category_label)
        color_labels.append(color_label)

    # scale the raw pixel intensities to the range [0, 1]
    datas = np.array(datas, dtype="float") / 255.0
    category_labels = np.array(category_labels)
    color_labels = np.array(color_labels)
    return datas, category_labels, color_labels

# binarize the labels using scikit-learn's special multi-label
def binarize_multilabels_and_save(labels, path):
    mlb = MultiLabelBinarizer()
    labels = mlb.fit_transform(labels)
    print(labels[:6])
    print('labels shape:', labels.shape)
    for (i, label) in enumerate(mlb.classes_):
        print("{}. {}".format(i + 1, label))
    with open(path, "wb") as f:
        f.write(pickle.dumps(mlb))
    return labels, len(mlb.classes_)
	
def binarize_labels_and_save(category_labels, color_labels, category_path, color_path):
    category_lb = LabelBinarizer()
    color_lb = LabelBinarizer()
    category_labels = category_lb.fit_transform(category_labels)
    color_labels = color_lb.fit_transform(color_labels)

    # loop over each of the possible class labels and show them
    for (i, label) in enumerate(category_lb.classes_):
        print("category {}. {}".format(i + 1, label))

    for (i, label) in enumerate(color_lb.classes_):
        print("color {}. {}".format(i + 1, label))

    with open(category_path, "wb") as f:
        f.write(pickle.dumps(category_lb))

    with open(color_path, "wb") as f:
        f.write(pickle.dumps(color_lb))
    return category_labels, color_labels, len(category_lb.classes_), len(color_lb.classes_)
	
# model_type='SimpleNet'  'SmallerInceptionNet'
def train_model(datas, labels, classes, finalAct='sigmoid', model_type='SimpleNet'):
    EPOCHS = 20
    INIT_LR = 1e-3
    BATCH_SIZE = 32
    INPUT_SHAPE = (96, 96, 3)
    (trainX, testX, trainY, testY) = train_test_split(datas, labels, test_size=0.2, random_state=42)
    if model_type == 'SimpleNet':
        simpleNet = SimpleNet(INPUT_SHAPE, classes, finalAct)
        model = simpleNet.build_model()
    else:
        smallerInceptionNet = SmallerInceptionNet()
        model = smallerInceptionNet.build_model(INPUT_SHAPE, classes, finalAct)

    opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
    model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"])

    history = model.fit(trainX, trainY, batch_size=BATCH_SIZE,
                        epochs=EPOCHS, verbose=1,
						validation_data=(testX,testY))
                        
    model.save('trained_mode/' + '{}.h5'.format(model_type))

def train_fashionnet_model(datas, category_labels, color_labels, category_classes, color_classes, finalAct='softmaxt'):
    EPOCHS = 30
    INIT_LR = 1e-3
    BATCH_SIZE = 32
    INPUT_SHAPE = (96, 96, 3)
    (trainX, testX, trainCategoryY, testCategoryY, trainColorY, testColorY) = train_test_split(datas, category_labels, color_labels, test_size=0.2, random_state=42)

    fashionNet = FashionNet(INPUT_SHAPE, category_classes=category_classes, 
                               color_classes=color_classes, finalAct=finalAct)
    model = fashionNet.build_model()
    losses = { 'category_output':'categorical_crossentropy', 'color_output':'categorical_crossentropy' }
    loss_weights = {'category_output':1.0, 'color_output':1.0}

    opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
    model.compile(optimizer=opt,loss=losses, loss_weights=loss_weights, metrics=["accuracy"])

    history = model.fit(trainX, {'category_output': trainCategoryY, 'color_output':trainColorY},
						batch_size=BATCH_SIZE, epochs=EPOCHS,
						verbose=1,
                        validation_data=(testX, {'category_output': testCategoryY, 'color_output':testColorY}))

    model.save('trained_mode/' + '{}.h5'.format('FashionNet'))

    plot_fashionnet_loss_acc(history, EPOCHS)

def plot_loss_acc(history, EPOCHS):
    plt.style.use("ggplot")
    plt.figure()
    N = EPOCHS
    plt.plot(np.arange(0, N), history.history["loss"], label="train_loss")
    plt.plot(np.arange(0, N), history.history["val_loss"], label="val_loss")
    plt.plot(np.arange(0, N), history.history["acc"], label="train_acc")
    plt.plot(np.arange(0, N), history.history["val_acc"], label="val_acc")
    plt.title("Training Loss and Accuracy")
    plt.xlabel("Epoch #")
    plt.ylabel("Loss/Accuracy")
    plt.legend(loc="upper left")
    plt.savefig('plot_loss_acc.png')

def plot_fashionnet_loss_acc(history, EPOCHS):
    loss_names = ['loss', 'category_output_loss', 'color_output_loss']
    plt.style.use("ggplot")
    (fig, ax) = plt.subplots(3, 1, figsize=(13, 13))

    for (i, l) in enumerate(loss_names):
        title = 'Loss for {}'.format(l) if l != 'loss' else 'Total loss'
        ax[i].set_title(title)
        ax[i].set_xlabel('Epoch #')
        ax[i].set_ylabel('Loss')
        ax[i].plot(np.arange(0, EPOCHS), history.history[l], label=l)
        ax[i].plot(np.arange(0, EPOCHS), history.history["val_"+l], label="val_"+l)
        ax[i].legend()
    plt.savefig('plot_fashionnet_losses.png')
    plt.close()
    '''
    accuray_names = ['category_output_acc', 'color_output_acc']
    plt.style.use("ggplot")
    (fig, ax) = plt.subplots(2, 1, figsize=(8, 8))
    for (i, l) in enumerate(accuray_names):
        title = 'Accuray for {}'.format(l)
        ax[i].set_title(title)
        ax[i].set_xlabel('Epoch #')
        ax[i].set_ylabel('Accuray')
        ax[i].plot(np.arange(0, EPOCHS), history.history[l], label=l)
        ax[i].plot(np.arange(0, EPOCHS), history.history["val_"+l], label="val_"+l)
        ax[i].legend()
    plt.savefig('plot_fashionnet_accs.png')
    plt.close()
    '''
	
def main():
    data_dir = './dataset'
    img_size = (96, 96)
    label_dir = './labels'
    if not os.path.exists(label_dir):
        os.mkdir(label_dir)

    '''
    datas, labels = load_data(data_dir, img_size)
    labels, classes= binarize_multilabels_and_save(labels, os.path.join(label_dir, 'multi-label.pickle'))
    train_model(datas, labels, classes, finalAct='sigmoid', model_type='SimpleNet')

    '''
    datas, category_labels, color_labels = load_data_multilabels(data_dir, img_size)
    category_path = os.path.join(label_dir, 'category.pickle')
    color_path = os.path.join(label_dir, 'color.pickle')
    category_labels, color_labels, category_classes, color_classes = binarize_labels_and_save(category_labels, color_labels, category_path, color_path)
    train_fashionnet_model(datas, category_labels, color_labels, category_classes, color_classes, finalAct='softmax')
    
if __name__ == '__main__':
    main()

4、测试部分代码

# import the necessary packages
from keras.preprocessing.image import img_to_array
from keras.models import load_model
import numpy as np
import argparse
import imutils
import pickle
import cv2
import os
import tensorflow as tf

# load the image
# model_type = None, FashionNnet
def load_image(img_path, model_type=None):
    image = cv2.imread(img_path)
    output = imutils.resize(image, width=400)
    if model_type == 'FashionNnet':
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # pre-process the image for classification
    image = cv2.resize(image, (96, 96))
    image = image.astype("float") / 255.0
    image = img_to_array(image)
    image = np.expand_dims(image, axis=0)
    return image, output

def load_trained_model(img, model_path, labelbin_path):
    label_lb = pickle.loads(open(labelbin_path, "rb").read())
    model = load_model(model_path)
    proba = model.predict(img)[0]
    
    idxs = np.argsort(proba)[::-1][:2]
    label_1 = label_lb.classes_[idxs[0]]
    label_2 = label_lb.classes_[idxs[1]]
    
    proba_1 = proba[idxs[0]]
    proba_2 = proba[idxs[1]]
    
    result = (label_1, proba_1, label_2, proba_2)
    return result
    


# load the trained convolutional neural network 
def load_trained_fashionnet_model(img, model_path, categorybin_path, colorbin_path):
    category_lb = pickle.loads(open(categorybin_path, "rb").read())
    color_lb = pickle.loads(open(colorbin_path, "rb").read())
    
    model = load_model(model_path, custom_objects={'tf':tf})
    (category_proba, color_proba) = model.predict(img)

    category_idx = category_proba[0].argmax()
    color_idx = color_proba[0].argmax()
    category_label = category_lb.classes_[category_idx]
    color_label = color_lb.classes_[color_idx]
    
    category_proba = category_proba[0][category_idx]
    color_proba = color_proba[0][color_idx]
    result = (category_label, category_proba, color_label, color_proba)
    return result

def show_result(img, result):
    (label_1, proba_1, label_2, proba_2) = result
    text1 = "{}: {:.2f}%".format(label_1, proba_1*100)
    text2 = "{}: {:.2f}%".format(label_2, proba_2*100)
    
    cv2.putText(img, text1, (10, 25), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
    cv2.putText(img, text2, (10, 55), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

    # show the output image
    cv2.imshow("Output", img)
    cv2.waitKey(2000)
    cv2.destroyAllWindows()
    
def show_fashionnet_result(img, result):
    (category_label, category_proba, color_label, color_proba) = result
    category_text = "category: {}: {:.2f}%".format(category_label, category_proba*100)
    color_text = "color: {}: {:.2f}%".format(color_label, color_proba*100)

    cv2.putText(img, category_text, (10, 25), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
    cv2.putText(img, color_text, (10, 55), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

    # show the output image
    cv2.imshow("Output", img)
    cv2.waitKey(2000)
    cv2.destroyAllWindows()


if __name__=='__main__':
    test_dir = './examples'
    #model_type = 'FashionNnet'
    model_type = None
    for img in os.listdir(test_dir):
        img_path = os.path.join(test_dir, img)
        if model_type == None:
            image,output = load_image(img_path)
            model_path = 'trained_mode/SimpleNet.h5'
            labelbin_path = './labels/multi-label.pickle'
            result = load_trained_model(image, model_path, labelbin_path)
            show_result(output, result)
        elif model_type == 'FashionNnet':
            image, output = load_image(img_path, model_type)
            model_path = 'trained_mode/FashionNet.h5'
            categorybin_path = './labels/category.pickle'
            colorbin_path = './labels/color.pickle'
            
            result = load_trained_fashionnet_model(image, model_path, categorybin_path, colorbin_path)
            show_fashionnet_result(output, result)

5、数据和详细完整代码

代码地址:https://github.com/zhangwei147258/fashion_mutil_label_classifier_keras

数据地址:https://pan.baidu.com/s/11LoY2H5shADwiQwPuhB6ng 提取码:pg7d 

  • 3
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
以下是使用Keras实现Multi-Head Self-Attention的示例代码: ```python import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers class MultiHeadSelfAttention(keras.layers.Layer): def __init__(self, embed_dim, num_heads): super(MultiHeadSelfAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads if embed_dim % num_heads != 0: raise ValueError( f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}" ) self.projection_dim = embed_dim // num_heads self.query_dense = layers.Dense(embed_dim) self.key_dense = layers.Dense(embed_dim) self.value_dense = layers.Dense(embed_dim) self.combine_heads = layers.Dense(embed_dim) def attention(self, query, key, value): score = tf.matmul(query, key, transpose_b=True) dim_key = tf.cast(tf.shape(key)[-1], tf.float32) scaled_score = score / tf.math.sqrt(dim_key) weights = tf.nn.softmax(scaled_score, axis=-1) output = tf.matmul(weights, value) return output, weights def separate_heads(self, x, batch_size): x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim)) return tf.transpose(x, perm=[0, 2, 1, 3]) def call(self, inputs): batch_size = tf.shape(inputs)[0] query = self.query_dense(inputs) key = self.key_dense(inputs) value = self.value_dense(inputs) query = self.separate_heads(query, batch_size) key = self.separate_heads(key, batch_size) value = self.separate_heads(value, batch_size) attention, weights = self.attention(query, key, value) attention = tf.transpose(attention, perm=[0, 2, 1, 3]) concat_attention = tf.reshape(attention, (batch_size, -1, self.embed_dim)) output = self.combine_heads(concat_attention) return output ``` 该代码定义了一个名为MultiHeadSelfAttention的Keras层。在初始化中,它接受输入的嵌入维度和要使用的头数。在调用中,它首先通过三个全连接层(Dense)将输入嵌入到查询、键和值中。然后,它将每个嵌入向量分成多个头,并在每个头中执行自注意力计算。最后,它将每个头的输出连接起来,并通过另一个全连接层将它们合并成一个单一的向量。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值