机器学习手撕代码(0)数据
- 后面连续几篇博客把之前没写完的几个经典机器学习算法代码补了补,尽量精简了代码量,欢迎找bug。
- 这第0篇主要是说一下一些数据格式之类的准备,保证代码复制回去就能跑通。
文件树就是下面这个样子,不必须,import数据集文件没问题就行。
数据来源:kaggle葡萄酒预测
datasets文件夹下面放一个dataset.py文件,后面所有的模型都用这一个数据集。
dataset.py
import pandas as pd
import numpy as np
class DataSet:
def __init__(self,path,mode='cla',rad_seed = 2021):
data = pd.read_csv(path).dropna(axis=0, how='any')
data1 = data[:2000]
data2 = data[-2000:]
data = pd.concat([data1, data2]).reset_index().drop(['index'], axis=1)
data = data.replace('white', 0).replace('red', 1)
if mode == 'cla':
self.target_head = 'type'
elif mode == 'reg':
self.target_head = 'residual sugar'
self.data_head = data.columns.to_list()
self.data_head.remove(self.target_head)
self.target = data[self.target_head].to_numpy()
self.data = data[self.data_head].to_numpy()
if rad_seed is not False:
np.random.seed(rad_seed)
permutation = list(np.random.permutation(len(self.data)))
self.data = self.data[permutation]
self.target = self.target[permutation]
def get_data(self):
return self.data,self.target,self.target_head,self.data_head
后面文章的模型原理就不详述了,看我的不如看书,分享一下自己手撕的代码,做了最大简化,个人感觉简洁一些。