利用TensorFlow定义Dataset类

step1: 生成标签文件label.csv

step2: 设计Dataset类的成员

step3: 类的初始化函数

step4: 记录的解析

  1. 图像增强
  2. 图像处理

step5: code

step6: show

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from common.slim.preprocessing.inception_preprocessing import preprocess_for_train, apply_with_random_selector

class MyDataset(object):
    """
        read_one_sample-->read_one_batch->getitem
    """
    def __init__(self, df, img_aug=False, shuffle=False, is_test=False, \
                 batch_size=16, num_threads=4, batch_join=False, num_epochs=None):
        self.df, self.img_aug, self.shuffle, self.is_test = df, img_aug, shuffle, is_test
        self.batch_size, self.num_threads, self.batch_join =  batch_size, num_threads, batch_join
        self.num_epochs = num_epochs
        self.img_dir = os.path.join(IMAGES_DIR, 'test' if is_test else 'train')
        
        if DATASET == 'grby':
            # for GRBY
            self.mean = np.array([0.05062989, 0.07459677, 0.05089854, 0.07629084])
            self.std = np.array([0.11019831, 0.14376172, 0.14938936, 0.14281237])
        else:
            # for GRBY_landmarks
            self.mean = np.array([0.06306252, 0.10562667, 0.07949321, 0.10858839])
            self.std = np.array([0.80647908, 0.82780679, 0.82680373, 0.8309615 ])
        self.graph = tf.Graph()
        with self.graph.as_default():
            self.__one_sample = self.read_one_sample(df, img_aug, shuffle, is_test, num_epochs) #shuffle放到这里,这样batch队列可以设置的小一点
            self.__batch_samples = self.read_batch_samples(self.__one_sample, batch_size, num_threads, \
                                                           batch_join=batch_join)
            self.open_session()
    def __len__(self,):
        return int(np.ceil(self.df.shape[0]/self.batch_size))
    def __getitem__(self, index):
        res = self.sess.run(self.__batch_samples)
        #res = [torch.Tensor(val) for val in res]
        return res
    def open_session(self):
        config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(graph=self.graph, config=config)
        self.sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
        self.coord = tf.train.Coordinator()
        self.threads = tf.train.start_queue_runners(sess=self.sess, coord=self.coord)
    def close_session(self):
        self.coord.request_stop() 
        self.coord.join(self.threads)
        self.sess.close()
    def read_one_sample(self, df, img_aug, shuffle, is_test, num_epochs=None):
        """
            读取一条记录
        """
        # step1: 生成样本队列,队列每次产生一个样本的摘要信息
        tensor_Id, tensor_img_name = tf.constant(value=df['No'].values, dtype=tf.int64), tf.constant(value= df['Id'].values, dtype=tf.string)
        tensor_list = [tensor_Id, tensor_img_name]
        if is_test is False:
            tensor_label = tf.constant(value=df.values[:,2:], dtype=tf.int32)
            tensor_list.append(tensor_label)
        record = tf.train.slice_input_producer(tensor_list, shuffle=shuffle, capacity=self.df.shape[0],\
                                              num_epochs=num_epochs)
        # step2:读取图像, 并处理图像
        img_path = tf.string_join([self.img_dir, record[1]+'.png'], separator='/') 
        img = self.read_img(img_path)
        if img_aug is True:
            img = self.img_aug_func(img, IMAGE_SIZE[0], IMAGE_SIZE[1])
        else:
            if img.dtype != tf.float32:
                img = tf.image.convert_image_dtype(img, dtype=tf.float32)
            img = tf.image.resize_images(img, size=IMAGE_SIZE, method=0)
        img = self.standardization(img, )
        img = tf.transpose(img, (2, 0, 1)) #channels first
        
        #step3: 返回一条记录
        img.set_shape([IMAGE_CHANNELS, *IMAGE_SIZE])
        record[1] = img
        if is_test is False:
            record[-1].set_shape([28,])
        return record
    def read_img(self, img_path):
        img = tf.read_file(img_path)
        img = tf.image.decode_png(img, channels=4)  
        img.set_shape([512, 512, 4])
        img = tf.slice(img, begin=[0, 0, 4-IMAGE_CHANNELS], size=[512, 512, IMAGE_CHANNELS])
        return img    
    def img_aug_func(self, img, height, width):
        """
            补充:
                step0: rotate
            原始:
                step1: random sample box
                step2: resize-->target size
                step3: random flip
                step4: color change:
                    (1)random_contrast: 对每个通道的像素 x' = (x-mean)*k+mean(RGB, RGBA图像均可)
                    (2)random_brightness: 对图片的每个通道的每个像素增加或着将去一个常数(RGB,RGBA图像均可)
        """
        img = apply_with_random_selector(img, lambda x, method: tf.image.rot90(x, method),
                                        num_cases=4)
        img = preprocess_for_train(img, height, width, bbox=None,
                         fast_mode=True,
                         scope=None,
                         add_image_summaries=False, img_channels=IMAGE_CHANNELS)
        return img
    def standardization(self, img):
        """
        """
        if img.dtype != tf.float32:
            img = tf.image.convert_image_dtype(img, dtype=tf.float32)
        img = (img-self.mean[-IMAGE_CHANNELS:])/self.std[-IMAGE_CHANNELS:]
        return img
    def read_batch_samples(self, one_sample, batch_size, num_threads, batch_join):
        capacity = 500 + batch_size * 3
        if batch_join is False:
            batch_samples = tf.train.batch(tensors=one_sample, \
                                           batch_size=batch_size, num_threads=num_threads, \
                                           capacity=capacity, allow_smaller_final_batch=True)
        else:
            batch_samples = tf.train.batch_join(tensors_list=[one_sample] * num_threads, \
                                                    batch_size=batch_size, capacity=capacity,
                                                    allow_smaller_final_batch=True)
        return batch_samples
    def visualize(self, ):
        record = self.__getitem__(0)
        ids = record[0]
        imgs = record[1]
        n_rows, n_cols = 4, 4
        plt.figure(figsize=(n_rows*4, n_cols*4))
        for i in range(n_rows):
            for j in range(n_cols):
                plt.subplot(n_rows, n_cols, i*n_cols+j+1)
                plt.imshow(np.transpose(imgs[i*n_cols+j], [1,2,0])[:,:,-3:])
                plt.title(ids[i*n_cols+j])
        plt.tight_layout()
        plt.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值