【2021-4-2】手动码出AlexNet(3.主文件供参考)

声明一下,这只是基本的网络和运行结构。调参之后的模型和数据由于需要花费精力而且没有指导意义,不会放。

from Data_Channel import Data_Channel
from AlexNet_Model import AlexNet
import numpy as np
import matplotlib.pyplot as plt
'''
主训练函数,这个就简单了,直接调用数据通道就行。
创建三组数据通道,依次调用,传入模型进行训练。
'''
Alex_model = AlexNet(learning_rate=0.02, drop_out=0.8, n_classes=6)
Labels_OH = {"cloudy":np.tile(np.array([0,0,0,0,0,1]), (80,1)), "haze":np.tile(np.array([0,0,0,0,1,0]), (80,1)),
             "rainy":np.tile(np.array([0,0,0,1,0,0]), (80,1)), "snow":np.tile(np.array([0,0,1,0,0,0]), (80,1)),
             "sunny":np.tile(np.array([0,1,0,0,0,0]), (80,1)), "thunder":np.tile(np.array([1,0,0,0,0,0]), (80,1))}
DC_Dic = {"cloudy":Data_Channel(category="cloudy", pool_size=20), "haze":Data_Channel(category="haze", pool_size=20),
          "rainy":Data_Channel(category="rainy", pool_size=20), "snow":Data_Channel(category="snow", pool_size=20),
          "sunny":Data_Channel(category="sunny", pool_size=20), "thunder":Data_Channel(category="thunder", pool_size=20)}

DC_list = ["cloudy", "haze", "rainy", "snow", "sunny", "thunder"]

for i in range(50):
    print("running")
    Index = DC_list[i%6]
    labels_now = Labels_OH[Index]
    DC = DC_Dic[Index]
    DC.Renew_dataset()
    Alex_model.learn(DC.RF_pool, labels_now)

X = np.arange(len(Alex_model.Loss_list))
plt.plot(X, Alex_model.Loss_list, '-r')
plt.grid()
plt.show()


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值