- 设置超参数。
- (name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2),这句话设置了一个name和两个函数preprocess和d_input_func,前者用于预处理data,后者用于将data乘2。
- 定义函数get_distribution_sampler()通过输入均值和方差,返回正态分布的torch张量,维数是n维。
- 定义函数get_generator_input_sampler(),无输入,返回m×n维的torch张量。
- class Generator定义生成器,class Discriminator定义判别器。
- 定义函数extract(),用于将输入的v矩阵,转换为列表形式。
- 定义函数states(),返回d向量的平均值和标准差。
- 定义函数decorate_with_diffs(),第一行mean = torch.mean(data.data, 1, keepdim=True)用于返回按1维返回data.data的均值保存于mean;第二行mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0]),torch.mul(input, value, out=None)函数用于标量值value乘以input中的所有值,因此mean_broadcast里保存了一组与data维度相同的均值矩阵;第三行diffs = torch.pow(data - Variable(mean_broadcast), exponent),torch.pow(input, exponent, out
pytorch实现GAN代码详解
最新推荐文章于 2024-08-02 16:43:37 发布