iris数据集介绍
由统计学家和植物学家Ronald Fisher在1936年收集并发布。该数据集中包含了150个样本,其中每个样本代表了一朵鸢尾花(iris flower),并且包含了四个特征(sepal length(花萼长度)、sepal width(花萼宽度)、petal length(花瓣长度)和petal width(花瓣宽度))以及对应的类别标签(iris setosa、iris versicolor和iris virginica)。
- 样本数量:150条
- 类别数量:3类
- 每类样本:50条
- 特征维度:4
读取数据集
import numpy as np
import torch
def load_iris(filename):
data = np.load(filename)
features = data['data']
labels = data['label']
return torch.tensor(features, dtype=torch.float64), torch.tensor(labels, dtype=torch.int64)
train_data, train_label = load_iris(r"../../Dataset/iris/iris_train.npz")
valid_data, valid_label = load_iris(r"../../Dataset/iris/iris_valid.npz")
print(train_data.shape, train_label.shape, valid_data.shape, valid_label.shape,)
input_dim = train_data.shape[1]
output_dim = int(train_label.max
文章介绍了鸢尾花数据集,包含150个样本,用于花卉分类。通过PyTorch构建了一个基于softmax的网络模型,使用交叉熵损失函数和梯度下降优化器进行训练,并进行了模型性能评估。
最低0.47元/天 解锁文章
4124

被折叠的 条评论
为什么被折叠?



