实战半监督模型mean_teacher(TensorFlow2)

        大家好,今天给大家带来TensorFlow2实现mean_teacher的思路及代码。

        写程序属于业余爱好,写的不好请见谅。

        思路步骤来自这里:

https://blog.csdn.net/hjimce/article/details/80551721?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522166987502116782395357346%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=166987502116782395357346&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~first_rank_ecpm_v1~rank_v31_ecpm-2-80551721-null-null.142^v67^control,201^v3^control,213^v2^t3_esquery_v2&utm_term=%E5%8D%8A%E7%9B%91%E7%9D%A3mean&spm=1018.2226.3001.4187

        我会根据链接里“三、算法流程 “和“四、伪代码”来编写

        下面将展示做出来的效果:

这是标签照片和无标签照片(混合照片):

将混合照片放入训练好的模型里分类出来:

效果还是不错的,个别图片没有准确分类出来,可能是标签数据给少了或者训练次数少了。

下面按步骤用代码来实现:

1、准备标签数据

       首先你需要准备两组图片(这里我准备的胡桃和神子的图片),将它们放入一个文件夹里,从里面拿几张胡桃图片放入一个单独的文件夹(记为标签0),几张神子放入另一个单独文件夹(记为标签1)。这样我们便有了标签照片胡桃,和神子了.。

2、加载数据并强化数据

       先建立一个load_data.py文件,我们需要读取到胡桃标签数据和神子标签数据以及无标签数据(就是胡桃和神子混合的文件夹),完了吗?不,我们需要对图像数据增强,这样我们训练出来的模型才有更高的灵活性,识别准确性。代码如下:

import tensorflow as tf
import os


def get_image(dir):
    #将图像转为tensor
    image = tf.io.read_file(dir)
    image = tf.io.decode_image(image,channels=3)
    image = tf.broadcast_to(image,(200,200,3))
    return image

def load_label_data(dir):
    #读取标签数据
    label = os.path.join(dir, "label")
    label_1 = os.path.join(label,'hutao')
    label_2 = os.path.join(label,'shenzi')
    data1 = tf.data.Dataset.list_files(os.path.join(label_1,'*.jpg'))
    data2 = tf.data.Dataset.list_files(os.path.join(label_2,'*.jpg'))
    data1 = iter(data1.map(get_image).batch(1).repeat())
    data2 = iter(data2.map(get_image).batch(1).repeat())
    #return 给标签数据打上标签(胡桃为0标签,神子为1标签)
    return [data1,0],[data2,1]

def load_valid_data(dir):
    #读取无标签数据
    valid = os.path.join(dir,'valid')
    mix = os.path.join(valid,'mix')
    data = tf.data.Dataset.list_files(os.path.join(mix,'*.jpg'))
    data = iter(data.map(get_image).batch(1).repeat())
    #没有标签
    return [data]

def strengthen_data(img):
    #这里对图像数据进行增强
    img_generator = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1/255.,
        rotation_range=45,
        width_shift_range=50,
        height_shift_range=50,
        horizontal_flip=True,
        zoom_range=0.5
    )
    img_generator.fit(img)
    return img_generator.flow(img,batch_size=1)[0]

3、构建网络模型

         先建立一个model.py文件,根据算法流程,我们需要一个network(学生模型和老师模型共用一个network结构),损失函数L1,损失函数L2,以及一个teacher模型更新方法(因为teacher更新并不是反向传播,而是根据student模型参数来更新的)。

我们用代码来理解吧:

import tensorflow as tf
from tensorflow.keras import Model,layers,Sequential,optimizers
import numpy as np

class network(Model):
    #定义网络结构
    def __init__(self):
        super(network, self).__init__()
        def conv_1(filters):
            return Sequential(
                [layers.Conv2D(filters,3,2,'same'),
                layers.BatchNormalization(),
                layers.Activation('relu')]
            )

        def conv_2(filters):
            return Sequential(
                [layers.Conv2D(filters,3,2,'same'),
                layers.BatchNormalization(),
                layers.Activation('relu'),
                layers.Conv2D(filters,1,1,'valid'),
                layers.BatchNormalization(),
                layers.Activation('relu')]
            )

        self.model = Sequential(
            [conv_1(32),
             conv_2(64),
             conv_2(128),
             conv_2(256),
             conv_2(512),
             layers.AvgPool2D(2),
             layers.Flatten(),
             layers.Dense(2)]
        )

    def call(self,inputs):
        out = self.model(inputs)
        return out

def soft_mse(pre_1,pre_2):
    #定义L2损失函数
    x1 = tf.nn.softmax(pre_1)
    x2 = tf.nn.softmax(pre_2)
    loss = tf.losses.MSE(x1,x2)
    return loss

