svm torch导出pt并运行

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
import torch
import torch.onnx

# 读取数据
def load_data():
    with open('Position_Salaries.csv', 'r') as f:
        data = np.loadtxt(f, delimiter=',', skiprows=1)
    X = data[:, 1:2] # 特征变量
    y = data[:, 2:] # 标签变量
    return X, y
    
# 特征处理和标准化
def feature_normalization(X, y):
    scaler_X = StandardScaler()
    X = scaler_X.fit_transform(X)
    scaler_y = StandardScaler()
    y = scaler_y.fit_transform(y)
    return X, y, scaler_X, scaler_y

# 创建SVR模型并训练
def train_model(X_train, y_train, kernel='rbf'):
    regressor = SVR(kernel=kernel).fit(X_train, y_train)
    return regressor

# 绘制散点图和预测曲线
def plot_result(X, y, regressor):
    plt.scatter(X, y)
    X_test = np.linspace(X.min(), X.max(), 100).reshape(-1, 1)
    y_pred = regressor.predict(X_test)
    plt.plot(X_test, y_pred, color='r')
    plt.show()

# 保存模型
def save_model(model, path):
    torch.save(model, path)

# 生成ONNX模型
def save_onnx(model, input_shape, output_path):
    xx = torch.randn(input_shape)
    with torch.no_grad():
        torch.onnx.export(model,
                          xx, 
                          output_path,
                          opset_version=11,
                          input_names=['input'], 
                          output_names=['output'])
def modelrun():
    model = torch.load('regressor.pt', map_location=torch.device('cpu'))#加载模型
    score = model.predict([[-1.5]])#模型预测
    print(score)

if __name__ == "__main__":
    X, y = load_data()
    X_train, y_train, scaler_X, scaler_y = feature_normalization(X, y)
    regressor = train_model(X_train, y_train)
    plot_result(X_train, y_train, regressor)
    save_model(regressor, 'regressorr.pt')
    #save_onnx(regressor, [1, 10], "regressor.onnx")
    modelrun()

生成.pt,并运行模型预测

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值