classNode(object):def__init__(self,X,y,split=Split.null()):self.X=Xself.y=yself.split=splitself.left=Noneself.right=Noneself.outcome=np.mean(self.y)ifself.y.any()elseNone...
class Node(object):
def __init__(self, X, y, split=Split.null()):
self.X = X
self.y = y
self.split = split
self.left = None
self.right = None
self.outcome = np.mean(self.y) if self.y.any() else None
def traverse(self):
result = []
nodes = [self] #这是什么意思??
while nodes:
current_node = nodes.pop()
result.append(current_node)
if current_node.left:
nodes.insert(0, current_node.left)
if current_node.right:
nodes.insert(0, current_node.right)
return result
def predict(self, X):
result = np.zeros(X.shape[0])
for i, x in enumerate(X):
nodes = [self]
while nodes:
current_node = nodes.pop()
if current_node.is_leaf:
result[i] = current_node.outcome
break
if x[current_node.split.split_attribute] < current_node.split\
.split_value:
nodes.insert(0, current_node.left)
else:
nodes.insert(0, current_node.right)
return result
def fit(self, X=None, y=None, max_tree_size=10):
self.X = X or self.X
self.y = y or self.y
tree = [self]
n_nodes = 1
while len(tree):
if n_nodes >= max_tree_size:
break
node = tree.pop()
node.grow()
if node.left:
n_nodes += 1
tree.insert(0, node.left)
if node.right:
n_nodes += 1
tree.insert(0, node.right)
return self
def grow(self, verbose=0):
self.outcome = np.mean(self.y) if not self.outcome else self.outcome
self.split = split(self.X, self.y)
if self.split == Split.null():
return
indexes = self.split.indexes(self.X)
self.left = Node(self.X[indexes], self.y[indexes])
self.right = Node(self.X[~indexes], self.y[~indexes])
@property
def is_leaf(self):
return not (self.right or self.left)
def __str__(self):
return '' % (
#self.X.shape if self.X != None else None,
self.X.shape if self.X.any() else None,
self.split.split_attribute,
self.split.split_value,
self.outcome)
@property
def pstr(self):
return '\n'.join([str(n) for n in self.traverse()])
nodes = [self] #这是什么意思??
展开