DCGAN中train中使用的加载数据集的方法

  # 将celeba转化成numpy形式

def get_image():
    # images 放的是"../Dataset/celebA/"文件夹下的所有png文件名称
    # images: ../Dataset/celebA\000005.png
    images = glob.glob('../Dataset/celete/*.png')
    print("images:", images[0])
    image_tmp = []
    for i in range(len(images)):
        image = imread(images[i])
        image = cv2.resize(image, (width, high))
        '''
        # 可以用来测试自己的image 读取是否正常
        if i > 2:
            break
        print("image",image)
        '''
        image = (image / 255. - 0.5) * 2
        # 原来图像值的范围是0-255之间 现在image的数值变到-1到1之间
        # print("image", image)
        image_tmp.append(image)
    image_data = np.array(image_tmp)
    return image_data


# 加载生成图像的噪音
def load_sample():
    sample = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
    z_samples = []
    for i in range(batch_size):
            z_samples.append(sample[i, :])
    return z_samples


# 定义获取mini_batch数据
def get_random_batch(image_all):
    # 生成一个一维数组 并且每次打乱顺序取前batch_size个数据
    image_data = np.arange(image_all.shape[0])
    np.random.shuffle(image_data)
    image_data = image_data[:batch_size]
    x_batch = image_all[image_data, :, :, :]
    return x_batch

# 定义如何保存图像
def imsave_image(images):
    size = [8,8]
    if isinstance(images, list):
        images = np.array(images)
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j * h: j * h + h, i * w: i * w + w, :] = image
    return img
# 定义线性激活函数
def lrelu(x, leak=0.2, name="lrelu"):
    return tf.maximum(x, leak*x)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值