solver = _check_solver(self.solver, self.penalty, self.dual)
if not isinstance(self.C, numbers.Number) or self.C < 0:
raise ValueError("Penalty term must be positive; got (C=%r)"
% self.C)
if self.penalty == 'elasticnet':
if (not isinstance(self.l1_ratio, numbers.Number) or
self.l1_ratio < 0 or self.l1_ratio > 1):
raise ValueError("l1_ratio must be between 0 and 1;"
" got (l1_ratio=%r)" % self.l1_ratio)
elif self.l1_ratio is not None:
warnings.warn("l1_ratio parameter is only used when penalty is "
"'elasticnet'. Got "
"(penalty={})".format(self.penalty))
if self.penalty == 'none':
if self.C != 1.0: # default values
warnings.warn(
"Setting penalty='none' will ignore the C and l1_ratio "
"parameters"
)
# Note that check for l1_ratio is done right above
C_ = np.inf
penalty = 'l2'
else:
C_ = self.C
penalty = self.penalty
if not isinstance(self.max_iter, numbers.Number) or self.max_iter < 0:
raise ValueError("Maximum number of iteration must be positive;"
" got (max_iter=%r)" % self.max_iter)
if not isinstance(self.tol, numbers.Number) or self.tol < 0:
raise ValueError("Tolerance for stopping criteria must be "
"positive; got (tol=%r)" % self.tol)
if solver == 'lbfgs':
_dtype = np.float64
else:
_dtype = [np.float64, np.float32]
X, y = self._validate_data(X, y, accept_sparse='csr', dtype=_dtype,
order="C",
accept_large_sparse=solver != 'liblinear')
check_classification_targets(y)
self.classes_ = np.unique(y)
multi_class = _check_multi_class(self.multi_class, solver,
len(self.classes_))
-
_check_solver
函数:该函数用于检查和确定所选择的求解器(solver)。求解器是用于解决优化问题的算法,根据参数的不同可以选择不同的求解器。 -
参数合法性检查:代码中使用了一系列条件判断来验证参数的合法性。例如,检查惩罚项(penalty)是否有效、C 值是否为正数、l1_ratio 是否在 0 到 1 之间等。
-
警告提示:当参数设置不符合预期时,通过
warnings.warn
方法输出警告信息,提醒用户当前参数可能被忽略或仅在特定情况下生效。 -
数据验证和处理:通过调用
_validate_data
方法对输入数据进行验证和处理,确保数据类型正确、接受稀疏矩阵(sparse matrix)、对数据排序,同时根据求解器的不同决定是否接受大规模稀疏矩阵。 -
目标类别处理:通过调用
check_classification_targets
方法对目标变量 y 进行验证,确保它是一个有效的分类目标。 -
设置多类别问题参数:通过调用
_check_multi_class
方法对多类别问题中的参数进行验证和设置。