loss下降auc下降_pytorch基础3:梯度下降算法

7b56cf3456184cc4534d2cec1324be4c.png

pytorch基础3:梯度下降算法

import matplotlib.pyplot as plt
import numpy as np

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
a = []
b = []
# 随机猜测权重为1.0,而后开始贪婪算法

def forward(x):
    return x * w
# 定义线性函数

def cost(xs, ys):
    cost = 0
    for x, y in zip(xs, ys):
        y_pred = forward(x)
        cost += (y_pred - y) ** 2
    return cost / len(xs)
# 定义损失函数,从xs和ys中取出对应的值,计算其损失值

def gradient(xs, ys):
    grad = 0
    for x, y in zip(xs, ys):
        grad += 2 * x * (x * w - y)
    return grad / len(xs)
# 计算梯度,使其之后的优化梯度为损失函数的导数值,沿着梯度最大方向

print('Predict(before training)', 4, forward(4))
for epoch in range(100):
    cost_val = cost(x_data, y_data)
    grad_val = gradient(x_data, y_data)
    w -= 0.01 * grad_val
    print('Epoch', epoch, 'w=', w, 'loss=', cost_val)
    a.append(epoch)
    b.append(cost_val)
print('Predict (after training)', 4, forward(4))
# 开始训练,计算损失值和梯度值,每次步长为0.01,开始训练并存储训练次数和损失值

print(a)
print(b)
plt.plot(a, b)
plt.ylabel('cost')
plt.xlabel('Epoch')
plt.xlim([0, 100])
plt.ylim([0, 4])
plt.show()
#进行matplot绘图

结果展示:

