python batch_size_Python config.batch_size方法代码示例

本文详细整理了Python中config.batch_size方法的多种使用场景,包括数据加载、模型训练、对抗训练等,提供了丰富的代码示例,旨在帮助读者理解和应用该方法。文章覆盖了多个实际项目,如VQA2.0、Reinforce-Paraphrase-Generation、TextGAN-PyTorch等,涵盖了数据预处理、模型训练和评估等多个方面。
摘要由CSDN通过智能技术生成

本文整理汇总了Python中config.batch_size方法的典型用法代码示例。如果您正苦于以下问题:Python config.batch_size方法的具体用法?Python config.batch_size怎么用?Python config.batch_size使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在模块config的用法示例。

在下文中一共展示了config.batch_size方法的29个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: get_loader

​点赞 6

# 需要导入模块: import config [as 别名]

# 或者: from config import batch_size [as 别名]

def get_loader(train=False, val=False, test=False, trainval=False):

""" Returns a data loader for the desired split """

split = VQA(

utils.path_for(train=train, val=val, test=test, trainval=trainval, question=True),

utils.path_for(train=train, val=val, test=test, trainval=trainval, answer=True),

config.preprocessed_trainval_path if not test else config.preprocessed_test_path,

answerable_only=train or trainval,

dummy_answers=test,

)

loader = torch.utils.data.DataLoader(

split,

batch_size=config.batch_size,

shuffle=train or trainval, # only shuffle the data in training

pin_memory=True,

num_workers=config.data_workers,

collate_fn=collate_fn,

)

return loader

开发者ID:KaihuaTang,项目名称:VQA2.0-Recent-Approachs-2018.pytorch,代码行数:20,

示例2: __init__

​点赞 6

# 需要导入模块: import config [as 别名]

# 或者: from config import batch_size [as 别名]

def __init__(self):

self.vocab = Vocab(config.vocab_path, config.vocab_size)

self.batcher = Batcher(config.train_data_path, self.vocab, mode='train',

batch_size=config.batch_size, single_pass=False)

time.sleep(5)

if not os.path.exists(config.log_root):

os.mkdir(config.log_root)

self.model_dir = os.path.join(config.log_root, 'train_model')

if not os.path.exists(self.model_dir):

os.mkdir(self.model_dir)

self.eval_log = os.path.join(config.log_root, 'eval_log')

if not os.path.exists(self.eval_log):

os.mkdir(self.eval_log)

self.summary_writer = tf.compat.v1.summary.FileWriter(self.eval_log)

开发者ID:wyu-du,项目名称:Reinforce-Paraphrase-Generation,代码行数:19,

示例3: adv_train_generator

​点赞 6

# 需要导入模块: import config [as 别名]

# 或者: from config import batch_size [as 别名]

def adv_train_generator(self, g_step):

"""

The gen is trained by MLE-like objective.

"""

total_g_loss = 0

for step in range(g_step):

inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA)

# ===Train===

rewards = self.get_mali_reward(target)

adv_loss = self.gen.adv_loss(inp, target, rewards)

self.optimize(self.gen_adv_opt, adv_loss)

total_g_loss += adv_loss.item()

# ===Test===

self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True)))

开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:18,

示例4: adv_train_generator

​点赞 6

# 需要导入模块: import config [as 别名]

# 或者: from config import batch_size [as 别名]

def adv_train_generator(self, g_step):

total_loss = 0

for step in range(g_step):

real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float()

gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)

if cfg.CUDA:

real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()

# ===Train===

d_out_real = self.dis(real_samples)

d_out_fake = self.dis(gen_samples)

g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type)

self.optimize(self.gen_adv_opt, g_loss, self.gen)

total_loss += g_loss.item()

return total_loss / g_step if g_step != 0 else 0

开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:19,

示例5: adv_train_discriminator

​点赞 6

# 需要导入模块: import config [as 别名]

# 或者: from config import batch_size [as 别名]

def adv_train_discriminator(self, d_step):

total_loss = 0

for step in range(d_step):

real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float()

gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)

if cfg.CUDA:

real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()

# ===Train===

d_out_real = self.dis(real_samples)

d_out_fake = self.dis(gen_samples)

_, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type)

self.optimize(self.dis_opt, d_loss, self.dis)

total_loss += d_loss.item()

return total_loss / d_step if d_step != 0 else 0

开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:19,

示例6: cal_metrics

​点赞 6

# 需要导入模块: import config [as 别名]

# 或者: from config import batch_size [as 别名]

def cal_metrics(self, fmt_str=False):

"""

Calculate metrics

:param fmt_str: if return format string for logging

"""

with torch.no_grad():

# Prepare data for evaluation

gen_data = GenDataIter(self.gen.sample(cfg.samples_num, 4 * cfg.batch_size))

# Reset metrics

self.nll_oracle.reset(self.oracle, gen_data.loader)

self.nll_gen.reset(self.gen, self.oracle_data.loader)

self.nll_div.reset(self.gen, gen_data.loader)

if fmt_str:

return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics])

else:

return [metric.get_score() for metric in self.all_metrics]

开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:20,

示例7: train_discriminator

​点赞 6

# 需要导入模块: import config [as 别名]

# 或者: from config import batch_size [as 别名]

def train_discriminator(self, d_step, d_epoch, phase='MLE'):

"""

Training the discriminator on real_data_samples (positive) and generated samples from gen (negative).

Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch.

"""

# prepare loader for validate

global d_loss, train_acc

for step in range(d_step):

# prepare loader for training

pos_samples = self.train_data.target # not re-sample the Oracle data

neg_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size)

dis_data = DisDa

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值