pytorch
文章平均质量分 64
Lntano__y
在这里记录自己的学习历程,并分享给有需要的同学!
展开
-
from torch.utils.data import Dataset详解
torch.utils.data.Dataset 是 PyTorch 数据加载库中的一个重要类,用于定义自定义数据集。通过继承 Dataset 类,可以创建自己的数据集类,并实现数据的加载和处理逻辑。通过继承和实现 torch.utils.data.Dataset 类,可以灵活地创建自定义数据集,并与 DataLoader 结合使用,实现高效的数据加载和处理。(self, idx):支持索引操作,返回指定索引的样本。继承 Dataset 类,并实现。方法中加载文件数据。2.1 导入相关模块。原创 2024-07-19 16:38:23 · 581 阅读 · 0 评论 -
torch.no_grad()详解
torch.no_grad() 是一个用于禁用梯度计算的上下文管理器,适用于模型评估、推理等不需要梯度计算的场景。在使用 torch.no_grad() 时,通常还会将模型设置为评估模式(model.eval()),以确保某些层(如 dropout 和 batch normalization)在推理时的行为与训练时不同。torch.no_grad() 可以嵌套使用,内层的 torch.no_grad() 仍然会禁用梯度计算。进入 torch.no_grad() 上下文,临时禁用梯度计算。原创 2024-07-19 10:46:01 · 1382 阅读 · 0 评论 -
from torch.utils.data import TensorDataset 详解
这里,x 是一个形状为 (100, 10) 的张量,表示 100 个样本,每个样本有 10 个特征。y 是一个形状为 (100,) 的张量,表示 100 个样本的标签,取值在 0 和 1 之间(二分类问题)。在这个实例中,我们创建了一个包含 100 个样本的数据集,每个样本有 10 个特征,并将其划分为批次,每个批次包含 20 个样本。这里,我们创建了一个数据加载器 train_dataloader,每个批次包含 20 个样本,并且在每个 epoch 结束后会打乱数据(shuffle=True)。原创 2024-07-19 10:01:38 · 479 阅读 · 0 评论