其实在WeightedRandomSampler中,采样的权重针对的是每一个样本,所以我们可以确定好每个类对应的权重,再一一对应到样本上。并且,权重其实就是比值,num_samples就是一次采样的数目,里面的比值其实就是权重的比值。
class WeightedRandomSampler(Sampler):
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights)."""
def __init__(self, weights, num_samples, replacement=True):
# ...省略类型检查
# weights用于确定生成索引的权重
self.weights = torch.as_tensor(weights, dtype=torch.double)
self.num_samples = num_samples
# 用于控制是否对数据进行有放回采样
self.replacement = replacement
def __iter__(self):
# 按照加权返回随机索引值
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
对于Weighted Random Sampler类的__init__()来说,replacement参数依旧用于控制采样是否是有放回的;num_sampler用于控制生成的个数;
weights参数对应的是“样本”的权重