一、Thread + Queue实现 -> 基于下述视频的修正版本
代码来源: B站视频-第五集
Python实现生产者消费者爬虫_哔哩哔哩_bilibili
不知道有谁和我一样在学习Python多线程时也是看这个视频学习的,但在这个视频的代码部分却显然暴露出一些不完美的地方;
①程序开始打开了文件fout=open(...)却没有显式的进行关闭, 这可能带来写入中断、内存泄露等问题;
②程序并没有给出写明到底何时才完成所有任务,只能人工进行判断;
作为追求优雅的Python编写者自然不能容忍这样糟糕的代码,简单搜索得知:
①可以使用queue.task_done进行任务完成标注;因为Queue类 实例内部维护了一个变量叫unfinished_tasks; 当使用queue.put()方法时,unfinished_tasks ++; 当使用queue.task_done时,unfinished_tasks --;
②使用queue.join会在queue.unfinished_tasks != 0时保持阻塞
思考:
在视频代码部分, 我们可以通过queue.join的方法确保任务全部完成,然后关闭文件; 但是注意到:
①如果我们只检测消费者队列即html_queue中放进去的任务是否全部完成,显然不可行,因为存在生产者还没生产完产品,消费者就已经将产品全部消费的情况;
②只检测生产者队列即url_queue, 也不可行,因为存在生产者已经生产完产品,但消费者还没消费完产品的情况;
所以只能两个都检查,并且注意顺序: 先检查生产者队列是否为空, 为空证明生产完全部产品,再检查消费者队列,确保生产出的产品全部被消费者消费完; 对应到代码的85和87行;
在确保任务全部完成后,我们就可以安全的关闭文件了; 但是注意,此时线程没有结束,程序依旧在运行,一方面是由于视频代码中使用的是while True死循环,使得程序不可能结束; 另外一方面是因为queue.get()是一个阻塞操作;当queue.empty()为True时,程序就会阻塞住!
如果我们只是简单地在任务全部完成后为循环设置退出条件running=False; 然后在while循环中每次循环时都判断一次这个running, 还是会因为queue.get()的阻塞卡住;因为程序进入不到下一次循环去判断这个running了!
因此最后就是要解决queue.get()的阻塞问题了; 一种简单的办法就是设置超时,但我并不太认可这种做法,一来是你需要评估任务的最长用时是多少;其次是如果你设置超时时间过长了,程序就会浪费很多时间在无意义的等待上;设置的超时时间短了,可能queue还有东西需要put,你就认定queue的任务已经全部完成了,过早结束了任务;
所以我想到的策略是,不让queue.get()阻塞进入下一次循环不就好了?可以在任务确定结束后,人为的给queue加入任务,从而正常的进入下一次循环,最后结束线程! 代码见第91-94行
我猜想也许还有更好的解决方法,但我感觉目前这个策略是简单易懂的,极好的发挥了生产者消费者模式的效率优势! 完整代码如下:
from typing import List, Callable, IO
from re import finditer, compile
from queue import Queue
from random import randint
import requests
import threading
import time
class Spider(object):
def __init__(self, urls: List[str], url_queue: Queue, html_queue: Queue, save_obj: IO):
self.urls = urls
self.url_queue = url_queue
self.html_queue = html_queue
self.save_obj = save_obj
self.running = True
self.initialize_url_queue()
@staticmethod
def timeit(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
start = time.time()
res = func(*args, **kwargs)
print(f"用时: {time.time() - start}s")
return res
return wrapper
def initialize_url_queue(self) -> None:
for url in self.urls:
self.url_queue.put(url)
@timeit
def single_thread(self) -> None:
for url in self.urls:
self.crawl(url)
@timeit
def multi_thread(self) -> None:
threads = [threading.Thread(target=self.crawl, args=(url, )) for url in self.urls]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
@staticmethod
def crawl(url: str) -> str:
return url if url == "" else requests.get(url).text
@staticmethod
def parse(html: str) -> finditer:
# 多组件的Pipeline技术架构->复杂事件分很多中间步骤一步步完成
if html == "":
return list()
regular_exp = compile(r'post-item-title" href="(https://.*?)" target="_blank">(.*?)</a>')
match_objs = regular_exp.finditer(html)
for match_obj in match_objs:
yield match_obj.groups()
def __producer(self) -> None:
# 为了更真实的模拟生产者模式,加入随机睡眠延长生产时间
while self.running:
print(threading.current_thread().name)
url = self.url_queue.get()
html = self.crawl(url)
self.html_queue.put(html)
time.sleep(randint(1, 3))
self.url_queue.task_done()
print(f"退出线程: {threading.current_thread().name}")
def __consumer(self) -> None:
while self.running:
print(threading.current_thread().name)
html = self.html_queue.get()
parses = self.parse(html)
for result in parses:
self.save_obj.write(str(result) + '\n')
time.sleep(randint(1, 3))
self.html_queue.task_done()
print(f"退出线程: {threading.current_thread().name}")
def exit_thread(self, producer_thread_count: int, consumer_thread_count: int) -> None:
# 证明urls中的任务都被consumer全部生产完成了
self.url_queue.join()
# 证明消费者已经将全部生成的产品消费完了
self.html_queue.join()
self.save_obj.close()
self.running = False
# 为了让程序不阻塞在queue.get()进入下一次循环,给queue加入空数据
for _ in range(producer_thread_count):
self.url_queue.put("")
for _ in range(consumer_thread_count - producer_thread_count):
self.html_queue.put("")
print("任务全部完成!!!")
def producer_consumer_pattern(self, producer_thread_count: int, consumer_thread_count: int):
for i in range(producer_thread_count):
threading.Thread(target=self.__producer, name=f"producer{i}...").start()
for j in range(consumer_thread_count):
threading.Thread(target=self.__consumer, name=f"consumer{j}...").start()
self.exit_thread(producer_thread_count, consumer_thread_count)
if __name__ == "__main__":
url_lst = [f"https://www.cnblogs.com/#p{page}" for page in range(1, 100)]
url__queue = Queue()
html__queue = Queue()
save_file = open(r"CrawlData.txt", 'w', encoding="utf-8")
spider = Spider(url_lst, url__queue, html__queue, save_file)
spider.producer_consumer_pattern(10, 10)
# spider.multi_thread()
# spider.single_thread()
过渡: 使用传统的Thread + Queue实现生产者消费者线程还是有些麻烦的,在Python中还有封装得更加完美的模块,即concurrent下的futures, 关于它很多B站视频均有介绍,因此不在此赘述。
二、使用concurrent.futures实现
常见避坑: 由于concurrent模块的高度封装性,如果不知其运行机制直接使用很容易写出一些不可预料的糟糕代码;因此必须要阐述一些重要的点:
①futures.ThreadPoolExecutor实例对象的submit方法是非阻塞的,它会直接返回一个future对象, 但是future.result()是阻塞的,只有当future.done()为True即任务完成时才能成功返回结果;
②futures.as_complete()会将 list[futures] -> 生成器对象,优先获取future.done()为True的future对象然后通过yield进行返回;
③with模式下创建的futures.ThreadPoolExecutor对象,在脱离with的作用域后会自行shutdown, 就像with模式open文件一样。而shutdown需要保证线程池中的任务全部完成, 这是一个阻塞操作!
在有了这仨点最重要的储备后,下述代码便变得极好理解了:
"""
使用ThreadPool实现生产者消费者模式
①map方法比较固定,一次将所有任务全部提交, 而且是阻塞的,直到所有任务都完成才会返回htmls结果
②map阻塞式返回的是list, 而submit非阻塞式返回future对象, 配合as_complete可以方便的实现生产者消费者模式
③使用with方法打开ThreadPoolExecutor, 要注意with作用域Executor会auto-shutdown,
而shutdown需要保证线程池内任务全部完成
"""
from concurrent import futures
from typing import List, Callable, IO
from re import compile
from random import randint
import requests
import time
class Spider(object):
def __init__(self, urls: List[str], save_obj: IO):
self.urls = urls
self.htmls = None
self.save_obj = save_obj
@staticmethod
def timeit(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
start = time.time()
res = func(*args, **kwargs)
print(f"{func.__name__}用时: {time.time() - start}s")
return res
return wrapper
@staticmethod
def crawl(url: str) -> str:
print("爬取中...")
time.sleep(randint(1, 3))
return requests.get(url).text
@staticmethod
def parse(html: str) -> list:
print("解析中...")
time.sleep(randint(1, 3))
regular_exp = compile(r'post-item-title" href="(https://.*?)" target="_blank">(.*?)</a>')
match_objs = regular_exp.finditer(html)
return [match_obj.groups() for match_obj in match_objs]
def save_to_disk(self, response: futures):
for res in response.result():
self.save_obj.write(str(res) + '\n')
def __producer(self, producer_thread_count: int) -> futures.ThreadPoolExecutor:
producer_pool = futures.ThreadPoolExecutor(max_workers=producer_thread_count)
self.htmls = [producer_pool.submit(self.crawl, url) for url in self.urls]
# futures.wait(self.htmls, return_when=futures.FIRST_COMPLETED)
return producer_pool
@timeit
def __consumer(self, consumer_thread_count: int) -> None:
with futures.ThreadPoolExecutor(max_workers=consumer_thread_count) as consumer_pool:
for html in futures.as_completed(self.htmls):
future = consumer_pool.submit(self.parse, html.result())
future.add_done_callback(self.save_to_disk)
@timeit
def producer_consumer_pattern(self, producer_thread_count: int, consumer_thread_count: int) -> None:
producer_pool = self.__producer(producer_thread_count)
self.__consumer(consumer_thread_count)
producer_pool.shutdown()
self.save_obj.close()
if __name__ == "__main__":
url_lst = [f"https://www.cnblogs.com/#p{page}" for page in range(1, 100)]
save_file = open(r"CrawlData3.txt", 'w', encoding="utf-8")
spider = Spider(url_lst, save_file)
spider.producer_consumer_pattern(7, 10)
代码注意点:
①38-44行的parse从原来的生成器对象变成了函数; 这是因为直接调用生成器对象是不会执行其内部任何代码的, 为了让submit的任务代码可以被线程执行, 所以将生成器改成了函数;
②46行的__producer函数没有使用with模式创建futures.ThreadPoolExecutor对象,这是为了避免with模式对producer线程池的auto-shutdown, 使得主线程必须先执行完__producer全部任务才能继续执行__consumner;
(ps: 当然你也可以选择嵌套with, 外层包producer线程池, 内层包consumer线程池;
或者使用threading.Thread将__producer作为子线程进行执行,这两种办法都可以避免阻塞)
③63行增加回调函数,将数据解析和数据持久化的逻辑进行分离。