先看看调用
train, val = random_split(dataset, [n_train, n_val])
def random_split(dataset, lengths):
r"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
Arguments:
dataset (Dataset): Dataset to be split
lengths (sequence): lengths of splits to be produced
"""
if sum(lengths) != len(dataset):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
indices = randperm(sum(lengths)).tolist()
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
从英文字面意思是,将dataset随机的分割为无覆盖的datasets,按照给定的形状,例如这里是[n_train,n_val]分成两组,一组由n_train个数据,一组datasets有n_val个数据。并且取得的数据是随机的。随机从dataset里面无放回的取出n_train和n_val个数据,这样就不会重复了
第一个判断就是这个lengths里面的数据和跟dataset的len是不是一样的,因为不管分多少组,取的数据就len(dataset)这么大。
接下来看randperm,生成1~n之间整数的无重复随机排列。生成1~len(dataset)的数列并转换为list
最后subset,
先看一个for循环
Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)
for offset,length in zip(。。。。)
查看_accumulate
def _accumulate(iterable, fn=lambda x, y: x + y):
'Return running totals'
# _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
# _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
it = iter(iterable)
try:
total = next(it)
except StopIteration:
return
yield total
for element in it:
total = fn(total, element)
yield total
从示例来看,offset应该是n_train,n_train+n_val,length是n_train,n_val
所以indices就是0:n_train第一个,n_train:n_train+n_val,刚刚好将indices分成两段。
总的来说是调用了两次创建了两个subset放到了list并返回,这个时候得到
然后看subset
class Subset(Dataset):
r"""
Subset of a dataset at specified indices.
Arguments:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
这个很简单只是简单的复制操作而已,