torch.nn.data_parallel和class nn.DataParallel都在一个source里边,一个是函数一个是类,这里贴出来源代码,如果出现了各种类似out of memory、 a chunk memory之类的都可以看看。
import operator
import torch
import warnings
from itertools import chain
from ..modules import Module
from .scatter_gather import scatter_kwargs, gather
from .replicate import replicate
from .parallel_apply import parallel_apply
from torch.cuda._utils import _get_device_index
def _check_balance(device_ids):
imbalance_warn = """
There is an imbalance between your GPUs. You may want to exclude GPU {} which
has less than 75% of the memory or cores of GPU {}. You can do so by setting
the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
environment variable."""
device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
def warn_imbalance(get_prop):
values = [get_prop(props) for props in dev_props]
min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
if min_val / max_val < 0.75:
warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
return True
return False
if warn_imbalance(lambda props: props.total_memory):
if warn_imbalance(lambda props: props.multi_processor_count):
class DataParallel(Module):
r"""Implements data parallelism at the module level.
This container parallelizes the application of the given :attr:`module` by
splitting the input across the specified devices by chunking in the batch
dimension (other objects will be copied once per device). In the forward
pass, the module is replicated on each device, and each replica handles a
portion of the input.