Predict(before training) 4 4.0
Epoch 0 w= 1.0933333333333333 loss= 4.666666666666667
Epoch 1 w= 1.1779555555555554 loss= 3.8362074074074086
Epoch 2 w= 1.2546797037037036 loss= 3.1535329869958857
Epoch 3 w= 1.3242429313580246 loss= 2.592344272332262
Epoch 4 w= 1.3873135910979424 loss= 2.1310222071581117
Epoch 5 w= 1.4444976559288012 loss= 1.7517949663820642
Epoch 6 w= 1.4963445413754464 loss= 1.440053319920117
Epoch 7 w= 1.5433523841804047 loss= 1.1837878313441108
Epoch 8 w= 1.5859728283235668 loss= 0.9731262101573632
Epoch 9 w= 1.6246153643467005 loss= 0.7999529948031382
Epoch 10 w= 1.659651263674342 loss= 0.6575969151946154
Epoch 11 w= 1.6914171457314033 loss= 0.5405738908195378
Epoch 12 w= 1.7202182121298057 loss= 0.44437576375991855
Epoch 13 w= 1.7463311789976905 loss= 0.365296627844598
Epoch 14 w= 1.7700069356245727 loss= 0.3002900634939416
Epoch 15 w= 1.7914729549662791 loss= 0.2468517784170642
Epoch 16 w= 1.8109354791694263 loss= 0.2029231330489788
Epoch 17 w= 1.8285815011136133 loss= 0.16681183417217407
Epoch 18 w= 1.8445805610096762 loss= 0.1371267415488235
Epoch 19 w= 1.8590863753154396 loss= 0.11272427607497944
Epoch 20 w= 1.872238313619332 loss= 0.09266436490145864
Epoch 21 w= 1.8841627376815275 loss= 0.07617422636521683
Epoch 22 w= 1.8949742154979183 loss= 0.06261859959338009
Epoch 23 w= 1.904776622051446 loss= 0.051475271914629306
Epoch 24 w= 1.9136641373266443 loss= 0.04231496130368814
Epoch 25 w= 1.9217221511761575 loss= 0.03478477885657844
Epoch 26 w= 1.9290280837330496 loss= 0.02859463421027894
Epoch 27 w= 1.9356521292512983 loss= 0.023506060193480772
Epoch 28 w= 1.9416579305211772 loss= 0.01932302619282764
Epoch 29 w= 1.9471031903392007 loss= 0.015884386331668398
Epoch 30 w= 1.952040225907542 loss= 0.01305767153735723
Epoch 31 w= 1.9565164714895047 loss= 0.010733986344664803
Epoch 32 w= 1.9605749341504843 loss= 0.008823813841374291
Epoch 33 w= 1.9642546069631057 loss= 0.007253567147113681
Epoch 34 w= 1.9675908436465492 loss= 0.005962754575689583
Epoch 35 w= 1.970615698239538 loss= 0.004901649272531298
Epoch 36 w= 1.9733582330705144 loss= 0.004029373553099482
Epoch 37 w= 1.975844797983933 loss= 0.0033123241439168096
Epoch 38 w= 1.9780992835054327 loss= 0.0027228776607060357
Epoch 39 w= 1.980143350378259 loss= 0.002238326453885249
Epoch 40 w= 1.9819966376762883 loss= 0.001840003826269386
Epoch 41 w= 1.983676951493168 loss= 0.0015125649231412608
Epoch 42 w= 1.9852004360204722 loss= 0.0012433955919298103
Epoch 43 w= 1.9865817286585614 loss= 0.0010221264385926248
Epoch 44 w= 1.987834100650429 loss= 0.0008402333603648631
Epoch 45 w= 1.9889695845897222 loss= 0.0006907091659248264
Epoch 46 w= 1.9899990900280147 loss= 0.0005677936325753796
Epoch 47 w= 1.9909325082920666 loss= 0.0004667516012495216
Epoch 48 w= 1.9917788075181404 loss= 0.000383690560742734
Epoch 49 w= 1.9925461188164473 loss= 0.00031541069384432885
Epoch 50 w= 1.9932418143935788 loss= 0.0002592816085930997
Epoch 51 w= 1.9938725783835114 loss= 0.0002131410058905752
Epoch 52 w= 1.994444471067717 loss= 0.00017521137977565514
Epoch 53 w= 1.9949629871013967 loss= 0.0001440315413480261
Epoch 54 w= 1.9954331083052663 loss= 0.0001184003283899171
Epoch 55 w= 1.9958593515301082 loss= 9.733033217332803e-05
Epoch 56 w= 1.9962458120539648 loss= 8.000985883901657e-05
Epoch 57 w= 1.9965962029289281 loss= 6.57716599593935e-05
Epoch 58 w= 1.9969138906555615 loss= 5.406722767150764e-05
Epoch 59 w= 1.997201927527709 loss= 4.444566413387458e-05
Epoch 60 w= 1.9974630809584561 loss= 3.65363112808981e-05
Epoch 61 w= 1.9976998600690001 loss= 3.0034471708953996e-05
Epoch 62 w= 1.9979145397958935 loss= 2.4689670610172655e-05
Epoch 63 w= 1.9981091827482769 loss= 2.0296006560253656e-05
Epoch 64 w= 1.9982856590251044 loss= 1.6684219437262796e-05
Epoch 65 w= 1.9984456641827613 loss= 1.3715169898293847e-05
Epoch 66 w= 1.9985907355257035 loss= 1.1274479219506377e-05
Epoch 67 w= 1.9987222668766378 loss= 9.268123006398985e-06
Epoch 68 w= 1.9988415219681517 loss= 7.61880902783969e-06
Epoch 69 w= 1.9989496465844576 loss= 6.262999634617916e-06
Epoch 70 w= 1.9990476795699081 loss= 5.1484640551938914e-06
Epoch 71 w= 1.9991365628100501 loss= 4.232266273994499e-06
Epoch 72 w= 1.999217150281112 loss= 3.479110977946351e-06
Epoch 73 w= 1.999290216254875 loss= 2.859983851026929e-06
Epoch 74 w= 1.9993564627377531 loss= 2.3510338359374262e-06
Epoch 75 w= 1.9994165262155628 loss= 1.932654303533636e-06
Epoch 76 w= 1.999470983768777 loss= 1.5887277332523938e-06
Epoch 77 w= 1.9995203586170245 loss= 1.3060048068548734e-06
Epoch 78 w= 1.9995651251461022 loss= 1.0735939958924364e-06
Epoch 79 w= 1.9996057134657994 loss= 8.825419799121559e-07
Epoch 80 w= 1.9996425135423248 loss= 7.254887315754342e-07
Epoch 81 w= 1.999675878945041 loss= 5.963839812987369e-07
Epoch 82 w= 1.999706130243504 loss= 4.902541385825727e-07
Epoch 83 w= 1.9997335580874436 loss= 4.0301069098738336e-07
Epoch 84 w= 1.9997584259992822 loss= 3.312926995781724e-07
Epoch 85 w= 1.9997809729060159 loss= 2.723373231729343e-07
Epoch 86 w= 1.9998014154347876 loss= 2.2387338352920307e-07
Epoch 87 w= 1.9998199499942075 loss= 1.8403387118941732e-07
Epoch 88 w= 1.9998367546614149 loss= 1.5128402140063082e-07
Epoch 89 w= 1.9998519908930161 loss= 1.2436218932547864e-07
Epoch 90 w= 1.9998658050763347 loss= 1.0223124683409346e-07
Epoch 91 w= 1.9998783299358769 loss= 8.403862850836479e-08
Epoch 92 w= 1.9998896858085284 loss= 6.908348768398496e-08
Epoch 93 w= 1.9998999817997325 loss= 5.678969725349543e-08
Epoch 94 w= 1.9999093168317574 loss= 4.66836551287917e-08
Epoch 95 w= 1.9999177805941268 loss= 3.8376039345125727e-08
Epoch 96 w= 1.9999254544053418 loss= 3.154680994333735e-08
Epoch 97 w= 1.9999324119941766 loss= 2.593287985380858e-08
Epoch 98 w= 1.9999387202080534 loss= 2.131797981222471e-08
Epoch 99 w= 1.9999444396553017 loss= 1.752432687141379e-08
Predict (after training) 4 7.999777758621207

