python+pytorch实现双线性模型对鸢尾花数据集训练

import torch.nn as nn

import torch

from sklearn.datasets import load_iris

data = load_iris()

 

X1 = torch.tensor(data.data).float()

X2 = 0.01 * torch.rand(150, 2).float()


 

y = torch.tensor(data.target).long()


 

W = torch.randn(3, X1.shape[1], X2.shape[1], requires_grad=True) # 鸢尾花3各类别

b = torch.rand(150, 1, requires_grad=True)

 

cost = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam([W, b])

 

for i in range(50000):

    optimizer.zero_grad()

    Y = torch.cat([torch.sum(torch.mm(X1, W[i,:,:]) * X2, dim=1).unsqueeze(1) for i in range(W.shape[0])], dim=1) + b

    loss = cost(Y, y)

    loss.backward()

    optimizer.step()

    print(loss)

 

_, id = torch.max(Y, 1)

 

correct = 0

for i in range(len(id)):

    if id[i] == y[i]:

        correct += 1

print(correct / 150)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
好的,下面我为您介绍一下如何使用 Python、SSD 和 PyTorch 实现目标检测。 1. 准备数据集和模型 首先,您需要准备好数据集,并使用 PyTorch 实现的 SSD 进行训练。可以使用一些开源数据集,如 COCO 数据集,或自己收集数据集。 2. 加载模型和图像 使用 PyTorch 加载训练好的模型,并将待检测的图像输入到模型中。可以使用 OpenCV 加载图像。这里需要注意,输入图像需要经过预处理,如缩放、归一化等操作。 ```python import torch import cv2 # 加载模型 model = torch.load('path/to/model.pth') model.eval() # 加载图像 image = cv2.imread('path/to/image.jpg') # 图像预处理 image = cv2.resize(image, (300, 300)) image = image.astype('float32') image = image / 255.0 image = (image - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225] image = image.transpose((2, 0, 1)) image = torch.from_numpy(image).unsqueeze(0) ``` 3. 进行目标检测 将预处理后的图像输入到模型中,模型将会输出目标在图像中的位置和分类信息。 ```python # 目标检测 with torch.no_grad(): outputs = model(image) detections = outputs.data # 解析检测结果 for i in range(detections.size(1)): j = 0 while detections[0, i, j, 0] >= 0.6: score = detections[0, i, j, 0] label_name = 'class_name' pt = (detections[0, i, j, 1:]*300).cpu().numpy() j += 1 # 在图像上绘制检测结果 cv2.rectangle(image, (int(pt[0]), int(pt[1])), (int(pt[2]), int(pt[3])), (0, 255, 0), 2) cv2.putText(image, label_name, (int(pt[0]), int(pt[1]-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) ``` 4. 展示检测结果 可以使用 OpenCV 在图像上绘制检测结果,并展示出来。 ```python # 展示检测结果 cv2.imshow('image', image) cv2.waitKey(0) cv2.destroyAllWindows() ``` 需要注意的是,这里只是一个简单的示例,实际应用中还需要针对具体情况进行调整和优化。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值