使用batch_size加载celeba数据集的时候,怎么保证前(batch_size//2)图像是男性 后(batch_size//2)是女性

前言:论文中需要加载数据集的时候 保证样本前面都是男性,后面都是女性 OK 功能是实现了 但是代码有点乱

 

自己懂就可以了 有问题的话 再问吧

import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image

from scipy.misc import imread, imsave, imresize

import cv2
import glob
from tqdm import tqdm
import math
batch_size = 100
z_dim = 100

# 加载测试时所需要的要本
def load_sample():
    # 测试之后数据y_samples前50个为[1,0] 后50个为[0,1]
    #images = glob.glob('../abc_gan_data/celeb_use_in_test//*.jpg')
    images = glob.glob('../Dataset/celebA/*.png')

    read_path = '../Dataset/celebA/'

    image_name_test = []
    for i in range(len(images)):
        tmp = images[i][18:]
        image_name_test.append(tmp)
    print("已经进入到load_sample函数之后  样本的image_name是:",image_name_test[0:10])

    tags = {}

    # 暂时存放在这个地方
    X_all = []
    Y_all = []

    with open('../Dataset/text_data/list_attr_celeba.txt', 'r') as f:
        lines = f.readlines()
        all_tags = lines[1].strip('\n').split()
        real_image_number = 0
        inverse_image_number = 0
        # print("确认一下all_tags的类型和具体的属性值",type(all_tags),all_tags)
        for i in range(2, len(lines)):
            line = lines[i].strip('\n').split()
            if int(line[all_tags.index(target) + 1]) == 1 and real_image_number < (batch_size//2) and line[0] in image_name_test:
                real_image_number = real_image_number +1
                tags[line[0]] = [1, 0]
                image_path = read_path + line[0]
                #print("image_pathimage_pathimage_pathimage_pathimage_path",image_path)
                image = imread(image_path)
                h = image.shape[0]
                w = image.shape[1]

                if h > w:
                    image = image[h // 2 - w // 2: h // 2 + w // 2, :, :]
                else:
                    image = image[:, w // 2 - h // 2: w // 2 + h // 2, :]

                image = cv2.resize(image, (64, 64))
                # image = (image / 255. - 0.5) * 2
                image = (image / 255.)
                X_all.append(image)

                image_name = line[0]
                Y_all.append(tags[image_name])
            elif int(line[all_tags.index(target) + 1]) == -1 and inverse_image_number < (batch_size//2) \
                    and line[0] in image_name_test  and real_image_number>= (batch_size//2):
                tags[line[0]] = [0, 1]
                inverse_image_number = inverse_image_number + 1
                image_path = read_path + line[0]
                print("image_pathimage_pathimage_pathimage_pathimage_path",image_path)
                image = imread(image_path)
                h = image.shape[0]
                w = image.shape[1]

                if h > w:
                    image = image[h // 2 - w // 2: h // 2 + w // 2, :, :]
                else:
                    image = image[:, w // 2 - h // 2: w // 2 + h // 2, :]

                image = cv2.resize(image, (64, 64))
                # image = (image / 255. - 0.5) * 2
                image = (image / 255.)
                X_all.append(image)
                image_name = line[0]
                Y_all.append(tags[image_name])
            elif inverse_image_number > (batch_size//2) and real_image_number> (batch_size//2):
                #print("加载样本的时候 最后的形状是:",inverse_image_number, real_image_number)
                break
            else:
                #print("加载样本的时候 最后的形状是:",inverse_image_number, real_image_number)
                continue
        print("加载样本的时候 最后的形状是:",inverse_image_number, real_image_number)

    print("加载数据的长度", len(images), len(tags))
    X_all = np.array(X_all)
    print("在load_sample 数据集之中最后返回的数据集的形状是:", X_all.shape)
    Y_all = np.array(Y_all)
    print("在load_sample 数据集之中最后返回的数据集Y_all的形状是:", Y_all.shape)
    return X_all, Y_all


# 定义一个批量画图操作
def montage(images):
    if isinstance(images, list):
        images = np.array(images)
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    if len(images.shape) == 4 and images.shape[3] == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 3)) * 0.5
    elif len(images.shape) == 4 and images.shape[3] == 1:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 1)) * 0.5
    elif len(images.shape) == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1)) * 0.5
    else:
        raise ValueError('Could not parse image shape of {}'.format(images.shape))

    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
                1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
    return m

z_samples, y_samples = load_sample()

test_samples = (test_samples + 1) / 2
print("test_samples 的形状", test_samples.shape)
imgs = [img[:, :, :] for img in test_samples]
test_samples = montage(imgs)
plt.axis('off')
imsave(os.path.join(sample_dir, 'real_sample_%5d.jpg' % i), test_samples)
                         
                    
                
               
               

 

 

加载图像的结果:

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值