# -*- coding: utf-8 -*-
import torch
import optuna
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset
from torch import nn, optim
import pandas as pd
# -*- coding: utf-8 -*-
import torch
import optuna
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset
from torch import nn, optim
import pandas as pd
# 设置随机种子
torch.manual_seed(42)
#准备数据
def prepare_data():
return train_loader, val_loader,X_train, X_val
# 定义模型类
class MLPModel(nn.Module):
# 定义训练函数
def train(model, optimizer, criterion, data_loader):
model.train()
total_loss = 0.0
total_samples = 0
for inputs, targets in data_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item() * inputs.size(0)
total_samples += inputs.size(0)
return total_loss / total_samples
# 定义评估函数
def evaluate(model, data_loader):
model.eval()
correct = 0
total_samples = 0
with torch.no_grad():
for inputs, targets in data_loader:
outputs = model(inputs)
_, predicted_labels = torch.max(outputs.data, 1)
correct += (predicted_labels == targets).sum().item()
total_samples += targets.size(0)
accuracy = correct / total_samples
return accuracy
# 定义目标函数
def objective(trial):
# 定义超参数范围
hidden_dim = trial.suggest_int("hidden_dim", 2, 100)
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1)
# 创建MLP模型
input_dim = X_train.shape[1]
output_dim = 3
model = MLPModel(input_dim, hidden_dim, output_dim)
return val_accuracy
if __name__ == '__main__':
# 训练模型num_epochs
num_epochs = 1000
# 调用函数来准备数据
train_loader, val_loader,X_train, X_val = prepare_data()
# 设置试验次数
num_trials = 10
# 使用Optuna搜索最佳超参数配置
study = optuna.create_study(direction="maximize",storage='sqlite:///example.db')
# 加载指定相同的存储器地址
#study = optuna.load_study(study_name='my_study', storage='sqlite:///example.db')
#开始优化
study.optimize(objective, n_trials=num_trials)
# 打印最佳超参数配置和对应的验证集准确率
best_params = study.best_params
best_val_accuracy = study.best_value
print("Best Parameters:", best_params)
print("Best Validation Accuracy:", best_val_accuracy)
以上为部分代码,更多请关注同名公众号