图像采集数据集整理和扩充方案(含代码)

数据集概述

  1. 对于多类别图片数据集采集,在采集完成之后,采集设备可能是不同的手机或者摄像头,因此获取的最后图片数据集尺寸大小,分辨率可能差异性很大,对之后的模型训练有很大的影响;因此我们需要将数据集中的图片,重新处理至同一分辨率和大小。
  2. 如果处理完成数据集数量还不满足要求,在数量差异较小的时候可以通过图片扩增的手段增加数据集。

常用图像扩增所采用的变换

  1. 一定程度内的随机旋转、平移、缩放、裁剪、填充、左右翻转;
  2. 对图像中的像素添加噪声扰动。常见的有椒盐噪声、高斯白噪声;
  3. 颜色变换。在图像的RGB颜色空间上添加增量;
  4. 改变图片的亮度、清晰度、对比度、锐度等;
  5. 除此之外,还有采样算法SMTE,生成对抗网络GAN等都可以进行图像扩充;
    ------声明:摘选自《百面机器学习》,不错的书,推荐------

代码

Python >= 3.6,Opencv2

  1. 尺寸一致:
    将指定文件夹下的所有图片resize至同一尺寸,并保存在指定文件夹下
# -*-coding = utf-8 -*-
"""
流程:
1.读取指定文件夹所有文件(必须都是图片)
2.进行resize,并存储在指定文件夹下
修改值:
path_read: 需要进行修改的图片存储的文件夹
path_write: 修改后的图片存储的文件夹,必须为空,会对图片重新编号00000-09999
target_size:[x, y] 修改后文件的尺寸
"""
import os
import cv2


if __name__ == "__main__":
    path_read = "D:/pic_old/"
    path_write = "D:/pic_new/"
    target_size = [512, 512]
    image_list = [x for x in os.listdir(path_read)]
    for num, img in enumerate(image_list):
        print(num, img)
        image = cv2.imread(path_read+img, cv2.IMREAD_COLOR)
        # print(path_read+"/"+img)
        new_image = cv2.resize(image, (target_size[0], target_size[1]), interpolation=cv2.INTER_CUBIC)
        image_dir = path_write+str(num).zfill(5)+'.jpg'
        cv2.imwrite(image_dir, new_image)
  1. 图片变换:
    流程:
    (1)需要修改的参数:
    path_read: 读取原始数据集图片的位置;
    path_write: 图片扩增后存放的位置;
    picture_size: 图片之后存储的尺寸;
    enhance_hum: 需要通过扩增手段增加的图片数量
    (2)扩增手段:
    Image_flip:
    #翻转图片;随机旋转翻转方向,垂直/水平/垂直+水平
    Image_traslation:
    #平移图片,随机选择平移方向,指定平移像素100(可以修改),抽取原始图片像素点填补平移后空白区域;
    Image_rotate:
    #旋转图片,随机从rotate_angle列表中抽取旋转角度
    Image_noise:
    #添加噪声,随机选择高斯噪声或椒盐噪声;且高斯噪声的方差与椒盐噪声的比例都是随机抽取;
# -*-coding = utf-8 -*-
"""
1. Image_flip:翻转图片
2. Image_traslation:平移图片
3. Image_rotate:旋转图片
4. Image_noise:添加噪声
"""
import os
import cv2
import numpy as np
from random import choice
import random

def Image_flip(img):
    """
    :param img:原始图片矩阵
    :return: 0-垂直; 1-水平; -1-垂直&水平
    """
    if img is None:
        return
    paras = [0, 1, -1]
    img_new = cv2.flip(img, choice(paras))
    return img_new

def Image_traslation(img):
    """
    :param img: 原始图片矩阵
    :return: [1, 0, 100]-宽右移100像素; [0, 1, 100]-高下移100像素
    """
    paras_wide = [[1, 0, 100], [1, 0, -100]]
    paras_height = [[0, 1, 100], [0, 1, -100]]
    rows, cols = img.shape[:2]
    img_shift = np.float32([choice(paras_wide), choice(paras_height)])
    border_value = tuple(int(x) for x in choice(choice(img)))
    img_new = cv2.warpAffine(img, img_shift, (cols, rows), borderValue=border_value)
    return img_new

def Image_rotate(img):
    """
    :param img:原始图片矩阵
    :return:旋转中心,旋转角度,缩放比例
    """
    rows, cols = img.shape[:2]
    rotate_core = (cols/2, rows/2)
    rotate_angle = [60, -60, 45, -45, 90, -90, 210, 240, -210, -240]
    paras = cv2.getRotationMatrix2D(rotate_core, choice(rotate_angle), 1)
    border_value = tuple(int(x) for x in choice(choice(img)))
    img_new = cv2.warpAffine(img, paras, (cols, rows), borderValue=border_value)
    return img_new

def Image_noise(img):
    """
    :param img:原始图片矩阵
    :return: 0-高斯噪声,1-椒盐噪声
    """
    paras = [0, 1]
    gaussian_class = choice(paras)
    noise_ratio = [0.05, 0.06, 0.08]
    if gaussian_class == 1:
        output = np.zeros(img.shape, np.uint8)
        prob = choice(noise_ratio)
        thres = 1 - prob
        #print('prob', prob)
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                rdn = random.random()
                if rdn < prob:
                    output[i][j] = 0
                elif rdn > thres:
                    output[i][j] = 255
                else:
                    output[i][j] = img[i][j]
        return output
    else:
        mean = 0
        var=choice([0.001, 0.002, 0.003])
        #print('var', var)
        img = np.array(img/255, dtype=float)
        noise = np.random.normal(mean, var**0.5, img.shape)
        out = img + noise
        if out.min() < 0:
            low_clip = -1
        else:
            low_clip = 0
        out = np.clip(out, low_clip, 1.0)
        out = np.uint8(out*255)
        return out

