python 训练神经网络时的小工具

1.删除当前文件夹下的所有jpg/png

import os
files= os.listdir()

for pic in files:  # 遍历文件夹

    if pic.endswith(".png"):
        os.remove(pic)
    elif pic.endswith(".jpg"):
        os.remove(pic)

2.mac下安装XGBoost

macOS终端输入:

 conda install py-xgboost

python文件直接import:

import xgboost as xgb

3. 实验数据写入csv文件保存

可以写一个函数:

import csv
import codecs
def wcsv(fname,vgs,vds,a1s,b1s,a2s,b2s):
    # fname为“xx.csv” 即csv文件的名字
    # vgs,vds,a1s,b1s,a2s,b2s为要保存的变量
    header = ['vg', "vd", "a1", "b1", "a2","b2"]
    # header为csv文件的表头
    f=codecs.open(fname,'w+','utf-8')
    writer=csv.writer(f)
    writer.writerow(header)
    # 逐行写入变量数据
    for vg,vd,a1,b1,a2,b2 in zip(vgs,vds,a1s,b1s,a2s,b2s):
        writer.writerow((vg,vd,a1,b1,a2,b2))

4.循环在一张图上将多条曲线

import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
for i in tvg.keys(): 
    # 取列表中某条曲线的点
    x = tvg[i]
    y = gids[i]
    ax.plot(x, y) 
plt.savefig(f"origin_{index}.jpg")
#保存图片

5.为网络参数设置不同的学习率

在optimizer里直接设置(以sgd为例子):

  optimizer = torch.optim.SGD([
    # 直接改每一项的lr即可
        {'params': wnn.net1.parameters(), 'lr':1e-3},
        {'params': wnn.net2.parameters(), 'lr': 1e-3},
        {'params': wnn.a1.parameters(), 'lr': 1e-3},
        {'params': wnn.a2.parameters(), 'lr': 1e-3},
        {'params': wnn.b1.parameters(), 'lr': 1e-5},
        {'params': wnn.b2.parameters(), 'lr': 1e-3}, ],
        weight_decay=0.0005,
    )

 

6.为网络参数设置自适应学习率

torch.optim.lr_scheduler 提供了几种方法来根据 epoch 的数量调整学习率。 torch.optim.lr_scheduler.ReduceLROnPlateau 允许基于一些验证测量来降低动态学习率。

以LambdaLR为例:

# 学习率调整策略
lambda1= lambda epoch: 0.95 ** epoch
# lr_lambda=[]用于填写optimizer对应优化参数的优化策略
# 此例为八个参数都选择lambda1的递减策略
scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=[lambda1,lambda1,lambda1,lambda1,lambda1,lambda1,lambda1,lambda1])
 

记得在优化循环中加上scheduler.step() !! 

7.模型的保存与重用

保存训练好的模型:

torch.save(model, './model.pt')

重新调用该模型:

model=torch.load('./model.pt')

8.一个优化器同时优化多个网络

可以参考5中的写法,也可以使用 itertools.chain:(以Adam为例)

optimizer=torch.optim.Adam(itertools.chain(
    wnn.net1.parameters(),
    wnn.net2.parameters(),
    wnn.a1.parameters(),
    wnn.a2.parameters(),
    wnn.b1.parameters(),
    wnn.b2.parameters(),
    wnn.bias.parameters()
    ), 
    lr=wnn.plr)

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

chococolate

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值