keras多任务学习multi-task learning

keras多任务学习背景

给定输入狗图片,希望得到类别和年龄这两个分类task的结果,模型主要包括

  1. 模型共享一个backbone,这里选择使用resnet,并且用在imagenet训练的结果作为pretrain model,加载模型,也就是这里的my_new_model
  2. 设计类别和年龄这两个task的分类器. 这里类别分类器为category一共有11类,采用softmax, 年龄分类器为age,一共有3类,都是采用全连接
  3. 整体的模型结构就是model,backbone提取的特征分别送进category和age输出不同task的分类概率。model.complie分别设置loss的种类(交叉熵损失,smooth-l1,MSEloss等),multi-task的loss权重,以及评价的metric
resnet_weights_path='resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
my_new_model = Sequential()
my_new_model.add(ResNet50(include_top=False, pooling='avg', weights=None))
my_new_model.add(Activation('relu'))
my_new_model.add(Dropout(0.5))
my_new_model.summary()
img1 = Input(shape=(224, 224, 3), name='img')
feature = my_new_model(img1)
category = Dense(11, activation='softmax',name='category_out1')(feature)
age = Dense(3,activation='softmax',name='age_out2')(feature)
model = Model(inputs=[img1], outputs=[category, age])
model.compile(optimizer='sgd',
              loss={
                  'category_out1': 'categorical_crossentropy',
                  'age_out2': 'categorical_crossentropy'
              },
              loss_weights={
                  'category_out1': 1.,
                  'age_out2': 1.
              },
              metrics=['accuracy'])

数据的generator,每次的iterator需要输出 图像和对应的task的label。这里每次__get_item__返回图片的tensor和类别的标签长度为11的one-hot,年龄长度为3的one-hot

import keras
import numpy as np
import cv2


class MultiTaskGenerator(keras.utils.Sequence):
    def __init__(self,img_files=None, labels=None,age_labels=None,batch_size=32,n_classes=11,shuffle=True,dim=(224,224,3)):
        cls_dict = {'bixiong': 0, 'chaiquan': 1, 'demu': 2, 'fadou': 3, 'guibing': 4, 'jinmao': 5, 'jiwawa': 6, 'keka': 7, 'labu': 8, 'xuenarui': 9, 'yueke': 10}
        self.dim = dim
        self.batch_size = batch_size
        self.img_files = img_files
        self.labels = labels
        self.age_labels = age_labels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.cls2id=cls_dict
        self.age2id={'old':0,'small':1,'teen':2}
        self.on_epoch_end()
        
    def __len__(self):
        return int(len(self.img_files)/self.batch_size)
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.labels))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
    def __getitem__(self,index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        X, y,age=self.__data_generation(indexes)
        return X,[y,age]
        
    def __data_generation(self,list_IDs_temp):
        X = np.empty((self.batch_size,*self.dim))
        y = np.empty((self.batch_size),dtype=int)
        age = np.empty((self.batch_size),dtype=int)
        for i, ID in enumerate(list_IDs_temp):
            img = cv2.imread(self.img_files[ID]).astype('float')
            label = self.labels[ID]
            age_ = self.age_labels[ID]
            img = cv2.resize(img,(224,224))
            img = img/255
            X[i,] = img
            y[i] = self.cls2id[label]
            age[i] = self.age2id[age_]

            
        return X,keras.utils.to_categorical(y,num_classes=11),keras.utils.to_categorical(age,3)
    

训练过程

filepath="multi_task-{epoch:02d}-{val_age_out2_acc:.2f}.hdf5"
tensorboard = TensorBoard(log_dir='./logs', histogram_freq=0,
                          write_graph=True, write_images=False)
checkpoint= ModelCheckpoint(filepath, monitor='val_age_out2_acc', verbose=1, save_best_only=True, mode='max')
model.fit_generator(generator=train_gen,
                   validation_data= valid_gen,
                           steps_per_epoch=40,
                    validation_steps = 4,
                           callbacks = [checkpoint,tensorboard],
                           epochs=40                      
                            )

 

  • 2
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值