def update_teacher(teacher_net,student_net,i):
    #定义teacher模型更新方法
    k = min(1 - 1 / (1 + i), 0.99)
    t_wei = np.array(teacher_net.get_weights(),dtype='object')
    s_wei = np.array(student_net.get_weights(),dtype='object')
    wei = k*s_wei+(1-k)*t_wei
    teacher_net.set_weights(wei)

def is_label_data(data):
    # 判断是否为标签数据
    if len(data) == 2:
        return tf.reshape(tf.one_hot(data[1],depth=2),(1,2))
    else:
        return None

def label_loss(pre_1,label):
    #定义L1损失函数,
    #因为没说无标签数据的L1loss,所以我把无标签的L1loss设为0
    if label == None:
        return 0
    else:
        return tf.losses.categorical_crossentropy(label,pre_1,from_logits=True)

4、按照流程训练模型

           建立一个train.py文件,在一个epoch中,我们随机选取标签数据和无标签数据,对一个数据分别增强,分别输入到student和teacher里面,得到L1loss和L2loss,相加为loss,用loss更新student模型,利用student模型来更新teacher模型。多训练几次epoch。

代码如下:

from model import *
import load_data
import random
import os

if not os.path.exists('wei'):
    #权重文件夹
    os.mkdir('wei')
dir = r'C:\Users\ASUS\Desktop\mean_teacher_data'
data1,data2 = load_data.load_label_data(dir)
data3 = load_data.load_valid_data(dir)
student_net = network()
student_net.build(input_shape=(None,200,200,3))
teacher_net = network()
teacher_net.build(input_shape=(None,200,200,3))
optimizer = optimizers.Adam(0.0005)
wei = student_net.get_weights()
teacher_net.set_weights(wei)
for i in range(1500):
    #随机从标签数据和无标签数据里选择数据
    data = random.choices([data1,data2,data3],weights=[1,1,2],k=1)[0]
    label = is_label_data(data)
    data_1 = next(data[0])
    data_2 = data_1
    mydata_2 = load_data.strengthen_data(data_2)
    mydata_1 = load_data.strengthen_data(data_1)
    with tf.GradientTape() as tape:
        pre_1 = student_net(mydata_1)
        loss_1 = label_loss(pre_1,label)
        pre_2 = teacher_net(mydata_2)
        loss_2 = soft_mse(pre_1,pre_2)
        loss = loss_1+loss_2
    grad = tape.gradient(loss,student_net.trainable_variables)
    #更新student模型
    optimizer.apply_gradients(zip(grad,student_net.trainable_variables))
    #更新teacher模型
    update_teacher(teacher_net,student_net,i)
#保存数据
student_net.save_weights('wei/stu.cfc')

到这里,我们已经完成了mean_teacher模型了,剩下的就是测试模型了。

5、测试

        收工阶段,建立一个test.py文件夹,我们把混合照片放入,训练好的student模型,看看是胡桃的概率大还是神子的概率大,胡桃的话就将照片放入胡桃文件夹,神子的话就放入神子文件夹。

代码如下:

import tensorflow as tf
import model
import os

if not os.path.exists(r'C:\Users\ASUS\Desktop\mean_teacher_result'):
    os.mkdir(r'C:\Users\ASUS\Desktop\mean_teacher_result')
    os.mkdir(r'C:\Users\ASUS\Desktop\mean_teacher_result\hutao')
    os.mkdir(r'C:\Users\ASUS\Desktop\mean_teacher_result\shenzi')

student_net = model.network()
student_net.build(input_shape=(None,200,200,3))
try:
    student_net.load_weights('wei/stu2.cfc')
    print('load success')
except:
    print('load failed')

img_path = r'C:\Users\ASUS\Desktop\mean_teacher_data\valid\mix'

def fl(dir,h):
    img = tf.io.read_file(dir)
    img = tf.io.decode_image(img,channels=3)
    my_img = tf.cast(img,dtype=tf.float32)/255.
    img = tf.expand_dims(my_img,axis=0)
    out = student_net(img)
    pre = tf.nn.softmax(out)
    if pre[0][0]>pre[0][1]:
        tf.keras.preprocessing.image.save_img(os.path.join(r'C:\Users\ASUS\Desktop\mean_teacher_result\hutao',f'{h}.jpg'),my_img)
    else:
        tf.keras.preprocessing.image.save_img(os.path.join(r'C:\Users\ASUS\Desktop\mean_teacher_result\shenzi',f'{h}.jpg'),my_img)

if __name__ == '__main__':
    data = tf.data.Dataset.list_files(os.path.join(img_path,'*.jpg'))
    h = 0
    for i in data:
        fl(i,h)
        h+=1

 理性思考,共同进步,感谢大家的观看!

  • 6
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 11
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值