tf9: PixelCNN

前一帖生成音乐,本帖生成图片。本文使用TensorFlow实现论文《Conditional Image Generation with PixelCNN Decoders》,它是基于PixelCNN架构的模型,最早出现在《Pixel Recurrent Neural Networks》一文。

使用的图片数据

我本想使用ImageNet做为图片来源,就像论文中使用的。ImageNet图像有现成的分类,抓取也容易,但是由于很多源都被防火墙屏蔽,下载速度堪忧。《OpenCV之使用Haar Cascade进行对象识别

我看到网上有很多爬妹纸图的Python脚本,额,我爬了几天几夜的妹纸图(特别暴露那种),额,我就想看看PixelCNN最后能生成什么鬼。

如果你懒的爬图片,可以使用我抓取的图片(分成两部分):

  • https://pan.baidu.com/s/1kVSA8z9 (密码: atqm)
  • https://pan.baidu.com/s/1ctbd9O (密码: kubu)
数据预处理

下载的图片分布在多个目录,把图片汇总到一个新目录:

[python]  view plain  copy
  1. import os  
  2.    
  3. old_dir = 'images'  
  4. new_dir = 'girls'  
  5. if not os.path.exists(new_dir):  
  6.     os.makedirs(new_dir)  
  7.    
  8. count = 0  
  9. for (dirpath, dirnames, filenames) in os.walk(old_dir):  
  10.     for filename in filenames:  
  11.         if filename.endswith('.jpg'):  
  12.             new_filename = str(count) + '.jpg'  
  13.             os.rename(os.sep.join([dirpath, filename]), os.sep.join([new_dir, new_filename]))  
  14.             print(os.sep.join([dirpath, filename]))  
  15.             count += 1  
  16. print("Total Picture: ", count)  

使用《open_nsfw: 基于Caffe的成人图片识别模型》剔除掉和妹子图不相关的图片,给open_nsfw输入要检测的图片,它会返回图片评级(0-1),等级越高,图片越黄越暴力。使用OpenCV应该也不难。

为了减小计算量,我把图像缩放为64×64像素:

[python]  view plain  copy
  1. import os  
  2. import cv2  
  3. import numpy as np  
  4.    
  5. image_dir = 'girls'  
  6. new_girl_dir = 'little_girls'  
  7. if not os.path.exists(new_girl_dir):  
  8.     os.makedirs(new_girl_dir)  
  9.    
  10. for img_file in os.listdir(image_dir):  
  11.     img_file_path = os.path.join(image_dir, img_file)  
  12.     img = cv2.imread(img_file_path)  
  13.     if img is None:  
  14.         print("image read fail")  
  15.         continue  
  16.     height, weight, channel = img.shape  
  17.     if height < 200 or weight < 200 or channel != 3:   
  18.         continue  
  19.     # 你也可以转为灰度图片(channel=1),加快训练速度  
  20.     # 把图片缩放为64x64  
  21.     img = cv2.resize(img, (6464))  
  22.     new_file = os.path.join(new_girl_dir, img_file)  
  23.     cv2.imwrite(new_file, img)  
  24.     print(new_file)  

去除重复图片:

[python]  view plain  copy
  1. import os  
  2. import cv2  
  3. import numpy as np  
  4.    
  5. # 判断两张图片是否完全一样(使用哈希应该要快很多)  
  6. def is_same_image(img_file1, img_file2):  
  7.     img1 = cv2.imread(img_file1)  
  8.     img2 = cv2.imread(img_file2)  
  9.     if img1 is None or img2 is None:  
  10.         return False  
  11.     if img1.shape == img2.shape and not (np.bitwise_xor(img1, img2).any()):  
  12.         return True  
  13.     else:  
  14.         return False  
  15.    
  16. # 去除重复图片  
  17. file_list = os.listdir('little_girls')  
  18. try:  
  19.     for img1 in file_list:  
  20.         print(len(file_list))  
  21.         for img2 in file_list:  
  22.             if img1 != img2:  
  23.                 if is_same_image('little_girls/'+img1, 'little_girls/'+img2) is True:  
  24.                     print(img1, img2)  
  25.                     os.remove('little_girls/'+img1)  
  26.         file_list.remove(img1)  
  27. except Exception as e:  
  28.     print(e)  

PixelCNN生成妹纸图完整代码

下面代码只实现了unconditional模型(无条件),没有实现conditional和autoencoder模型。详细信息,请参看论文。

