import torch
# torch.multiprocessing:本地多进程模块的包装器
import torch.multiprocessing as multiprocessing
# 采样器模块:无放回抽样:随机抽样(RandomSampling)和系统抽样(SystematicSampling)。
# 有放回抽样:随机抽样(RepetitionRandomSampling)。
from .sampler import SequentialSampler, RandomSampler, BatchSampler
import collections
'''
collections模块在内置数据类型的基础上,提供了几个额外的数据类型:
1.namedtuple(): 生成可以使用名字来访问元素内容的tuple子类
2.deque: 双端队列,可以快速的从另外一侧追加和推出对象
3.Counter: 计数器,主要用来计数
4.OrderedDict: 有序字典
5.defaultdict: 带有默认值的字典
'''
import sys
# traceback模块被用来跟踪异常返回信息
import traceback
# threading 多线程控制和处理
import threading
from torch._six import string_classes
# python 版本查询
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
_use_shared_memory = False
"""Whether to use shared memory in default_collate"""
class ExceptionWrapper(object):
"Wraps an exception plus traceback to communicate across threads"
def __init__(self, exc_info):
self.exc_type = exc_info[0]
# format_exception:输出异常栈
# join:字符串操作函数,链接字符
'''a="abcd"
>>> ",".join(a)
'a,b,c,d'
'''
self.exc_msg = "".join(traceback.format_exception(*exc_info))
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
global _use_shared_memory
_use_shared_memory = True
torch.set_num_threads(1)
while True:
r = index_queue.get()
if r is None:
data_queue.put(None)
break
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
def _pin_memory_loop(in_queue, out_queue, done_event):
while True:
try:
r = in_queue.get()
except Exception:
if done_event.is_set():
return
raise
if r is
PyTorch代码学习-dataloader
最新推荐文章于 2024-04-23 23:21:34 发布