tensorflow多任务训练代码

70 篇文章 5 订阅
11 篇文章 1 订阅

单任务代码:由以下单任务组成两个相同的双任务

import cv2
#正则匹配使用:
import re
import os
#此库用于拷贝,删除,移动,复制以及解压缩
import shutil
import numpy as np
import h5py
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
#用于将普通标签转为独热向量
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
#如此读取图像,直接返回numpy.ndarray
#img=cv2.imdecode(np.fromfile("C:/Users/104005162/Desktop/企业微信截图_20220212102256.png",np.uint8),-1)
#print(img.shape)
##转换为bgr图片,注意此时是PNG图片,不能用矩阵直接转换!
##img=img[:,:,::-1]
##bgra
#img=cv2.cvtColor(img,cv2.COLOR_BGRA2RGB)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import resnet
from tensorflow.keras.utils import plot_model
from tensorflow.keras.layers import Input,GlobalAvgPool2D,Dense,Dropout,Lambda,Conv2D
from tensorflow.keras.models import Model
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint,TensorBoard,EarlyStopping
def readPictureByPath(path):
    img=cv2.imdecode(np.fromfile(path,np.uint8),-1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #return np.expand_dims(img,0)
    return img
import cv2
#正则匹配使用:
import re
import os
#此库用于拷贝,删除,移动,复制以及解压缩
import shutil
import numpy as np
import h5py
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
#用于将普通标签转为独热向量
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
#如此读取图像,直接返回numpy.ndarray
#img=cv2.imdecode(np.fromfile("C:/Users/104005162/Desktop/企业微信截图_20220212102256.png",np.uint8),-1)
#print(img.shape)
##转换为bgr图片,注意此时是PNG图片,不能用矩阵直接转换!
##img=img[:,:,::-1]
##bgra
#img=cv2.cvtColor(img,cv2.COLOR_BGRA2RGB)
def readPictureByPath(path):
    img=cv2.imdecode(np.fromfile(path,np.uint8),-1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #return np.expand_dims(img,0)
    return img
#加载数据
xData=[]
yLabel=[]
with h5py.File("myTextData", 'r') as d:
    xData=np.array(d['x'])
    yLabel=np.array(d['label'])
    print(type(d['label']))

#以四个一样的任务进行演示:识别服装类型:
#输入进行额外命名
inputData=Input(shape=(28, 28, 3))
x=Conv2D(10,kernel_size=(5,5),activation='relu')(inputData)
x=GlobalAvgPool2D()(x)
output=Dense(4,activation="softmax")(x)
model=Model(inputData,output)
model.summary()
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss=tf.keras.losses.CategoricalCrossentropy(),metrics=tf.keras.metrics.categorical_accuracy)
history=model.fit(x=xData,y=yLabel,epochs=100,validation_split=0.2,batch_size=32)
#绘制代际曲线图
plt.plot(history.history['val_categorical_accuracy'])
plt.show()

转为多任务训练:

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import resnet
from tensorflow.keras.utils import plot_model
from tensorflow.keras.layers import Input,GlobalAvgPool2D,Dense,Dropout,Lambda,Conv2D
from tensorflow.keras.models import Model
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint,TensorBoard,EarlyStopping
def readPictureByPath(path):
    img=cv2.imdecode(np.fromfile(path,np.uint8),-1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #return np.expand_dims(img,0)
    return img
import cv2
#正则匹配使用:
import re
import os
#此库用于拷贝,删除,移动,复制以及解压缩
import shutil
import numpy as np
import h5py
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
#用于将普通标签转为独热向量
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
#如此读取图像,直接返回numpy.ndarray
#img=cv2.imdecode(np.fromfile("C:/Users/104005162/Desktop/企业微信截图_20220212102256.png",np.uint8),-1)
#print(img.shape)
##转换为bgr图片,注意此时是PNG图片,不能用矩阵直接转换!
##img=img[:,:,::-1]
##bgra
#img=cv2.cvtColor(img,cv2.COLOR_BGRA2RGB)
def readPictureByPath(path):
    img=cv2.imdecode(np.fromfile(path,np.uint8),-1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #return np.expand_dims(img,0)
    return img
#加载数据
xData=[]
yLabel=[]
with h5py.File("C:/Users/25360/Desktop/model/myTextData", 'r') as d:
    xData=np.array(d['x'])
    yLabel=np.array(d['label'])
#变更为双输出
#(2,157,4)
yLabel= np.array([yLabel, yLabel])
#(2,157,4)
yLabel=(yLabel[0],yLabel[1])
print(yLabel[0])


#以四个一样的任务进行演示:识别服装类型:
#输入进行额外命名

inputData=Input(shape=(28, 28, 3))
x=Conv2D(10,kernel_size=(5,5),activation='relu')(inputData)
x=GlobalAvgPool2D()(x)
output1=Dense(4,activation="softmax",name="outPut1")(x)
output2=Dense(4,activation="softmax",name="outPut2")(x)
model=Model(inputData,[output1,output2])
model.summary()
#统一一种损失函数:
#model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss=tf.keras.losses.CategoricalCrossentropy(),metrics=tf.keras.metrics.categorical_accuracy)
#分开计算的损失函数:转换成字典的形式,重点训练输入2
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss={"outPut1":tf.keras.losses.CategoricalCrossentropy(),
                                                              "outPut2":tf.keras.losses.CategoricalCrossentropy()},
              metrics={"outPut1":tf.keras.metrics.categorical_accuracy,"outPut2":tf.keras.metrics.categorical_accuracy},
              loss_weights={"outPut1":1,"outPut2":10})
history=model.fit(x=xData,y=yLabel,epochs=100,validation_split=0.2,batch_size=32)
#绘制代际曲线图
plt.plot(history.history['val_outPut1_categorical_accuracy'],c='red')
plt.plot(history.history['val_outPut2_categorical_accuracy'],c='blue')
plt.legend(["val_outPut1","val_outPut2"])
plt.show()
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

颢师傅

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值