[python]  view plain  copy
  1. # -*- coding: utf-8 -*-  
  2.    
  3. import tensorflow as tf  
  4. import numpy as np  
  5. import os  
  6. import cv2  
  7.    
  8. # 如果使用mnist数据集,把MNIST设置为True  
  9. MNIST = False  
  10.    
  11. if MNIST == True:  
  12.     from tensorflow.examples.tutorials.mnist import input_data  
  13.     data = input_data.read_data_sets('/tmp/')  
  14.     image_height = 28  
  15.     image_width = 28  
  16.     image_channel = 1  
  17.    
  18.     batch_size = 128  
  19.     n_batches = data.train.num_examples // batch_size  
  20. else:  
  21.     picture_dir = 'little_girls'  
  22.     picture_list = []  
  23.     # 建议不要把图片一次加载到内存,为了节省内存,最好边加载边使用  
  24.     for (dirpath, dirnames, filenames) in os.walk(picture_dir):  
  25.         for filename in filenames:  
  26.             if filename.endswith('.jpg'):  
  27.                 picture_list.append(os.sep.join([dirpath, filename]))  
  28.    
  29.     print("图像总数: ", len(picture_list))  
  30.    
  31.     # 图像大小和Channel  
  32.     image_height = 64  
  33.     image_width = 64  
  34.     image_channel = 3  
  35.    
  36.     # 每次使用多少样本训练  
  37.     batch_size = 128  
  38.     n_batches = len(picture_list) // batch_size  
  39.    
  40.     #图片格式对应输入X  
  41.     img_data = []  
  42.     for img_file in picture_list:  
  43.         img_data.append(cv2.imread(img_file))  
  44.     img_data = np.array(img_data)  
  45.     img_data = img_data / 255.0  
  46.     #print(img_data.shape)   # (44112, 64, 64, 3)  
  47.    
  48.    
  49. X = tf.placeholder(tf.float32, shape=[None, image_height, image_width, image_channel])  
  50.    
  51. def gated_cnn(W_shape_, fan_in, gated=True, payload=None, mask=None, activation=True):  
  52.     W_shape = [W_shape_[0], W_shape_[1], fan_in.get_shape()[-1], W_shape_[2]]  
  53.     b_shape = W_shape_[2]  
  54.    
  55.     def get_weights(shape, name, mask=None):  
  56.         weights_initializer = tf.contrib.layers.xavier_initializer()  
  57.         W = tf.get_variable(name, shape, tf.float32, weights_initializer)  
  58.    
  59.         if mask:  
  60.             filter_mid_x = shape[0]//2  
  61.             filter_mid_y = shape[1]//2  
  62.             mask_filter = np.ones(shape, dtype=np.float32)  
  63.             mask_filter[filter_mid_x, filter_mid_y+1:, :, :] = 0.  
  64.             mask_filter[filter_mid_x+1:, :, :, :] = 0.  
  65.    
  66.             if mask == 'a':  
  67.                 mask_filter[filter_mid_x, filter_mid_y, :, :] = 0.  
  68.    
  69.             W *= mask_filter   
  70.         return W  
  71.    
  72.     if gated:  
  73.         W_f = get_weights(W_shape, "v_W", mask=mask)  
  74.         W_g = get_weights(W_shape, "h_W", mask=mask)  
  75.    
  76.         b_f = tf.get_variable("v_b", b_shape, tf.float32, tf.zeros_initializer)  
  77.         b_g = tf.get_variable("h_b", b_shape, tf.float32, tf.zeros_initializer)  
  78.    
  79.         conv_f = tf.nn.conv2d(fan_in, W_f, strides=[1,1,1,1], padding='SAME')  
  80.         conv_g = tf.nn.conv2d(fan_in, W_g, strides=[1,1,1,1], padding='SAME')  
  81.         if payload is not None:  
  82.             conv_f += payload  
  83.             conv_g += payload  
  84.    
  85.         fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g))  
  86.     else:  
  87.         W = get_weights(W_shape, "W", mask=mask)  
  88.         b = tf.get_variable("b", b_shape, tf.float32, tf.zeros_initializer)  
  89.         conv = tf.nn.conv2d(fan_in, W, strides=[1,1,1,1], padding='SAME')  
  90.         if activation:   
  91.             fan_out = tf.nn.relu(tf.add(conv, b))  
  92.         else:  
  93.             fan_out = tf.add(conv, b)  
  94.    
  95.     return fan_out  
  96.    
  97. def pixel_cnn(layers=12, f_map=32):  
  98.     v_stack_in, h_stack_in = X, X  
  99.    
  100.     for i in range(layers):  
  101.         filter_size = 3 if i > 0 else 7  
  102.         mask = 'b' if i > 0 else 'a'  
  103.         residual = True if i > 0 else False  
  104.         i = str(i)  
  105.    
  106.         with tf.variable_scope("v_stack"+i):  
  107.             v_stack = gated_cnn([filter_size, filter_size, f_map], v_stack_in, mask=mask)  
  108.             v_stack_in = v_stack  
  109.    
  110.         with tf.variable_scope("v_stack_1"+i):  
  111.             v_stack_1 = gated_cnn([11, f_map], v_stack_in, gated=False, mask=mask)  
  112.    
  113.         with tf.variable_scope("h_stack"+i):  
  114.             h_stack = gated_cnn([1, filter_size, f_map], h_stack_in, payload=v_stack_1, mask=mask)  
  115.    
  116.         with tf.variable_scope("h_stack_1"+i):  
  117.             h_stack_1 = gated_cnn([11, f_map], h_stack, gated=False, mask=mask)  
  118.             if residual:  
  119.                 h_stack_1 += h_stack_in  
  120.             h_stack_in = h_stack_1  
  121.    
  122.     with tf.variable_scope("fc_1"):  
  123.         fc1 = gated_cnn([11, f_map], h_stack_in, gated=False, mask='b')  
  124.    
  125.     color = 256  
  126.     with tf.variable_scope("fc_2"):  
  127.         fc2 = gated_cnn([11, image_channel * color], fc1, gated=False, mask='b', activation=False)  
  128.         fc2 = tf.reshape(fc2, (-1, color))  
  129.    
  130.         return fc2  
  131.    
  132. def train_pixel_cnn():  
  133.     output = pixel_cnn()  
  134.    
  135.     loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(output, tf.cast(tf.reshape(X, [-1]), dtype=tf.int32)))  
  136.     trainer = tf.train.RMSPropOptimizer(1e-3)  
  137.     gradients = trainer.compute_gradients(loss)  
  138.     clipped_gradients = [(tf.clip_by_value(_[0], -11), _[1]) for _ in gradients]  
  139.     optimizer = trainer.apply_gradients(clipped_gradients)  
  140.    
  141.     with tf.Session() as sess:  
  142.         sess.run(tf.initialize_all_variables())  
  143.    
  144.         saver = tf.train.Saver(tf.trainable_variables())  
  145.    
  146.         for epoch in range(50):  
  147.             for batch in range(n_batches):  
  148.    
  149.                 if MNIST == True:  
  150.                     batch_X, _ = data.train.next_batch(batch_size)  
  151.                     batch_X = batch_X.reshape([batch_size, image_height, image_width, image_channel])  
  152.                 else:  
  153.                     batch_X = img_data[batch_size * batch : batch_size * (batch + 1)]  
  154.    
  155.                 _, cost = sess.run([optimizer, loss], feed_dict={X:batch_X})  
  156.                 print("epoch:", epoch, '  batch:', batch,'  cost:', cost)  
  157.             if epoch % 7 == 0:  
  158.                 saver.save(sess, "girl.ckpt", global_step=epoch)  
  159.    
  160. # 训练  
  161. train_pixel_cnn()  
  162.    
  163. def generate_girl():  
  164.     output = pixel_cnn()  
  165.    
  166.     predict = tf.reshape(tf.multinomial(tf.nn.softmax(output), num_samples=1, seed=100), tf.shape(X))  
  167.     #predict_argmax = tf.reshape(tf.argmax(tf.nn.softmax(output), dimension=tf.rank(output) - 1), tf.shape(X))  
  168.    
  169.     with tf.Session() as sess:   
  170.         sess.run(tf.initialize_all_variables())  
  171.    
  172.         saver = tf.train.Saver(tf.trainable_variables())  
  173.         saver.restore(sess, 'girl.ckpt-49')  
  174.    
  175.         pics = np.zeros((1*1, image_height, image_width, image_channel), dtype=np.float32)  
  176.    
  177.         for i in range(image_height):  
  178.             for j in range(image_width):  
  179.                 for k in range(image_channel):  
  180.                     next_pic = sess.run(predict, feed_dict={X:pics})  
  181.                     pics[:, i, j, k] = next_pic[:, i, j, k]  
  182.    
  183.         cv2.imwrite('girl.jpg', pics[0])  
  184.         print('生成妹子图: girl.jpg')  
  185.    
  186. # 生成图像  
  187. generate_girl()  

额,妹子图正在训练中…

补充练习:使用OpenCV提取图像中的脸,然后使用上面模型进行训练,看看能生成什么。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值