sklearn使用numpy ndarray或者pandas dataframe作为训练数据,调用fit()函数即可完成训练。
本部分我们先介绍一下sklearn的基本用法。
二分类
我们先看一个二分类问题,将mnist分类成数字5和非5两类:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
X,y = mnist['data'], mnist['target']
X_train, X_test = X[:6000], X[6000:]
y_train, y_test = y[:6000].astype(np.uint8), y[6000:].astype(np.uint8)
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(loss='hinge')
sgd_clf.fit(X_train, y_train_5)
print(sgd_clf.predict([X[0]]))
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_model, X_train, y_train_5, cv=3, scoring='accuracy')
[False]
array([0.96 , 0.9575, 0.964 ])
回归
我们再看一个回归算法的示例,使用的是housing数据集,预测地区房产的中位数。
DOWNLOAD_ROOT = "https://raw.githubusercontent.com/ageron/handson-ml2/master/"
HOUSING_PATH = os.path.join("datasets", "housing")
HOUSING_URL = DOWNLOAD_ROOT + "datasets/housing/housing.tgz"
def fetch_housing_data(housing_url=HOUSING_URL, housing_path=HOUSING_PATH):
if not os.path.isdir(housing_path):
os.makedirs(housing_path)
tgz_file = os.path.join(housing_path,'housing.tgz')
urllib.request.urlretrieve(housing_url, tgz_file)
housing_tgz = tarfile.open(tgz_file)
housing_tgz.extractall(path = housing_path) #解压文件
housing_tgz.close()
# fetch_housing_data()
housing = pd.read_csv(os.path.join(HOUSING_PATH,'housing.csv'))
median = housing['total_bedrooms'].median()
housing['total_bedrooms'].fillna(median,inplace=True)
housing_label = housing['median_house_value']
housing_feature = housing.drop(['median_house_value','ocean_proximity'], axis=1)
from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(housing_feature,housing_label)
print(model.intercept_, model.coef_)
-3570118.0614940603 [-4.26104026e+04 -4.24754782e+04 1.14445085e+03 -6.62091740e+00
8.11609666e+01 -3.98732002e+01 7.93047225e+01 3.97522237e+04]