学习谷歌的tensorflow一直为没有足够的样本而苦恼。
最近学习到可以采用旋转倾斜角度的不同得到不同的样本库。
该方法可解决学习深度学习让而没有足够多的样本库的苦恼。
############################################################################################
#!/usr/bin/python2.7
# -*- coding: utf-8 -*-
#Author : zhaoqinghui
#Date : 2016.5.11
#Function: add image
##########################################################################################
import tensorflow as tf
import numpy as np
import math
import cv2
import sys
import os
from scipy import ndimage
import random
###########################################################################################
#设置自己的参数
###########################################################################################
training_index = './traini.txt'
newlabel_index ='./newlabel.txt'
classnum=36
maxImageNum=360
#############################################################################################
def read_traing_list():
train_image_dir = []
train_label_dir = []
reader = open(training_index)
while 1:
line = reader.readline()
tmp = line.split(" ")
if not line:
break
train_image_dir.append(tmp[0])
train_label_dir.append(tmp[1][0:-1])
#print train_image_dir[1:maxImageNum]
#print train_label_dir[1:maxImageNum]
reader.close()
return train_image_dir, train_label_dir
def distort_image():
train_image_dir, train_label_dir = read_traing_list()
label_reader = open(newlabel_index,"w")
for idx in range(len(train_image_dir)):
image_path = str(train_image_dir[idx])
image_tmp = cv2.imread(image_path,cv2.IMREAD_GRAYSCALE)
rotate_image = ndimage.rotate(image_tmp,random.randint(7,30))
rotate_image = cv2.resize(rotate_image,(28,28))
rotate_image_path = image_path[:-4]+"_1.png"
print rotate_image_path
cv2.imwrite(rotate_image_path,rotate_image)
rotate_image2 = ndimage.rotate(image_tmp,random.randint(330,355))
rotate_image2 = cv2.resize(rotate_image2,(28,28))
rotate_image_path2 = image_path[:-4]+"_2.png"
cv2.imwrite(rotate_image_path2,rotate_image2)
label_reader.write(rotate_image_path+" "+str(train_label_dir[idx])+"\n")
label_reader.write(rotate_image_path2+" "+str(train_label_dir[idx])+"\n")
label_reader.close()
print "done"
if __name__="__main__":
distort_image()