augment,zoom,blur..

https://github.com/codebox/image_augmentor/tree/master/ops

@zoom

from skimage import transform
import numpy as np
import re

PREFIX = 'zoom'
REGEX = re.compile(r"^" + PREFIX + "_(?P<p1x>[-0-9]+)_(?P<p1y>[-0-9]+)_(?P<p2x>[-0-9]+)_(?P<p2y>[-0-9]+)")
PAD_VALUE = 0

class Zoom:
    def __init__(self, p1x, p1y, p2x, p2y):
        self.p1x = p1x
        self.p1y = p1y
        self.p2x = p2x
        self.p2y = p2y
        self.code = PREFIX + str(p1x) + '_' + str(p1y) + '_' + str(p2x) + '_' + str(p2y)

    def process(self, img):
        h = len(img)
        w = len(img[0])

        crop_p1x = max(self.p1x, 0)
        crop_p1y = max(self.p1y, 0)
        crop_p2x = min(self.p2x, w)
        crop_p2y = min(self.p2y, h)

        cropped_img = img[crop_p1y:crop_p2y, crop_p1x:crop_p2x]

        x_pad_before = -min(0, self.p1x)
        x_pad_after  =  max(0, self.p2x-w)
        y_pad_before = -min(0, self.p1y)
        y_pad_after  =  max(0, self.p2y-h)

        padding = [(y_pad_before, y_pad_after), (x_pad_before, x_pad_after)]
        is_colour = len(img.shape) == 3
        if is_colour:
            padding.append((0,0)) # colour images have an extra dimension

        padded_img = np.pad(cropped_img, padding, 'constant')
        return transform.resize(padded_img, (h,w))

    @staticmethod
    def match_code(code):
        match = REGEX.match(code)
        if match:
            d = match.groupdict()
            return Zoom(int(d['p1x']), int(d['p1y']), int(d['p2x']), int(d['p2y']))

@translate

from skimage.transform import AffineTransform
from skimage import transform as tf
import re

CODE = 'trans'
REGEX = re.compile(r"^" + CODE + "_(?P<x_trans>[-0-9]+)_(?P<y_trans>[-0-9]+)")

class Translate:
    def __init__(self, x_trans, y_trans):
        self.code = CODE + str(x_trans) + '_' + str(y_trans)
        self.x_trans = x_trans
        self.y_trans = y_trans

    def process(self, img):
        return tf.warp(img, AffineTransform(translation=(-self.x_trans, -self.y_trans)))

    @staticmethod
    def match_code(code):
        match = REGEX.match(code)
        if match:
            d = match.groupdict()
            return Translate(int(d['x_trans']), int(d['y_trans']))

@rotate

from skimage import transform
import re

PREFIX = 'rot'
REGEX = re.compile(r"^" + PREFIX + "_(?P<angle>-?[0-9]+)")

class Rotate:
    def __init__(self, angle):
        self.angle = angle
        self.code = PREFIX + str(angle)

    def process(self, img):
        return transform.rotate(img, -self.angle)

    @staticmethod
    def match_code(code):
        match = REGEX.match(code)
        if match:
            d = match.groupdict()
            return Rotate(int(d['angle']))

@noise

from skimage.util import random_noise
import re

CODE = 'noise'
REGEX = re.compile(r"^" + CODE + "_(?P<var>[.0-9]+)")

class Noise:
    def __init__(self, var):
        self.code = CODE + str(var)
        self.var = var

    def process(self, img):
        return random_noise(img, mode='gaussian', var=self.var)

    @staticmethod
    def match_code(code):
        match = REGEX.match(code)
        if match:
            d = match.groupdict()
            return Noise(float(d['var']))

@blur

from skimage.filters import gaussian
from skimage.exposure import rescale_intensity
import re

CODE = 'blur'
REGEX = re.compile(r"^" + CODE + "_(?P<sigma>[.0-9]+)")

class Blur:
    def __init__(self, sigma):
        self.code = CODE + str(sigma)
        self.sigma = sigma

    def process(self, img):
        is_colour = len(img.shape)==3
        return rescale_intensity(gaussian(img, sigma=self.sigma, multichannel=is_colour))

    @staticmethod
    def match_code(code):
        match = REGEX.match(code)
        if match:
            d = match.groupdict()
            return Blur(float(d['sigma']))

@fliph

import numpy as np

CODE = 'fliph'

