可以输入参数改变的模型代码或者特征选择的代码均可
def learning_curve(parameter_value: list, label: np.ndarray, train_data=None, model=None, cv=5):
'''用于快捷绘制模型学习曲线的函数,可用于迭代模型函数或者特征选择函数的参数值,
可自动判断改变的函数参数值和对应模型名称和默认评估指标绘制学习曲线和打印最好训练结果.
1.parameter_value
可输入[input_style,...](步长默认为1,个数默认5)
input_style
有step num str muti四种
step:['step',start,end,step]-np.arange()
num:['num',start,end,num]-np.linespace()
str:['str',可用于迭代的单个或多个str]
muti:['muti',可用于迭代的单个或多个数字]-比如学习率['muti'0.001,0.01,0.1]
2.绘制train_data(训练数据)特征选择参数改变的学习曲线时,请在对应处放置train_data的改变代码,输入具体model
绘制model(机器学习模型)参数改变的学习曲线时,请在对应处放置model的改变代码,输入具体train_data
放置在下面循环代码处!!!
'''
st_time = time.time()
global changed_params
acc_list = []
previous_params = None
#判断输入参数值并作出裁决
if len(parameter_value) == 3 and parameter_value =='step':
parameter_value.append(1)
elif len(parameter_value) == 3 and parameter_value =='num':
parameter_value.append(5)
elif len(parameter_value) ==2:
iter_list = parameter_value[1:]
if len(parameter_value) ==4:
if parameter_value[0] == 'num' :
iter_list = np.linspace(parameter_value[1],parameter_value[2]+1e-100,parameter_value[3])
elif parameter_value[0] == 'step':
st, ed, step = tuple(parameter_value)
iter_list = np.arange(st,ed+1e-100,step)
elif parameter_value[0] == 'muti':
iter_list = parameter_value[1:]
elif parameter_value[0] == 'str':
iter_list = parameter_value[1:]
else:
iter_list = parameter_value[1:]
#循环训练
print('training start...')
for inx,i in enumerate(iter_list):
try:
'下面这一行添加模型训练中参数改变的模型或者特征选择中的参数改变的代码'
model = XGBRegressor(n_estimators=250,subsample=1,learning_rate=i,max_depth=6)# 代码添加处
params = model.get_params()
# 找出参数值变化的参数名字-仅使用于模型训练时
changed_params = []
if previous_params:
for param in params:
if params[param] != previous_params[param]:
changed_params.append(param)
previous_params = params.copy()
acc_list.append(cross_val_score(model, train_data, label, cv=cv).mean())
print(f'{inx+1}th training end')
except:
print('\n\033[31m***An error has occurred***\033[0m\nsomething possible error cause:\n1.未放置参数改变的代码\n2.未指定参数未改变的另外一方,两者默认为NoneType类型\n3.传入代码的参数得到了无效的数字(比如<=0的无效参数)或者数字类型(参数只接受整数或者小数)或数字范围(范围不符合参数规范)\n4.本函数默认迭代i为\033[31mfloat\033[0m类型,传入改变函数代码时可手动加入int(i)防止参数不接受float型而报错')
sys.exit(0)
print('\033[32mtraining finished successfully\033[0m')
# 下面是输入一个数字得到一个评分的情况
if len(acc_list) == 1:
plt.scatter(iter_list[0], acc_list[0])
plt.text(iter_list[0], acc_list[0],
s=f'singal parameter:{iter_list[0]}\nvalue:{np.round(acc_list[0], 4)}')
else:
plt.plot(iter_list, acc_list)
plt.scatter(iter_list, acc_list)
#判断输入模型以获取模型名字和默认评估指标
model_name = str(model).split("(")[0]
if 'Regressor' in model_name:
score = 'R2'
else:
score = 'accuracy'
parameter = changed_params[0] if changed_params != [] else "parameter"
ed_time = time.time()
print(f'training process takes {ed_time - st_time:0.2f}s')
#以上述训练结果和获取的模型名字和评估指标绘图
plt.xlabel(f'{parameter} changed range')
plt.ylabel(f'evaluation criteria:{score}')
plt.title(f'{model_name} model learning curve')
plt.show()
#打印最好参数值以及输出模型评估结果
print(f'best \033[31m{parameter}\033[0m value: {iter_list[np.argmax(acc_list)]}\nbest \033[31m{score}\033[0m score: {np.max(acc_list)}')
使用过程中由任何疑问可以提出来哦,这也是我写的第一篇博客