import torch
from torch.utils.data import DataLoader, Dataset
import threading
import queue
import time
# 定义一个自定义的数据集类
class MyDataset(Dataset):
def __init__(self):
self.data = list(range(100))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 定义一个数据处理函数
def process_data(data):
# 模拟数据处理过程
print("Processing data:", data)
time.sleep(1)
# 定义一个数据处理线程类
class DataProcessThread(threading.Thread):
def __init__(self, data_queue):
super(DataProcessThread, self).__init__()
self.data_queue = data_queue
self.daemon = True
def run(self):
while True:
batch = self.data_queue.get()
if batch is None:
break
process_data(batch)
self.data_queue.task_done()
# 创建自定义数据集实例
dataset = MyDataset()
# 创建 DataLoader 实例
dataloader = DataLoader(dataset, batch_size=4)
# 创建一个数据队列
data_queue = queue.Queue()
# 创建固定数量的数据处理线程
num_threads = 4
threads = [DataProcessThread(data_queue) for _ in range(num_threads)]
# 启动数据处理线程
for thread in threads:
thread.start()
# 将数据批次放入队列
for batch in dataloader:
data_queue.put(batch)
# 等待所有数据处理完成
data_queue.join()
# 向队列中添加结束标志
for _ in range(num_threads):
data_queue.put(None)
# 等待数据处理线程结束
for thread in threads:
thread.join()
print("All data has been processed.")
Python 多线程框架
最新推荐文章于 2024-05-16 15:52:17 发布