from torch.nn.parallel._functions import Scatter
from torch.nn.parallel import DataParallel
import torch
# This code was copied from torch.nn.parallel and adapted for DataParallel to chunk lists instead of duplicating them
# (this is really all this code is here for)
def scatter(inputs, target_gpus, dim=0):
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
if not isinstance(obj, list) or len(obj) == 0 or len(target_gpus) == 0:
return []
num_gpus = len(target_gpus)
num_samples = len(obj)
samples_per_gpu = num_samples // num_gpus
remaining_samples = num_samples % num_gpus
distributed_samples = []
start_idx = 0
for gpu_idx in range(num_gpus):
gpu_samples_count = samples_per_gpu + (1 if gpu_idx < remaining_samples else 0)
gpu_samples = obj[start_idx:start_idx + gpu_samples_count]
distributed_samples.append(gpu_samples)
start_idx += gpu_samples_count
return distributed_samples
# size = len(obj) // len(target_gpus)
# remaining_samples = len(obj) % len(target_gpus)
# res = [obj[i * size:(i + 1) * size] for i in range(len(target_gpus))]
# if remaining_samples==0:
# return res
# end = (len(target_gpus)+1)*size
# for list1 in res:
# list1.append(obj[end])
# end+=1
# if end==len(obj):
# break
# return res
if isinstance(obj, dict) and len(obj) > 0:
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return [obj for _ in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
class DataParallelV2(DataParallel):
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)