pytorch Dataset and Datasetloader

在使用pytorch构架数据集的时候,使用torch.utils.data.Dataset和torch.utils.data.DataLoader会大大提高效率,且基本能满足需求。
这里主要记录个人对于torch.utils.data.Dataset和torch.utils.data.DataLoader的理解

1. torch.utils.data.Dataset:
pytorch对于Dataset的原始定义是比较简单的:

而我们在定义自己的数据集时,是继承这个类再进行的定义的,并且继承后一定要重写 __len__(self)和__getitem__(self, index)这两个函数,例如:

我认为这里最重要的就是我画横线的三个函数,也是自己新建的数据库中必须包含的三个函数。
__len__(): 这个比较还理解,返回文件的个数;
__init__(): 初始化函数,init里包含的性质也就是自己新建的数据集包含的性质;
__getitem__():  我认为这个函数是Dataset类最重要的一个函数。因为他决定了你之后用torch.utils.data.DataLoader之后返回的内容。开始我一直不知道这个函数在什么时候会被调用,因为在类内部找不到调用这个函数的地方,而它又不像__init__()会被默认调用。原来是在后面使用torch.utils.data.DataLoader时才会调用这个函数。后面说再详细点。
其他函数可以根据需要再常规定义。

 

2. torch.utils.data.DataLoader
Dataloader的定义复杂一些,函数也更多。
|
在使用dataloader的时候:


在运行到画红线的地方会进入到Dataset的__getitem__()里面,返回的是一个迭代器,这个迭代器的内容是N个batch_size的数据,所以用 in 来得到各个batch的数据。
在读取数据的过程中,如果Dataset写不好,生成这个迭代器的过程是很费时间的,会减慢训练速度。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
ACGAN stands for Auxiliary Classifier Generative Adversarial Networks. It is a type of generative model that uses deep neural networks to generate new data samples that mimic a given dataset. ACGANs also have an auxiliary classifier that helps to generate samples with specific attributes or labels. PyTorch is a popular deep learning framework used for building and training neural networks. PyTorch provides a simple and efficient way to build ACGAN models. To build an ACGAN in PyTorch, you would typically: 1. Define the generator and discriminator networks using PyTorch's nn.Module class. 2. Implement the loss functions for the generator and discriminator networks. 3. Train the ACGAN model using PyTorch's built-in optimization functions and training loops. Here is an example of PyTorch code for building an ACGAN: ``` import torch import torch.nn as nn import torch.optim as optim class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() # define generator network architecture def forward(self, z, y): # generate new samples based on noise vector z and label y class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() # define discriminator network architecture def forward(self, x, y): # classify whether input x is real or fake based on label y # define loss functions adversarial_loss = nn.BCELoss() auxiliary_loss = nn.CrossEntropyLoss() # initialize generator and discriminator networks generator = Generator() discriminator = Discriminator() # define optimizer for each network optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) # train ACGAN model for epoch in range(num_epochs): for i, (real_images, real_labels) in enumerate(data_loader): # train discriminator with real images discriminator.zero_grad() real_validity = discriminator(real_images, real_labels) real_loss = adversarial_loss(real_validity, torch.ones(real_validity.size()).cuda()) real_loss.backward() # train discriminator with fake images z = torch.randn(batch_size, latent_dim).cuda() fake_labels = torch.randint(0, num_classes, (batch_size,)).cuda() fake_images = generator(z, fake_labels).detach() fake_validity = discriminator(fake_images, fake_labels) fake_loss = adversarial_loss(fake_validity, torch.zeros(fake_validity.size()).cuda()) fake_loss.backward() # train generator generator.zero_grad() gen_images = generator(z, fake_labels) gen_validity = discriminator(gen_images, fake_labels) gen_loss = adversarial_loss(gen_validity, torch.ones(gen_validity.size()).cuda()) aux_loss = auxiliary_loss(fake_labels, fake_labels) g_loss = gen_loss + aux_loss g_loss.backward() # update discriminator and generator parameters optimizer_D.step() optimizer_G.step() # print training progress print("[Epoch %d/%d] [Batch %d/%d] D_loss: %.4f G_loss: %.4f" % (epoch+1, num_epochs, i+1, len(data_loader), (real_loss+fake_loss).item(), g_loss.item())) ``` In the above code, we define a Generator and Discriminator network, loss functions, and optimizers. We then train the ACGAN model by alternating between training the discriminator and generator networks on batches of real and fake data samples. The generator network is trained to generate new samples that fool the discriminator network, while also generating samples with specific attributes or labels.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值