Python深度学习:数据加载

Python深度学习:数据加载

1、Dataset基类torch.utils.data.Dataset

以一个案例来描述如何使用Dataset来加载数据。

数据来源:https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection

数据介绍:用于骚扰短信识别的经典数据集,每行开头用ham和spam标识正常短信和骚扰短信。

from torch.utils.data import Dataset

data_path = "./smsspamcollection/SMSSpamCollection"


class MyDataset(Dataset):
    def __init__(self):
        self.lines = open(data_path, encoding="utf8").readlines()

    def __getitem__(self, item):
        cur_line = self.lines[item].strip()
        label = cur_line[:4].strip()
        content = cur_line[4:].strip()
        return label, content

    def __len__(self):
        return len(self.lines)


if __name__ == '__main__':
    myDataset = MyDataset()
    for i in range(len(myDataset)):
        print(f"No.{i + 1}: {myDataset[i]}")
    print(f"total:{len(myDataset)}")

在这里插入图片描述

2、迭代数据集

DataLoader(dataset=myDataset, batch_size=2, shuffle=True)

from torch.utils.data import DataLoader

myDataset = MyDataset()
data_loader = DataLoader(dataset=myDataset, batch_size=2, shuffle=True)

if __name__ == '__main__':
    for i in data_loader:
        print(i)

在这里插入图片描述

3、pytorch中自带的数据集

  • torchvision:图像,在torchvision.datasets
  • torchtext:文本,在torchtext.datasets

torchvision.datasets中的MNIST

import torchvision

dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True)

print(dataset[0])	# (<PIL.Image.Image image mode=L size=28x28 at 0x1A106A67438>, 5)

img = dataset[0][0]
img.show()

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个简单的Python深度学习数据分析示例代码,使用了Keras库和MNIST数据集来训练一个手写数字识别模型: ```python import numpy as np from keras.datasets import mnist from keras.models import Sequential from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D from keras.utils import np_utils # MNIST数据集 (X_train, y_train), (X_test, y_test) = mnist.load_data() # 调整数据格式 X_train = X_train.reshape(X_train.shape[0], 28, 28, 1) X_test = X_test.reshape(X_test.shape[0], 28, 28, 1) X_train = X_train.astype('float32') X_test = X_test.astype('float32') X_train /= 255 X_test /= 255 # 将标签转换为one-hot编码 Y_train = np_utils.to_categorical(y_train, 10) Y_test = np_utils.to_categorical(y_test, 10) # 定义模型 model = Sequential() model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1))) model.add(Conv2D(64, (3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(128, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(10, activation='softmax')) # 编译模型 model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) # 训练模型 model.fit(X_train, Y_train, batch_size=32, epochs=10, verbose=1, validation_data=(X_test, Y_test)) # 评估模型 score = model.evaluate(X_test, Y_test, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1]) ``` 这个示例代码使用了卷积神经网络(CNN)来识别手写数字。它包括两个卷积层、一个池化层、两个dropout层和两个全连接层。训练过程使用了Adam优化器和交叉熵损失函数。最后,模型在测试数据集上获得了大约99%的准确率。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值