class FlipH:
    def __init__(self):
        self.code = CODE

    def process(self, img):
        return np.fliplr(img)

    @staticmethod
    def match_code(code):
        if code == CODE:
            return FlipH()

@flipv

import numpy as np

CODE = 'flipv'

class FlipV:
    def __init__(self):
        self.code = CODE

    def process(self, img):
        return np.flipud(img)

    @staticmethod
    def match_code(code):
        if code == CODE:
            return FlipV()

@counter

from multiprocessing.dummy import Lock

class Counter:
    def __init__(self):
        self.lock = Lock()
        self._processed = 0
        self._error = 0
        self._skipped_no_match = 0
        self._skipped_augmented = 0

    def processed(self):
        with self.lock:
            self._processed += 1

    def error(self):
        with self.lock:
            self._error += 1

    def skipped_no_match(self):
        with self.lock:
            self._skipped_no_match += 1

    def skipped_augmented(self):
        with self.lock:
            self._skipped_augmented += 1

    def get(self):
        with self.lock:
            return {'processed' : self._processed, 'error' : self._error, 'skipped_no_match' : self._skipped_no_match, 'skipped_augmented' : self._skipped_augmented}

@main.py

import sys, os, re, traceback
from os.path import isfile
from multiprocessing.dummy import Pool, cpu_count
from counter import Counter
from ops.rotate import Rotate
from ops.fliph import FlipH
from ops.flipv import FlipV
from ops.zoom import Zoom
from ops.blur import Blur
from ops.noise import Noise
from ops.translate import Translate
from skimage.io import imread, imsave

EXTENSIONS = ['png', 'jpg', 'jpeg', 'bmp']
WORKER_COUNT = max(cpu_count() - 1, 1)
OPERATIONS = [Rotate, FlipH, FlipV, Translate, Noise, Zoom, Blur]

'''
Augmented files will have names matching the regex below, eg

    original__rot90__crop1__flipv.jpg

'''
AUGMENTED_FILE_REGEX = re.compile('^.*(__.+)+\\.[^\\.]+$')
EXTENSION_REGEX = re.compile('|'.join(map(lambda n : '.*\\.' + n + '$', EXTENSIONS)))

thread_pool = None
counter = None

def build_augmented_file_name(original_name, ops):
    root, ext = os.path.splitext(original_name)
    result = root
    for op in ops:
        result += '__' + op.code
    return result + ext

def work(d, f, op_lists):
    try:
        in_path = os.path.join(d,f)
        for op_list in op_lists:
            out_file_name = build_augmented_file_name(f, op_list)
            if isfile(os.path.join(d,out_file_name)):
                continue
            img = imread(in_path)
            for op in op_list:
                img = op.process(img)
            imsave(os.path.join(d, out_file_name), img)

        counter.processed()
    except:
        traceback.print_exc(file=sys.stdout)

def process(dir, file, op_lists):
    thread_pool.apply_async(work, (dir, file, op_lists))

if __name__ == '__main__':
    if len(sys.argv) < 3:
        print 'Usage: {} <image directory> <operation> (<operation> ...)'.format(sys.argv[0])
        sys.exit(1)

    image_dir = sys.argv[1]
    if not os.path.isdir(image_dir):
        print 'Invalid image directory: {}'.format(image_dir)
        sys.exit(2)

    op_codes = sys.argv[2:]
    op_lists = []
    for op_code_list in op_codes:
        op_list = []
        for op_code in op_code_list.split(','):
            op = None
            for op in OPERATIONS:
                op = op.match_code(op_code)
                if op:
                    op_list.append(op)
                    break

            if not op:
                print 'Unknown operation {}'.format(op_code)
                sys.exit(3)
        op_lists.append(op_list)

    counter = Counter()
    thread_pool = Pool(WORKER_COUNT)
    print 'Thread pool initialised with {} worker{}'.format(WORKER_COUNT, '' if WORKER_COUNT == 1 else 's')

    matches = []
    for dir_info in os.walk(image_dir):
        dir_name, _, file_names = dir_info
        print 'Processing {}...'.format(dir_name)

        for file_name in file_names:
            if EXTENSION_REGEX.match(file_name):
                if AUGMENTED_FILE_REGEX.match(file_name):
                    counter.skipped_augmented()
                else:
                    process(dir_name, file_name, op_lists)
            else:
                counter.skipped_no_match()

    print "Waiting for workers to complete..."
    thread_pool.close()
    thread_pool.join()

    print counter.get()

@.gitignore

*.pyc
*.iml
.idea
notes.txt
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值