图片风控NSFW(not suit for work)-2 基于tf2模型微调

直接使用yahoo开源的模型open_nsfw,不能满足业务需求,需要对模型进行重新训练。本篇主要是对模型进行训练 .
(在上篇博客已经讲述了怎么将原始模型转换为tensorflow2模型)

思路

1 将开源雅虎nsfw模型转换为 tensorflow2,见tensorflow2模型重构
2 准备训练样本,正负样本 (比例4:1~1:4之间)
3 数据增强
4 模型训练
5 模型保存
6 模型部署 (java部署)

1 数据准备

训练数据格式如下,其中positiveSapmle为正样本,negetiveSample目录中为负样本,
“”"
|-path
  |-positiveSapmle
  |-negetiveSample
“”"

# path为样本目录,labelName为positiveSapmle,或者negetiveSample
data_dirs=[]
data_labels=[]
path="./sample/"
def get_data_paths_lables(path,labelName):
    path_label=os.path.join(path,labelName)
    label= 1 if labelName=="positiveSapmle" else 0  
    path_list=os.listdir(path_label)
    data_dirs.extend( [os.path.join(path,labelName,name) for name in path_list])
    data_labels.extend([label for i in path_list])

2 模型训练

  • 1 加载模型,使用getModel ,在上篇博客已经实现了怎么tf2加载yahoo开源模型。

  • 2 数据加载
    格式如下
    |-path
      |-positiveSapmle
      |-negetiveSample

  • 3 数据增强
    翻转,旋转,增加对比度等

  • 4 模型训练

import os 
from sklearn.model_selection import train_test_split
import tensorflow as tf
from nsfwmodel import getModel
from image_utils import create_tensorflow_image_loader,__tf_jpeg_process
from  image_utils import  create_yahoo_image_loader

def load_image(input_type=1,image_loader="yahoo"):
    if input_type == 1:
        print('TENSOR...')
        if image_loader == IMAGE_LOADER_TENSORFLOW:
            print('IMAGE_LOADER_TENSORFLOW...')
            fn_load_image = create_tensorflow_image_loader()
        else:
            print('create_yahoo_image_loader')
            fn_load_image = create_yahoo_image_loader()
    elif input_type == 2:
        print('BASE64_JPEG...')
        import base64
        fn_load_image = lambda filename: np.array([base64.urlsafe_b64encode(open(filename, "rb").read())])
    return fn_load_image

def imageToTensor(inputs,input_type=1):
    if input_type == 1:
        input_tensor = inputs
    elif input_type == 2:
        from image_utils import load_base64_tensor
        input_tensor = load_base64_tensor(inputs)
    else:
        raise ValueError("invalid input type '{}'".format(input_type))
    return input_tensor



# 1 数据准备
input_type=1
image_loader= "tensorflow"
IMAGE_LOADER_TENSORFLOW = "tensorflow"
IMAGE_LOADER_YAHOO = "yahoo"
# 图片加载器
fn_load_image=load_image(input_type,image_loader)

#训练样本及标签加载
get_data_paths_lables(path,"negetiveSample")  
get_data_paths_lables(path,"positiveSapmle")    
train_data_dirs,test_data_dirs,train_data_labels,test_data_labels=train_test_split(data_dirs,data_labels,test_size=0.2)
 
 
# 数据处理(数据增强,数据变化)
def load_preprosess_image(path, label):
    image = tf.io.read_file(path)
    image = __tf_jpeg_process(image)
    image=imageToTensor(image, input_type)
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_brightness(image, 0.3)
    image = tf.image.random_contrast(image, 0, 1)
    image = tf.cast(image, tf.float32)

    label = tf.reshape(label, [1])
    return image, label

# 模型训练
BATCH_SIZE=32
train_image_ds = tf.data.Dataset.from_tensor_slices((train_data_dirs, train_data_labels)).map(load_preprosess_image)
train_dataset=train_image_ds.shuffle(10000).batch(BATCH_SIZE)
test_image_ds = tf.data.Dataset.from_tensor_slices((test_data_dirs, test_data_labels)).map(load_preprosess_image)
test_dataset=train_image_ds.batch(BATCH_SIZE)
model = getModel()
model.summary()
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4),loss=tf.keras.losses.SparseCategoricalCrossentropy()
             ,metrics=["acc"])
model.fit(train_dataset,epochs=50,validation_data=test_dataset,)

3 模型预测及保存

# 模型预测
model.predict(fn_load_image("./images/ALqhFyWOTw004_1.jpg"))
# 返回:array([[0.97107196, 0.02892802]], dtype=float32)

# 模型保存
model.save_weights("./model/nsfw_finetune.weight")

4 保存为原始的.npy格式 (非必要)

原始开源模型参数:open_nsfw-weights.npy

import numpy as np
npweight_tf1=np.load("./open_nsfw-weights.npy", allow_pickle=True,encoding="latin1").item()
name_dict=dict(zip(['variance', 'scale', 'offset', 'mean','weights', 'biases'],["moving_variance",'gamma','beta','moving_mean','kernel','bias']))
auto_to_manu_dict=dict(zip(["moving_variance",'gamma','beta','moving_mean','kernel','bias'],['variance', 'scale', 'offset', 'mean','weights', 'biases']))
mweight=model.weights
targweight={}
for nwkey,nwvalue in list(npweight_tf1.items()):
    # numpy  value keys 
    nwkeys=  [ name_dict[j] for j in nwvalue.keys()]
    for tfweigh in mweight:  # 模型weight 
        targweight[nwkey]=targweight.get(nwkey,{})
        if  nwkey in tfweigh.name: # numpy  key 在 模型中            
            for jj in nwkeys:
                if jj in tfweigh.name:
                    targweight[nwkey][auto_to_manu_dict[jj]]=tfweigh.numpy()
targweight
np.save("./data/open_nsfw-weights_new.npy",targweight)
  • 7
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
使用:网络需要在图像和输出概率(评分0-1)之间过滤不适合工作的图片。评分<0.2表示图像具有较高概率是安全的。评分>0.8表明极有可能是不适合工作(NSFW)图像。我们建议开发者根据用例和图像类型的不同选择合适的阈值。根据使用情况、定义以及公差的不同会产生误差。理想情况下,开发人员应该创建一个评价集,根据“什么是安全的”对他们的应用程序进行定义,然后适合ROC曲线选择一个合适的阈值。结果可以通过微调你的数据/ uscase /定义NSFW模型的改进。我们不提供任何结果的准确性保证。使用者适度地结合机器学习解决方案将有助于提高性能。模型描述:我们将不适合工作的图片NSFW)作为数据集中的积极对象,适合工作的图片作为消极对象来进行训练。所有这些被训练得图片都被打上了特定的标签。所以由于数据本身的原因,我们无法发布数据集或者其他信息。我们用非常不错的名字叫“CaffeOnSpark”的架构给“Hadoop”带来深度学习算法,并且使用Spark集群来进行模型训练的实验。在此非常感谢 CaffeOnSpark 团队。深度模型算法首先在 ImageNet 上生成了1000种数据集,之后我们调整不适合工作(NSFW)的数据集比例。我们使用了50 1by2的残差网络生成网络模型模型通过 pynetbuilder 工具以及复制残余网络的方法会产生50层网络(每层网络只有一半的过滤器)。你可以从这里获取到更多关于模型产生的信息。更深的网络或者具有更多过滤器的网络通常会更精确。我们使用剩余(residual)网络结构来训练模型,这样可以提供恰到好处的精确度,同样模型在运行以及内存上都能保持轻量级。 标签:opennsfw

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值