该楼层疑似违规已被系统折叠 隐藏此楼查看此楼
源码:# coding: utf-8
import sys, os
sys.path.append(os.pardir) # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net import MultiLayerNet
from common.util import shuffle_dataset
from common.trainer import Trainer
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
# 为了实现高速化,减少训练数据
x_train = x_train[:500]
t_train = t_train[:500]
# 分割验证数据
validation_rate = 0.20
validation_num = x_train.shape[0] * validation_rate
x_train, t_train = shuffle_dataset(x_train, t_train)
x_val = x_train[:validation_num]
x_val = int(x_val)
t_val = t_train[:validation_num]
t_val = int(t_val)
x_train = x_train[validation_num:]
allt_train = t_train[validation_num:]
def __train(lr, weight_decay, epocs=50):
network = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100, 100, 100],
output_size=10, weight_decay_lambda=weight_decay)
trainer = Trainer(network, x_train, t_train, x_val, t_val,
epochs=epocs, mini_batch_size=100,
optimizer='sgd', optimizer_param={'lr': lr}, verbose=False)
trainer.train()
return trainer.test_acc_list, trainer.train_acc_list
# 超参数的随机搜索======================================
optimization_trial = 100
results_val = {}
results_train = {}
for _ in range(optimization_trial):
# 指定搜索的超参数的范围===============
weight_decay = 10 ** np.random.uniform(-8, -4)
lr = 10 ** np.random.uniform(-6, -2)
# ================================================
val_acc_list, train_acc_list = __train(lr, weight_decay)
print("val acc:" + str(val_acc_list[-1]) + " | lr:" + str(lr) + ", weight decay:" + str(weight_decay))
key = "lr:" + str(lr) + ", weight decay:" + str(weight_decay)
results_val[key] = val_acc_list
results_train[key] = train_acc_list
# 绘制图形========================================================
print("=========== Hyper-Parameter Optimization Result ===========")
graph_draw_num = 20
col_num = 5
row_num = int(np.ceil(graph_draw_num / col_num))
i = 0
for key, val_acc_list in sorted(results_val.items(), key=lambda x:x[1][-1], reverse=True):
print("Best-" + str(i+1) + "(val acc:" + str(val_acc_list[-1]) + ") | " + key)
plt.subplot(row_num, col_num, i+1)
plt.title("Best-" + str(i+1))
plt.ylim(0.0, 1.0)
if i % 5: plt.yticks([])
plt.xticks([])
x = np.arange(len(val_acc_list))
plt.plot(x, val_acc_list)
plt.plot(x, results_train[key], "--")
i += 1
if i >= graph_draw_num:
break
plt.show()
以下为报错:
Reloaded modules: dataset, dataset.mnist, common, common.multi_layer_net, common.layers, common.functions, common.util, common.gradient, common.trainer, common.optimizer
Traceback (most recent call last):
File "", line 1, in
runfile('C:/Users/11/Desktop/python学习/190124.py', wdir='C:/Users/11/Desktop/python学习')
File "C:\ProgramData\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 704, in runfile
execfile(filename, namespace)
File "C:\ProgramData\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 108, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "C:/Users/11/Desktop/python学习/190124.py", line 29, in
x_val = x_train[:validation_num]
TypeError: slice indices must be integers or None or have an __index__ method
runfile('C:/Users/11/Desktop/python学习/190124.py', wdir='C:/Users/11/Desktop/python学习')
Reloaded modules: dataset, dataset.mnist, common, common.multi_layer_net, common.layers, common.functions, common.util, common.gradient, common.trainer, common.optimizer
Traceback (most recent call last):
File "", line 1, in
runfile('C:/Users/11/Desktop/python学习/190124.py', wdir='C:/Users/11/Desktop/python学习')
File "C:\ProgramData\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 704, in runfile
execfile(filename, namespace)
File "C:\ProgramData\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 108, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "C:/Users/11/Desktop/python学习/190124.py", line 29, in
x_val = x_train[:validation_num]
TypeError: slice indices must be integers or None or have an __index__ method