python 多线程并发编程(生产者、消费者模式),边读图像,边处理图像,处理完后保存图像实现提高处理效率

需求

本次的需求是边读图像,边处理图像(各种变组合),处理完后还要把处理好的图像保存到指定的文件夹。而且图像也挺多的,如果按顺序一个一个处理,那肯定要不少时间。所以就想到了多线程并发编程。

实现

先导入本次需要用到的包

import os
import threading
from queue import Queue
import cv2

一些辅助函数

如下函数是得到指定后缀的文件
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')

def get_all_files(base, extensions):
    """
    get all files in extensions from base folder, it's a generator
    """
    for root, _, files in sorted(os.walk(base, followlinks=True)):
        for file in sorted(files):
            if file.endswith(extensions):
                yield os.path.join(root, file)


def get_all_images(base, image_extensions):
    """get all images"""
    return get_all_files(base, image_extensions)
如下的函数一个是读图像,一个是把RGB转成BGR
def default_loader_cv2(path):
    return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
    
def rgb_2_bgr(img):
    return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

下面是主要的几个处理函数

def load_image(target_dir, source_file):
    """load image here"""
    target_file = get_save_path(target_dir, source_file)
    img = default_loader_cv2(source_file)
    return (target_file, img)


def transform(stain_normalizer, img):
    """
    Description:
        - transform image method, basic resize here, you could do other transform here
    """
    return cv2.resize(img, (256, 256))


def save(save_path, img):
    """save image method"""
    cv2.imwrite(save_path, rgb_2_bgr(img))

在上面几个函数构建对应的处理函数

def do_load_image(load_queue: Queue, trainsform_queue: Queue, target_dir:str):
    while True:
        file = load_queue.get()
        if file is None: break
        target_file = os.path.join(target_dir, source_file)
        if not os.path.exists(target_file):      # skip all the transformed images
            img = default_loader_cv2(file)
            trainsform_queue.put((target_file, img))
        else:
            pass


def do_transforms(trainsform_queue: Queue, save_queue: Queue, stain_normalizer):
    while True:
        data = trainsform_queue.get()
        if data is None: break
        target_file, img = data
        img_norm = transform(stain_normalizer, img)
        save_queue.put((target_file, img_norm))


def do_save(save_queue:Queue):
    while True:
        data = save_queue.get()
        if data is None:  break
        target_file, img_norm = data
        save(target_file, img_norm)

main函数

在这里,是整个程度的启动,特别注意线程的启动与结束顺序,不要搞错了,不然程序会进行死循环。
一般生产者消费者,大家看到的都是只有两个函数(一个生产者,一个消费者),这里实行的是3个函数,load是transform的生产者,transform是save的生产者,这里利用队列实行了3个队列,实行了数据间的传递。可以利用这种思想实行更多层级的生产者与消费者模式。

def main(source_dir, target_dir):
	# 4104 image, took 224.6297s
    files = get_all_images(source_dir, IMG_EXTENSIONS)  # generator could only be iterated 1 time
    # transform will be the slowest, so load queue would be too much data if you donot maximize
    load_queue = Queue(maxsize=5000) 
    trainsform_queue = Queue()
    save_queue = Queue()
    
    for file in files:
        load_queue.put(file)
    
    # start load_threads
    load_threads = []
    for _ in range(2):
        t = threading.Thread(
            target=do_load_image,
            args=(load_queue, trainsform_queue, target_dir)
        )
        t.start()
        load_threads.append(t)

    # start transform_threads
    transform_threads = []
    for _ in range(6):
        t = threading.Thread(
            target=do_transforms,
            args=(trainsform_queue, save_queue, stain_normalizer)
        )
        t.start()
        transform_threads.append(t)

    # start save_threads
    save_threads = []
    for _ in range(4):
        t = threading.Thread(
            target=do_save,
            args=(save_queue,)
        )
        t.start()
        save_threads.append(t)
    
    # put sentinel load_threads to break the loop
    # DONOT put thread.join() under this loop
    for _ in load_threads:
        load_queue.put(None)

    for thread in load_threads:
        thread.join()

    # put sentinel transform_threads to break the loop
    # DONOT put thread.join() under this loop
    for thread in transform_threads:
        trainsform_queue.put(None)

    for thread in transform_threads:
        thread.join()

    # put sentinel transform_threads to break the loop
    # DONOT put thread.join() under this loop
    for thread in save_threads:
        save_queue.put(None)
    
    for thread in save_threads:
        thread.join()

按顺序执行

def single_thread(source_dir, target_dir):
	# 4104 image, took 486.4547s
    files = get_all_images(source_dir, IMG_EXTENSIONS)
    for file in files:
        target_file, img = load_image(target_dir, file)
        img_transform = transform(stain_normalizer, img)
        save(target_file, img_transform)

结果

从代码来看,单线程的顺序执行比多线程少不小的代码,而且结果也相对简单,基本上不会出什么问题。然后单线程的所要花费的时间却是多线程的2倍还要多。图像一共是4104张512x512的3通道png图像。单线程花费时间是486.4547s,而多线程花费时间是224.6297s。是虽然多线程的代码多了点,但是从性能上来说,还是比单线程顺序执行快不少,还是蛮值得的

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

jasneik

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值