# -*- coding: utf-8 -*-
import csv
from random import seed
from random import randrange
from math import sqrt
def loadCSV(filename): # 加载数据,一行行的存入列表
dataSet = []
with open(filename, 'r') as file:
csvReader = csv.reader(file)
for line in csvReader:
dataSet.append(line)
return dataSet
# 除了标签列,其他列都转换为float类型
def column_to_float(dataSet):
featLen = len(dataSet[0]) - 1
for data in dataSet:
for column in range(featLen):
data[column] = float(data[column].strip())
# 将数据集随机分成N块,方便交叉验证,其中一块是测试集,其他四块是训练集
def spiltDataSet(dataSet, n_folds):
fold_size = int(len(dataSet) / n_folds)
dataSet_copy = list(dataSet)
dataSet_spilt = []
for i in range(n_folds):
fold = []
while len(fold) < fold_size: # 这里不能用if,if只是在第一次判断时起作用,while执行循环,直到条件不成立
index = randrange(len(dataSet_copy))
fold.append(dataSet_copy.pop(index)) # pop() 函数用于移除列表中的一个元素(默认最后一个元素),并且返回该元素的值。
dataSet_spilt.append(fold)
return dataSet_spilt
# 构造数据子集
def get_subsample(dataSet, ratio):
subdataSet = []
lenSubdata = round(len(dataSet) * ratio) # 返回浮点数
while len(subdataSet) < lenSubdata:
index = randrange(len(dataSet) - 1)
subdataSet.append(dataSet[index])
# print len(subdataSet)
return subdataSet
# 分割数据集
def data_spilt(dataSe
co-forest算法
最新推荐文章于 2024-11-09 18:29:17 发布