主要是在类中添加两个函数:
get_params 以及set_params
以下示例是线性回归的代码:
还要注意的是在对数据进行操作时在fit()函数中以及在predict()函数中都要使用深度拷贝的方式
class LinearRegression():
def __init__(self, num_iters=2000, alpha=0.001):
self.num_iters = num_iters
self.alpha = alpha
def compute_cost(self, X, y, w):
#计算代价函数
m = X.shape[0]
J = (1 / (2 * m)) * np.sum((np.dot(X, w) - y) ** 2)
return J
def gradient_descent(self, X, y, w, num_iters, alpha):
#计算梯度
m = X.shape[0]
#记录损失值
J_all = np.zeros((num_iters, 1))
#开始迭代
for i in range(num_iters):
J_all[i] = self.compute_cost(X, y, w) #计算loss 代价函数
#更新w
w = w - ((alpha / m) * np.dot(X.T, (X.dot(w) - y[:, np.newaxis])))
return w, J_all
def fit(self, Xn, yn):
#训练模型
X=np.copy(Xn)