L06_逻辑斯蒂回归

逻辑斯蒂回归用来二分类

0.导入需要的包

import torch
import torch.nn.functional as F

1.准备数据

x_data = torch.Tensor([[1],[2],[3]])
y_data = torch.Tensor([[0],[0],[1]])

2.定义模型

class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
        
    def forward(self,x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred
    
model = LogisticRegressionModel()

3.定义损失和优化器

criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.3)

4.训练

for epoch in range(500):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(f'epoch:{epoch},loss:{loss.item()}')
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
print(f'预测:x=4,y={model(torch.Tensor([4])).item()}')    
epoch:0,loss:0.6622183918952942
epoch:1,loss:0.5972882509231567
epoch:2,loss:0.5675674080848694
epoch:3,loss:0.5535971522331238
epoch:4,loss:0.5457622408866882
epoch:5,loss:0.5401986241340637
epoch:6,loss:0.5354849696159363
epoch:7,loss:0.531114935874939
epoch:8,loss:0.526910126209259
epoch:9,loss:0.5228070616722107
epoch:10,loss:0.5187824368476868
epoch:11,loss:0.5148274302482605
epoch:12,loss:0.5109379291534424
epoch:13,loss:0.507111668586731
epoch:14,loss:0.5033470988273621
epoch:15,loss:0.49964264035224915
epoch:16,loss:0.4959971010684967
epoch:17,loss:0.49240922927856445
epoch:18,loss:0.48887792229652405
epoch:19,loss:0.4854017496109009
epoch:20,loss:0.48197975754737854
epoch:21,loss:0.4786108434200287
epoch:22,loss:0.47529372572898865
epoch:23,loss:0.472027450799942
epoch:24,loss:0.46881103515625
epoch:25,loss:0.4656432569026947
epoch:26,loss:0.4625232219696045
epoch:27,loss:0.45944997668266296
epoch:28,loss:0.4564225375652313
epoch:29,loss:0.4534399211406708
epoch:30,loss:0.4505012035369873
epoch:31,loss:0.44760558009147644
epoch:32,loss:0.4447520971298218
epoch:33,loss:0.44193992018699646
epoch:34,loss:0.4391683042049408
epoch:35,loss:0.4364362955093384
epoch:36,loss:0.4337431490421295
epoch:37,loss:0.4310881197452545
epoch:38,loss:0.4284704029560089
epoch:39,loss:0.425889253616333
epoch:40,loss:0.42334404587745667
epoch:41,loss:0.420833945274353
epoch:42,loss:0.41835835576057434
epoch:43,loss:0.41591647267341614
epoch:44,loss:0.4135077893733978
epoch:45,loss:0.41113153100013733
epoch:46,loss:0.40878722071647644
epoch:47,loss:0.4064740240573883
epoch:48,loss:0.4041915833950043
epoch:49,loss:0.40193918347358704
epoch:50,loss:0.3997161388397217
epoch:51,loss:0.39752212166786194
epoch:52,loss:0.39535650610923767
epoch:53,loss:0.3932185471057892
epoch:54,loss:0.39110803604125977
epoch:55,loss:0.3890243470668793
epoch:56,loss:0.3869668245315552
epoch:57,loss:0.38493525981903076
epoch:58,loss:0.38292887806892395
epoch:59,loss:0.380947470664978
epoch:60,loss:0.37899050116539
epoch:61,loss:0.37705734372138977
epoch:62,loss:0.37514781951904297
epoch:63,loss:0.37326136231422424
epoch:64,loss:0.3713975250720978
epoch:65,loss:0.3695560395717621
epoch:66,loss:0.36773642897605896
epoch:67,loss:0.365938276052475
epoch:68,loss:0.36416128277778625
epoch:69,loss:0.3624049723148346
epoch:70,loss:0.36066901683807373
epoch:71,loss:0.3589531481266022
epoch:72,loss:0.3572568893432617
epoch:73,loss:0.3555799424648285
epoch:74,loss:0.3539219796657562
epoch:75,loss:0.35228273272514343
epoch:76,loss:0.3506617844104767
epoch:77,loss:0.3490588366985321
epoch:78,loss:0.3474736511707306
epoch:79,loss:0.34590592980384827
epoch:80,loss:0.34435519576072693
epoch:81,loss:0.34282147884368896
epoch:82,loss:0.3413042724132538
epoch:83,loss:0.3398033380508423
epoch:84,loss:0.3383183777332306
epoch:85,loss:0.3368492126464844
epoch:86,loss:0.33539560437202454
epoch:87,loss:0.33395716547966003
epoch:88,loss:0.3325338065624237
epoch:89,loss:0.3311251401901245
epoch:90,loss:0.32973095774650574
epoch:91,loss:0.3283511698246002
epoch:92,loss:0.32698535919189453
epoch:93,loss:0.3256334066390991
epoch:94,loss:0.32429513335227966
epoch:95,loss:0.3229702413082123
epoch:96,loss:0.3216584622859955
epoch:97,loss:0.3203597366809845
epoch:98,loss:0.31907379627227783
epoch:99,loss:0.3178004324436188
epoch:100,loss:0.3165395259857178
epoch:101,loss:0.31529080867767334
epoch:102,loss:0.31405404210090637
epoch:103,loss:0.31282922625541687
epoch:104,loss:0.3116159737110138
epoch:105,loss:0.31041431427001953
epoch:106,loss:0.30922389030456543
epoch:107,loss:0.3080446422100067
epoch:108,loss:0.30687636137008667
epoch:109,loss:0.30571886897087097
epoch:110,loss:0.30457210540771484
epoch:111,loss:0.3034358322620392
epoch:112,loss:0.30230990052223206
epoch:113,loss:0.3011942207813263
epoch:114,loss:0.300088495016098
epoch:115,loss:0.29899275302886963
epoch:116,loss:0.2979068458080292
epoch:117,loss:0.2968304455280304
epoch:118,loss:0.29576361179351807
epoch:119,loss:0.29470619559288025
epoch:120,loss:0.29365792870521545
epoch:121,loss:0.2926187515258789
epoch:122,loss:0.29158857464790344
epoch:123,loss:0.2905672490596771
epoch:124,loss:0.2895546853542328
epoch:125,loss:0.28855079412460327
epoch:126,loss:0.2875552475452423
epoch:127,loss:0.28656816482543945
epoch:128,loss:0.2855893075466156
epoch:129,loss:0.28461864590644836
epoch:130,loss:0.28365597128868103
epoch:131,loss:0.2827012836933136
epoch:132,loss:0.28175440430641174
epoch:133,loss:0.2808153033256531
epoch:134,loss:0.2798837721347809
epoch:135,loss:0.2789597511291504
epoch:136,loss:0.2780431807041168
epoch:137,loss:0.2771339416503906
epoch:138,loss:0.276231974363327
epoch:139,loss:0.2753371298313141
epoch:140,loss:0.27444931864738464
epoch:141,loss:0.2735684812068939
epoch:142,loss:0.27269449830055237
epoch:143,loss:0.2718273103237152
epoch:144,loss:0.27096685767173767
epoch:145,loss:0.270113080739975
epoch:146,loss:0.269265741109848
epoch:147,loss:0.2684248387813568
epoch:148,loss:0.26759031414985657
epoch:149,loss:0.2667621076107025
epoch:150,loss:0.2659401595592499
epoch:151,loss:0.26512423157691956
epoch:152,loss:0.26431453227996826
epoch:153,loss:0.26351070404052734
epoch:154,loss:0.2627127468585968
epoch:155,loss:0.2619207501411438
epoch:156,loss:0.261134535074234
epoch:157,loss:0.2603539824485779
epoch:158,loss:0.259579062461853
epoch:159,loss:0.25880974531173706
epoch:160,loss:0.25804591178894043
epoch:161,loss:0.25728753209114075
epoch:162,loss:0.25653454661369324
epoch:163,loss:0.25578686594963074
epoch:164,loss:0.25504443049430847
epoch:165,loss:0.25430724024772644
epoch:166,loss:0.2535751461982727
epoch:167,loss:0.25284814834594727
epoch:168,loss:0.25212612748146057
epoch:169,loss:0.2514090836048126
epoch:170,loss:0.25069698691368103
epoch:171,loss:0.24998970329761505
epoch:172,loss:0.2492872029542923
epoch:173,loss:0.24858951568603516
epoch:174,loss:0.2478964924812317
epoch:175,loss:0.24720807373523712
epoch:176,loss:0.24652425944805145
epoch:177,loss:0.24584496021270752
epoch:178,loss:0.24517016112804413
epoch:179,loss:0.2444998025894165
epoch:180,loss:0.24383385479450226
epoch:181,loss:0.24317222833633423
epoch:182,loss:0.24251486361026764
epoch:183,loss:0.24186182022094727
epoch:184,loss:0.24121296405792236
epoch:185,loss:0.24056823551654816
epoch:186,loss:0.23992769420146942
epoch:187,loss:0.23929114639759064
epoch:188,loss:0.23865866661071777
epoch:189,loss:0.23803012073040009
epoch:190,loss:0.23740558326244354
epoch:191,loss:0.23678497970104218
epoch:192,loss:0.23616819083690643
epoch:193,loss:0.23555515706539154
epoch:194,loss:0.23494601249694824
epoch:195,loss:0.23434054851531982
epoch:196,loss:0.23373882472515106
epoch:197,loss:0.23314078152179718
epoch:198,loss:0.2325463443994522
epoch:199,loss:0.23195551335811615
epoch:200,loss:0.23136834800243378
epoch:201,loss:0.23078449070453644
epoch:202,loss:0.23020429909229279
epoch:203,loss:0.22962749004364014
epoch:204,loss:0.22905409336090088
epoch:205,loss:0.22848407924175262
epoch:206,loss:0.22791743278503418
epoch:207,loss:0.22735412418842316
epoch:208,loss:0.22679413855075836
epoch:209,loss:0.22623737156391144
epoch:210,loss:0.2256838083267212
epoch:211,loss:0.22513343393802643
epoch:212,loss:0.22458629310131073
epoch:213,loss:0.22404217720031738
epoch:214,loss:0.22350120544433594
epoch:215,loss:0.222963348031044
epoch:216,loss:0.222428560256958
epoch:217,loss:0.22189675271511078
epoch:218,loss:0.22136789560317993
epoch:219,loss:0.22084204852581024
epoch:220,loss:0.22031907737255096
epoch:221,loss:0.21979911625385284
epoch:222,loss:0.21928191184997559
epoch:223,loss:0.21876764297485352
epoch:224,loss:0.21825619041919708
epoch:225,loss:0.21774746477603912
epoch:226,loss:0.2172415852546692
epoch:227,loss:0.21673841774463654
epoch:228,loss:0.21623806655406952
epoch:229,loss:0.21574024856090546
epoch:230,loss:0.21524524688720703
epoch:231,loss:0.21475277841091156
epoch:232,loss:0.21426303684711456
epoch:233,loss:0.2137758731842041
epoch:234,loss:0.21329133212566376
epoch:235,loss:0.2128092497587204
epoch:236,loss:0.21232974529266357
epoch:237,loss:0.21185283362865448
epoch:238,loss:0.2113783359527588
epoch:239,loss:0.2109062224626541
epoch:240,loss:0.21043670177459717
epoch:241,loss:0.20996952056884766
epoch:242,loss:0.2095048427581787
epoch:243,loss:0.20904244482517242
epoch:244,loss:0.20858250558376312
epoch:245,loss:0.2081248015165329
epoch:246,loss:0.20766951143741608
epoch:247,loss:0.2072165459394455
epoch:248,loss:0.2067657709121704
epoch:249,loss:0.20631732046604156
epoch:250,loss:0.20587114989757538
epoch:251,loss:0.2054271548986435
epoch:252,loss:0.2049853652715683
epoch:253,loss:0.20454581081867218
epoch:254,loss:0.2041083574295044
epoch:255,loss:0.20367316901683807
epoch:256,loss:0.2032400220632553
epoch:257,loss:0.20280903577804565
epoch:258,loss:0.20238013565540314
epoch:259,loss:0.20195335149765015
epoch:260,loss:0.2015286237001419
epoch:261,loss:0.20110589265823364
epoch:262,loss:0.20068521797657013
epoch:263,loss:0.20026664435863495
epoch:264,loss:0.19984997808933258
epoch:265,loss:0.1994353085756302
epoch:266,loss:0.19902272522449493
epoch:267,loss:0.19861197471618652
epoch:268,loss:0.1982032060623169
epoch:269,loss:0.19779632985591888
epoch:270,loss:0.19739143550395966
epoch:271,loss:0.19698835909366608
epoch:272,loss:0.1965871900320053
epoch:273,loss:0.19618795812129974
epoch:274,loss:0.19579048454761505
epoch:275,loss:0.1953948736190796
epoch:276,loss:0.19500108063220978
epoch:277,loss:0.19460906088352203
epoch:278,loss:0.19421891868114471
epoch:279,loss:0.19383053481578827
epoch:280,loss:0.19344383478164673
epoch:281,loss:0.19305895268917084
epoch:282,loss:0.1926758736371994
epoch:283,loss:0.1922944039106369
epoch:284,loss:0.19191469252109528
epoch:285,loss:0.19153672456741333
epoch:286,loss:0.1911603957414627
epoch:287,loss:0.19078578054904938
epoch:288,loss:0.19041281938552856
epoch:289,loss:0.19004149734973907
epoch:290,loss:0.18967176973819733
epoch:291,loss:0.18930374085903168
epoch:292,loss:0.18893729150295258
epoch:293,loss:0.1885724812746048
epoch:294,loss:0.18820922076702118
epoch:295,loss:0.1878475695848465
epoch:296,loss:0.1874874383211136
epoch:297,loss:0.1871289610862732
epoch:298,loss:0.1867719143629074
epoch:299,loss:0.18641650676727295
epoch:300,loss:0.18606257438659668
epoch:301,loss:0.1857100874185562
epoch:302,loss:0.1853591948747635
epoch:303,loss:0.18500977754592896
epoch:304,loss:0.1846618503332138
epoch:305,loss:0.18431539833545685
epoch:306,loss:0.18397033214569092
epoch:307,loss:0.18362677097320557
epoch:308,loss:0.18328464031219482
epoch:309,loss:0.18294398486614227
epoch:310,loss:0.18260468542575836
epoch:311,loss:0.18226683139801025
epoch:312,loss:0.18193034827709198
epoch:313,loss:0.18159528076648712
epoch:314,loss:0.1812615990638733
epoch:315,loss:0.1809292584657669
epoch:316,loss:0.18059830367565155
epoch:317,loss:0.18026868999004364
epoch:318,loss:0.17994041740894318
epoch:319,loss:0.17961353063583374
epoch:320,loss:0.17928792536258698
epoch:321,loss:0.17896361649036407
epoch:322,loss:0.17864061892032623
epoch:323,loss:0.17831893265247345
epoch:324,loss:0.1779986023902893
epoch:325,loss:0.17767947912216187
epoch:326,loss:0.17736168205738068
epoch:327,loss:0.1770450919866562
epoch:328,loss:0.17672979831695557
epoch:329,loss:0.17641572654247284
epoch:330,loss:0.176102876663208
epoch:331,loss:0.17579126358032227
epoch:332,loss:0.1754808872938156
epoch:333,loss:0.17517171800136566
epoch:334,loss:0.1748637706041336
epoch:335,loss:0.17455708980560303
epoch:336,loss:0.17425145208835602
epoch:337,loss:0.17394711077213287
epoch:338,loss:0.17364394664764404
epoch:339,loss:0.17334191501140594
epoch:340,loss:0.17304104566574097
epoch:341,loss:0.1727413386106491
epoch:342,loss:0.17244277894496918
epoch:343,loss:0.17214536666870117
epoch:344,loss:0.1718490868806839
epoch:345,loss:0.17155392467975616
epoch:346,loss:0.17125992476940155
epoch:347,loss:0.17096692323684692
epoch:348,loss:0.1706750988960266
epoch:349,loss:0.17038439214229584
epoch:350,loss:0.17009474337100983
epoch:351,loss:0.16980619728565216
epoch:352,loss:0.16951870918273926
epoch:353,loss:0.1692323088645935
epoch:354,loss:0.16894693672657013
epoch:355,loss:0.1686626821756363
epoch:356,loss:0.16837947070598602
epoch:357,loss:0.16809721291065216
epoch:358,loss:0.16781608760356903
epoch:359,loss:0.1675359457731247
epoch:360,loss:0.16725683212280273
epoch:361,loss:0.16697876155376434
epoch:362,loss:0.16670171916484833
epoch:363,loss:0.1664256602525711
epoch:364,loss:0.16615058481693268
epoch:365,loss:0.16587650775909424
epoch:366,loss:0.16560344398021698
epoch:367,loss:0.16533131897449493
epoch:368,loss:0.16506020724773407
epoch:369,loss:0.16479013860225677
epoch:370,loss:0.16452090442180634
epoch:371,loss:0.16425266861915588
epoch:372,loss:0.16398541629314423
epoch:373,loss:0.16371913254261017
epoch:374,loss:0.16345375776290894
epoch:375,loss:0.1631893664598465
epoch:376,loss:0.1629258096218109
epoch:377,loss:0.16266322135925293
epoch:378,loss:0.16240155696868896
epoch:379,loss:0.162140890955925
epoch:380,loss:0.16188107430934906
epoch:381,loss:0.16162210702896118
epoch:382,loss:0.16136406362056732
epoch:383,loss:0.16110695898532867
epoch:384,loss:0.16085071861743927
epoch:385,loss:0.1605953723192215
epoch:386,loss:0.16034092009067535
epoch:387,loss:0.16008730232715607
epoch:388,loss:0.15983451902866364
epoch:389,loss:0.1595826894044876
epoch:390,loss:0.15933169424533844
epoch:391,loss:0.1590815633535385
epoch:392,loss:0.15883223712444305
epoch:393,loss:0.15858377516269684
epoch:394,loss:0.15833614766597748
epoch:395,loss:0.15808947384357452
epoch:396,loss:0.1578434854745865
epoch:397,loss:0.15759839117527008
epoch:398,loss:0.15735407173633575
epoch:399,loss:0.15711058676242828
epoch:400,loss:0.15686796605587006
epoch:401,loss:0.15662606060504913
epoch:402,loss:0.15638504922389984
epoch:403,loss:0.15614484250545502
epoch:404,loss:0.1559053212404251
epoch:405,loss:0.15566670894622803
epoch:406,loss:0.15542885661125183
epoch:407,loss:0.15519171953201294
epoch:408,loss:0.1549554169178009
epoch:409,loss:0.15471990406513214
epoch:410,loss:0.15448515117168427
epoch:411,loss:0.1542511135339737
epoch:412,loss:0.15401791036128998
epoch:413,loss:0.15378546714782715
epoch:414,loss:0.15355370938777924
epoch:415,loss:0.1533227413892746
epoch:416,loss:0.15309247374534607
epoch:417,loss:0.15286298096179962
epoch:418,loss:0.15263423323631287
epoch:419,loss:0.15240627527236938
epoch:420,loss:0.15217895805835724
epoch:421,loss:0.15195240080356598
epoch:422,loss:0.15172649919986725
epoch:423,loss:0.15150146186351776
epoch:424,loss:0.15127703547477722
epoch:425,loss:0.15105333924293518
epoch:426,loss:0.15083037316799164
epoch:427,loss:0.1506081223487854
epoch:428,loss:0.1503865122795105
epoch:429,loss:0.15016567707061768
epoch:430,loss:0.1499454528093338
epoch:431,loss:0.14972595870494843
epoch:432,loss:0.14950720965862274
epoch:433,loss:0.14928902685642242
epoch:434,loss:0.1490715593099594
epoch:435,loss:0.14885486662387848
epoch:436,loss:0.14863869547843933
epoch:437,loss:0.14842329919338226
epoch:438,loss:0.14820851385593414
epoch:439,loss:0.14799444377422333
epoch:440,loss:0.14778093993663788
epoch:441,loss:0.14756818115711212
epoch:442,loss:0.1473560482263565
epoch:443,loss:0.14714455604553223
epoch:444,loss:0.14693373441696167
epoch:445,loss:0.14672355353832245
epoch:446,loss:0.14651398360729218
epoch:447,loss:0.14630500972270966
epoch:448,loss:0.14609675109386444
epoch:449,loss:0.14588911831378937
epoch:450,loss:0.14568205177783966
epoch:451,loss:0.14547567069530487
epoch:452,loss:0.14526985585689545
epoch:453,loss:0.14506466686725616
epoch:454,loss:0.1448601484298706
epoch:455,loss:0.1446561962366104
epoch:456,loss:0.14445284008979797
epoch:457,loss:0.1442500799894333
epoch:458,loss:0.14404797554016113
epoch:459,loss:0.14384643733501434
epoch:460,loss:0.14364556968212128
epoch:461,loss:0.1434452086687088
epoch:462,loss:0.14324545860290527
epoch:463,loss:0.14304625988006592
epoch:464,loss:0.1428476721048355
epoch:465,loss:0.14264975488185883
epoch:466,loss:0.14245234429836273
epoch:467,loss:0.1422555297613144
epoch:468,loss:0.14205926656723022
epoch:469,loss:0.14186353981494904
epoch:470,loss:0.141668438911438
epoch:471,loss:0.1414739042520523
epoch:472,loss:0.141279935836792
epoch:473,loss:0.14108650386333466
epoch:474,loss:0.1408936232328415
epoch:475,loss:0.14070133864879608
epoch:476,loss:0.14050953090190887
epoch:477,loss:0.1403183788061142
epoch:478,loss:0.14012771844863892
epoch:479,loss:0.1399376094341278
epoch:480,loss:0.1397479921579361
epoch:481,loss:0.13955903053283691
epoch:482,loss:0.13937050104141235
epoch:483,loss:0.13918252289295197
epoch:484,loss:0.13899511098861694
epoch:485,loss:0.1388082504272461
epoch:486,loss:0.13862188160419464
epoch:487,loss:0.13843606412410736
epoch:488,loss:0.1382507085800171
epoch:489,loss:0.13806594908237457
epoch:490,loss:0.13788163661956787
epoch:491,loss:0.13769787549972534
epoch:492,loss:0.13751469552516937
epoch:493,loss:0.13733188807964325
epoch:494,loss:0.13714967668056488
epoch:495,loss:0.1369679570198059
epoch:496,loss:0.13678671419620514
epoch:497,loss:0.13660602271556854
epoch:498,loss:0.13642586767673492
epoch:499,loss:0.13624614477157593
预测:x=4,y=0.9917744994163513

5.图像展示

import numpy as np
import matplotlib.pyplot as plt

from pylab import mpl
# 设置显示中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]

x = np.linspace(0,10,200)
x_t = torch.Tensor(x).view(200,1)
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x,y,color='blue')
plt.plot([0,10],[0.5,0.5],color='red')

plt.xlabel('Hours')
plt.ylabel('通过的概率')
plt.title('学习时间和通过的概率关系图')
plt.grid(linestyle='--')
plt.show
<function matplotlib.pyplot.show(close=None, block=None)>

在这里插入图片描述


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值