if __name__ == "__main__":
    """
    path_read: 读取原始数据集图片的位置;
    path_write:图片扩增后存放的位置;
    picture_size:图片之后存储的尺寸;
    enhance_hum: 需要通过扩增手段增加的图片数量
    """
    path_read = "D:/pic_old/"
    path_write = "D:/pic_new/"
    enhance_num = 500
    image_list = [x for x in os.listdir(path_read)]
    existed_img = len(image_list)
    while enhance_num > 0:
        img = choice(image_list)
        image = cv2.imread(path_read+img, cv2.IMREAD_COLOR)
        algorithm = [1, 2, 3, 4]
        random_process = choice(algorithm)
        if random_process == 1:
            image = Image_flip(image)
        elif random_process == 2:
            image = Image_traslation(image)
        elif random_process == 3:
            image = Image_rotate(image)
        else:
            image = Image_noise(image)
        image_dir = path_write+str(enhance_num+existed_img-1).zfill(5)+'.jpg'
        cv2.imwrite(image_dir, image)
        enhance_num -= 1

3. Python根据关键词百度图片爬虫
这类代码用的是网上其他人的,为了防止侵权,贴出代码来源: https://www.cnblogs.com/zishengY/articles/9371765.html
为了方便你们用,我把代码移植过来了,只需要修改ROOT_DIR、KEYWORD和MAX_NUM就行,分别代表图片存放路径、关键词和需要下载的数量。

# coding:utf-8

import os
import re
import urllib
import shutil
import requests
import itertools


# ------------------------ Hyperparameter ------------------------

ROOT_DIR = 'D:/pic_save/'

# 存放所下载图片的文件夹
SAVE_DIR = ROOT_DIR
# 如有多个关键字,需用空格进行分隔
KEYWORD = ''
# 保存后的图片格式
SAVE_TYPE = '.jpg'
# 需要下载的图片数量
MAX_NUM = 2000


# ------------------------ URL decoding ------------------------
str_table = {
    '_z2C$q': ':',
    '_z&e3B': '.',
    'AzdH3F': '/'
}

char_table = {
    'w': 'a',
    'k': 'b',
    'v': 'c',
    '1': 'd',
    'j': 'e',
    'u': 'f',
    '2': 'g',
    'i': 'h',
    't': 'i',
    '3': 'j',
    'h': 'k',
    's': 'l',
    '4': 'm',
    'g': 'n',
    '5': 'o',
    'r': 'p',
    'q': 'q',
    '6': 'r',
    'f': 's',
    'p': 't',
    '7': 'u',
    'e': 'v',
    'o': 'w',
    '8': '1',
    'd': '2',
    'n': '3',
    '9': '4',
    'c': '5',
    'm': '6',
    '0': '7',
    'b': '8',
    'l': '9',
    'a': '0'
}
char_table = {ord(key): ord(value) for key, value in char_table.items()}

# ------------------------ Encoding ------------------------
def decode(url):
    for key, value in str_table.items():
        url = url.replace(key, value)
    return url.translate(char_table)

# ------------------------ Page scroll down ------------------------
def buildUrls():
    word = urllib.parse.quote(KEYWORD)
    url = r"http://image.baidu.com/search/acjson?tn=resultjson_com&ipn=rj&ct=201326592&fp=result&queryWord={word}&cl=2&lm=-1&ie=utf-8&oe=utf-8&st=-1&ic=0&word={word}&face=0&istype=2nc=1&pn={pn}&rn=60"
    urls = (url.format(word=word, pn=x) for x in itertools.count(start=0, step=60))
    return urls

re_url = re.compile(r'"objURL":"(.*?)"')

# ------------------------ Get imgURL ------------------------
def resolveImgUrl(html):
    imgUrls = [decode(x) for x in re_url.findall(html)]
    return imgUrls

# ------------------------ Download imgs ------------------------
def downImgs(imgUrl, dirpath, imgName):
    filename = os.path.join(dirpath, imgName)
    try:
        res = requests.get(imgUrl, timeout=15)
        if str(res.status_code)[0] == '4':
            print(str(res.status_code), ":", imgUrl)
            return False
    except Exception as e:
        print(e)
        return False
    with open(filename + SAVE_TYPE, 'wb') as f:
        f.write(res.content)

# ------------------------ Check save dir ------------------------
def mkDir():
    try:
        shutil.rmtree(SAVE_DIR)
    except:
        pass
    os.makedirs(SAVE_DIR)


# ------------------------ Main ------------------------
if __name__ == '__main__':

    print('\n\n', '= = ' * 25, ' Keyword Spider ', ' = =' * 25, '\n\n')
    mkDir()
    urls = buildUrls()
    idx = 0
    for url in urls:
        html = requests.get(url, timeout=10).content.decode('utf-8')
        imgUrls = resolveImgUrl(html)
        # Ending if no img
        if len(imgUrls) == 0:
            break
        for url in imgUrls:
            downImgs(url, SAVE_DIR, '{:>05d}'.format(idx + 1))
            print('  {:>05d}'.format(idx + 1))
            idx += 1
            if idx >= MAX_NUM:
                break
        if idx >= MAX_NUM:
            break
    print('\n\n', '= = ' * 25, ' Download ', idx, ' pic ', ' = =' * 25, '\n\n')

这里没有写出STOME和GAN扩充数据集的方法,这两类方法太复杂,代码我也不会哈哈哈哈哈哈。不过数据集最好能够实地采集,或者从网上收集;如果是自己通过已有数据集扩充的,最好不要超过原数据集图片数量,不然容易造成过拟合。

展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 游动-白 设计师: 上身试试
应支付0元
点击重新获取
扫码支付

支付成功即可阅读