4e7fcdca7e2278ac8cbbb373c3874ea8.png
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
这段代码主要是绘制用于比较不同分类器在二分类问题中的表现,通过绘制 ROC 曲线来比较每个分类器的性能。下面逐行解释代码: ```python plt.figure(figsize=(10, 8)) ``` 创建一个大小为 10 * 8 英寸的图像对象。 ```python plt.plot([0, 1], [0, 1], 'k--') ``` 绘制一条从 (0,0) 到 (1,1) 的直线,颜色为黑色,线型为虚线,用于表示随机猜测的分类器的表现。 ```python for name, model, color in zip(['KNN', 'LightGBM', 'XGBoost', 'Random Forest'], [knn_model, lgb_model, xgb_model, rf_model], ['#0e72cc', '#6ca30f', '#f59311', '#fa4343']): y_pred_prob = model.predict_proba(X_test)[:, 1] fpr, tpr, _ = roc_curve(y_test, y_pred_prob) auc_score = roc_auc_score(y_test, y_pred_prob) plt.plot(fpr, tpr, label=f'{name} (AUC={auc_score:.4f})', color=color) ``` 使用 `zip` 函数将分类器的名称、模型和颜色进行打包,进行循环遍历。在每次循环中,使用当前模型在测试集上进行预测,得到预测概率值 `y_pred_prob`,然后使用 `roc_curve` 函数计算得到真正率 `tpr` 和假正率 `fpr`,再使用 `roc_auc_score` 函数计算 AUC 值。最后,使用 `plt.plot` 函数绘制当前分类器的 ROC 曲线,并在图例中添加分类器名称和对应的 AUC 值。 ```python plt.xlabel('False positive rate') plt.ylabel('True positive rate') plt.title('ROC curve') plt.legend() plt.show() ``` 设置坐标轴的标签和标题,并显示图例和绘制的 ROC 曲线。 ```python print('KNN_AUC score:', auc_score_knn) print('LGB_AUC score:', auc_score_lgb) print('XGB_AUC score:', auc_score_xgb) print('RF_AUC score:', auc_score_rf) ``` 打印每个分类器的 AUC 值。但是这段代码中没有定义 `auc_score_knn`、`auc_score_lgb`、`auc_score_xgb` 和 `auc_score_rf`,所以这些变量可能是在其他地方定义的。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值