本文介绍了如何加载各种数据源,以生成可以用于sklearn使用的数据集。主要包括以下几类数据源:
- 预定义的公共数据源
- 内存中的数据
- csv文件
- 任意格式的数据文件
- 稀疏数据格式文件
sklearn使用的数据集一般为numpy ndarray,或者pandas dataframe。
import numpy as np
import pandas as pd
import sklearn
import os
import urllib
import tarfile
1、预定义的公共数据源
更多数据集请见:https://scikitlearn.com.cn/0.21.3/47/
minst数据集
以下示例用于判断图片是否数字5
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
X,y = pd.DataFrame.to_numpy(mnist['data']), pd.DataFrame.to_numpy(mnist['target'])
X_train, X_test = X[:6000], X[6000:]
y_train, y_test = y[:6000].astype(np.uint8), y[6000:].astype(np.uint8)
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
from sklearn.linear_model import SGDClassifier
model = SGDClassifier(loss='hinge')
model.fit(X_train, y_train_5)
print(model.predict([X[0]]))
[ True]
iris数据集
这是一个非常著名的数据集,共有150朵鸢尾花,分别来自三个不同品种(山鸢尾、变色鸢尾和维吉尼亚鸢尾),数据里包含花的萼片以及花瓣的长度和宽度。
from sklearn import datasets
iris = datasets.load_iris()
我们看一下数据集。注意,sklearn的dataset都包含这些keys:
print(iris.keys