PT之DNN:基于泰坦尼克号数据集(独热编码/标签编码)利用PyTorch框架的浅层神经网络算法(pth和onnx文件的模型导出和载入推理)实现二分类预测应用案例
目录
基于泰坦尼克号数据集(独热编码/标签编码)利用PyTorch框架的浅层神经网络算法(pth和onnx文件的模型导出和载入推理)实现二分类预测应用案例
# 1、定义数据集
# 定义入模特征
# 2、数据预处理
# 2.1、缺失值处理
# 2.2、特征编码
# 2.3、分离特征与标签
# 3、模型训练与评估
# 3.1、切分数据集
# 转换数据集为PyTorch的Tensor格式
# 3.2、定义模型:前馈神经网络
# 初始化模型
# 定义损失函数和优化器
# 3.3、训练模型(前向传播+反向优化)
# 3.4、模型评估并输出预测结果
# 3.5、模型导出与推理
# T1、导出+载入pth模型文件进行推理
# T2、导出+载入ONNX模型
相关文章
PT之DNN:基于泰坦尼克号数据集(独热编码/标签编码)利用PyTorch框架的浅层神经网络算法(pth和onnx文件的模型导出和载入推理)实现二分类预测应用案例
PT之DNN:基于泰坦尼克号数据集(独热编码/标签编码)利用PyTorch框架的浅层神经网络算法(pth和onnx文件的模型导出和载入推理)实现二分类预测应用案例实现代码
基于泰坦尼克号数据集(独热编码/标签编码)利用PyTorch框架的浅层神经网络算法(pth和onnx文件的模型导出和载入推理)实现二分类预测应用案例
# 1、定义数据集
D:\ProgramData\Anaconda3\python.exe E:/File_Python/Python_daydayup/20230512.py
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 PassengerId 891 non-null int64
1 Survived 891 non-null int64
2 Pclass 891 non-null int64
3 Name 891 non-null object
4 Sex 891 non-null object
5 Age 714 non-null float64
6 SibSp 891 non-null int64
7 Parch 891 non