使用vgg16网络完成多标记分类任务代码(tensorflow框架)

github下载链接:

https://github.com/A-mockingbird/VGG16ForMultilabelClassification

# 1.读取多标记分类数据集,将数据集分割,存储为tfrecords格式

新建文件

ReadMultilabelDataset.py

import json
import os
import random
import numpy as np
import tensorflow as tf
from PIL import Image

slim = tf.contrib.slim

def get_multilabel_dataset_dict(imagedir, class_name, train_percentage=8):
    """
    读取数据集,数据集存储格式:不同标记的图像放在一个文件下面,例如有三个分类:dog,house,car
    既有dog又有house的图像文件名为:dog+house
    返回字典,存储测试集和训练集图像及其标签和文件名
    """
    rootdir = imagedir
    #获取全部子文件名(标记)
    category = [x[1] for x in os.walk(imagedir)][0]
    dataset = {}
    #遍历全部子文件
    for j, cat in enumerate(category):
        #获取标签,例如dog+house就变成[1, 1, 0]
        sub_label = get_label(class_name, cat)
        subdir = os.path.join(rootdir, cat)
        imagelist = os.listdir(subdir)
        number = len(imagelist)
        train_dataset = []
        test_dataset = []
        print('{}: {}'.format(cat, sub_label))
        for i, image in enumerate(imagelist):
        #遍历图像
            #随机分为训练集和测试集
            r = random.randint(0, number)
            if r < number / 10.0 *train_percentage:
                train_dataset.append(image)
            else:
                test_dataset.append(image)
        #存入字典中
        dataset[cat] = {
            'dir':subdir,
            'label':sub_label, 
            'train':train_dataset,
            'test':test_dataset
        }
    return dataset

def get_label(class_name, cat):   
    #标签转换,转换向量形式
    label = []
    cls = cat.split('+')
    for i, x in enumerate(class_name):
        if x in cls:
            label.append(1)
        else:
            label.append(0)
    return label 

def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def create_tfrecord_example(label, imagefile, resize=None):
    #创建tfrecord的输入example
    #读取图像
    pil_image = Image.open(imagefile)
    #resize图像
    if resize != None:
        pil_image = pil_image.resize(resize)
    #将读取的图像转换为二进制格式
    bytes_image = pil_image.tobytes()
    #创建example(包含图像信息和标签信息)
    example = tf.train.Example(features=tf.train.Features(feature={
        'label': int64_list_feature(label), 
        'image': bytes_feature(bytes_image)
        #'format': bytes_feature('jpg')
  
  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值