@zoom
from skimage import transform
import numpy as np
import re
PREFIX = 'zoom'
REGEX = re.compile(r"^" + PREFIX + "_(?P<p1x>[-0-9]+)_(?P<p1y>[-0-9]+)_(?P<p2x>[-0-9]+)_(?P<p2y>[-0-9]+)")
PAD_VALUE = 0
class Zoom:
def __init__(self, p1x, p1y, p2x, p2y):
self.p1x = p1x
self.p1y = p1y
self.p2x = p2x
self.p2y = p2y
self.code = PREFIX + str(p1x) + '_' + str(p1y) + '_' + str(p2x) + '_' + str(p2y)
def process(self, img):
h = len(img)
w = len(img[0])
crop_p1x = max(self.p1x, 0)
crop_p1y = max(self.p1y, 0)
crop_p2x = min(self.p2x, w)
crop_p2y = min(self.p2y, h)
cropped_img = img[crop_p1y:crop_p2y, crop_p1x:crop_p2x]
x_pad_before = -min(0, self.p1x)
x_pad_after = max(0, self.p2x-w)
y_pad_before = -min(0, self.p1y)
y_pad_after = max(0, self.p2y-h)
padding = [(y_pad_before, y_pad_after), (x_pad_before, x_pad_after)]
is_colour = len(img.shape) == 3
if is_colour:
padding.append((0,0)) # colour images have an extra dimension
padded_img = np.pad(cropped_img, padding, 'constant')
return transform.resize(padded_img, (h,w))
@staticmethod
def match_code(code):
match = REGEX.match(code)
if match:
d = match.groupdict()
return Zoom(int(d['p1x']), int(d['p1y']), int(d['p2x']), int(d['p2y']))
@translate
from skimage.transform import AffineTransform
from skimage import transform as tf
import re
CODE = 'trans'
REGEX = re.compile(r"^" + CODE + "_(?P<x_trans>[-0-9]+)_(?P<y_trans>[-0-9]+)")
class Translate:
def __init__(self, x_trans, y_trans):
self.code = CODE + str(x_trans) + '_' + str(y_trans)
self.x_trans = x_trans
self.y_trans = y_trans
def process(self, img):
return tf.warp(img, AffineTransform(translation=(-self.x_trans, -self.y_trans)))
@staticmethod
def match_code(code):
match = REGEX.match(code)
if match:
d = match.groupdict()
return Translate(int(d['x_trans']), int(d['y_trans']))
@rotate
from skimage import transform
import re
PREFIX = 'rot'
REGEX = re.compile(r"^" + PREFIX + "_(?P<angle>-?[0-9]+)")
class Rotate:
def __init__(self, angle):
self.angle = angle
self.code = PREFIX + str(angle)
def process(self, img):
return transform.rotate(img, -self.angle)
@staticmethod
def match_code(code):
match = REGEX.match(code)
if match:
d = match.groupdict()
return Rotate(int(d['angle']))
@noise
from skimage.util import random_noise
import re
CODE = 'noise'
REGEX = re.compile(r"^" + CODE + "_(?P<var>[.0-9]+)")
class Noise:
def __init__(self, var):
self.code = CODE + str(var)
self.var = var
def process(self, img):
return random_noise(img, mode='gaussian', var=self.var)
@staticmethod
def match_code(code):
match = REGEX.match(code)
if match:
d = match.groupdict()
return Noise(float(d['var']))
@blur
from skimage.filters import gaussian
from skimage.exposure import rescale_intensity
import re
CODE = 'blur'
REGEX = re.compile(r"^" + CODE + "_(?P<sigma>[.0-9]+)")
class Blur:
def __init__(self, sigma):
self.code = CODE + str(sigma)
self.sigma = sigma
def process(self, img):
is_colour = len(img.shape)==3
return rescale_intensity(gaussian(img, sigma=self.sigma, multichannel=is_colour))
@staticmethod
def match_code(code):
match = REGEX.match(code)
if match:
d = match.groupdict()
return Blur(float(d['sigma']))
@fliph
import numpy as np
CODE = 'fliph'
class FlipH:
def __init__(self):
self.code = CODE
def process(self, img):
return np.fliplr(img)
@staticmethod
def match_code(code):
if code == CODE:
return FlipH()
@flipv
import numpy as np
CODE = 'flipv'
class FlipV:
def __init__(self):
self.code = CODE
def process(self, img):
return np.flipud(img)
@staticmethod
def match_code(code):
if code == CODE:
return FlipV()
@counter
from multiprocessing.dummy import Lock
class Counter:
def __init__(self):
self.lock = Lock()
self._processed = 0
self._error = 0
self._skipped_no_match = 0
self._skipped_augmented = 0
def processed(self):
with self.lock:
self._processed += 1
def error(self):
with self.lock:
self._error += 1
def skipped_no_match(self):
with self.lock:
self._skipped_no_match += 1
def skipped_augmented(self):
with self.lock:
self._skipped_augmented += 1
def get(self):
with self.lock:
return {'processed' : self._processed, 'error' : self._error, 'skipped_no_match' : self._skipped_no_match, 'skipped_augmented' : self._skipped_augmented}
@main.py
import sys, os, re, traceback
from os.path import isfile
from multiprocessing.dummy import Pool, cpu_count
from counter import Counter
from ops.rotate import Rotate
from ops.fliph import FlipH
from ops.flipv import FlipV
from ops.zoom import Zoom
from ops.blur import Blur
from ops.noise import Noise
from ops.translate import Translate
from skimage.io import imread, imsave
EXTENSIONS = ['png', 'jpg', 'jpeg', 'bmp']
WORKER_COUNT = max(cpu_count() - 1, 1)
OPERATIONS = [Rotate, FlipH, FlipV, Translate, Noise, Zoom, Blur]
'''
Augmented files will have names matching the regex below, eg
original__rot90__crop1__flipv.jpg
'''
AUGMENTED_FILE_REGEX = re.compile('^.*(__.+)+\\.[^\\.]+$')
EXTENSION_REGEX = re.compile('|'.join(map(lambda n : '.*\\.' + n + '$', EXTENSIONS)))
thread_pool = None
counter = None
def build_augmented_file_name(original_name, ops):
root, ext = os.path.splitext(original_name)
result = root
for op in ops:
result += '__' + op.code
return result + ext
def work(d, f, op_lists):
try:
in_path = os.path.join(d,f)
for op_list in op_lists:
out_file_name = build_augmented_file_name(f, op_list)
if isfile(os.path.join(d,out_file_name)):
continue
img = imread(in_path)
for op in op_list:
img = op.process(img)
imsave(os.path.join(d, out_file_name), img)
counter.processed()
except:
traceback.print_exc(file=sys.stdout)
def process(dir, file, op_lists):
thread_pool.apply_async(work, (dir, file, op_lists))
if __name__ == '__main__':
if len(sys.argv) < 3:
print 'Usage: {} <image directory> <operation> (<operation> ...)'.format(sys.argv[0])
sys.exit(1)
image_dir = sys.argv[1]
if not os.path.isdir(image_dir):
print 'Invalid image directory: {}'.format(image_dir)
sys.exit(2)
op_codes = sys.argv[2:]
op_lists = []
for op_code_list in op_codes:
op_list = []
for op_code in op_code_list.split(','):
op = None
for op in OPERATIONS:
op = op.match_code(op_code)
if op:
op_list.append(op)
break
if not op:
print 'Unknown operation {}'.format(op_code)
sys.exit(3)
op_lists.append(op_list)
counter = Counter()
thread_pool = Pool(WORKER_COUNT)
print 'Thread pool initialised with {} worker{}'.format(WORKER_COUNT, '' if WORKER_COUNT == 1 else 's')
matches = []
for dir_info in os.walk(image_dir):
dir_name, _, file_names = dir_info
print 'Processing {}...'.format(dir_name)
for file_name in file_names:
if EXTENSION_REGEX.match(file_name):
if AUGMENTED_FILE_REGEX.match(file_name):
counter.skipped_augmented()
else:
process(dir_name, file_name, op_lists)
else:
counter.skipped_no_match()
print "Waiting for workers to complete..."
thread_pool.close()
thread_pool.join()
print counter.get()
@.gitignore
*.pyc
*.iml
.idea
notes.txt