特征提取+聚类

该代码示例首先导入必要的库,包括PyTorch和OpenCV等。它定义了一个基于ResNet50的网络结构,然后加载一批图像,对图像进行预处理,通过ResNet提取特征。接着,利用KMeans算法对这些特征进行聚类,最后展示每个类别的样本图像。
摘要由CSDN通过智能技术生成
import os
import numpy as np
from sklearn.cluster import KMeans
import cv2
from imutils import build_montages
import torch.nn as nn
import torchvision.models as models
from PIL import Image
from torchvision import transforms

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        resnet50 = models.resnet50(pretrained=True)
        self.resnet = nn.Sequential(resnet50.conv1,
                                    resnet50.bn1,
                                    resnet50.relu,
                                    resnet50.maxpool,
                                    resnet50.layer1,
                                    resnet50.layer2,
                                    resnet50.layer3,
                                    resnet50.layer4)

    def forward(self, x):
        x = self.resnet(x)
        return x

net = Net().eval()

image_path = []
all_images = []
images = os.listdir('./images_lianghua_v3')

for image_name in images:
    image_path.append('./images_lianghua_v3/' + image_name)
for path in image_path:
    image = Image.open(path).convert('RGB')
    image = transforms.Resize([224,224])(image)
    image = transforms.ToTensor()(image)
    image = image.unsqueeze(0)
    image = net(image)
    print(image.shape)
    image = image.reshape(-1, )
    all_images.append(image.detach().numpy())

clt = KMeans(n_clusters=15)
clt.fit(all_images)
labelIDs = np.unique(clt.labels_)
print(labelIDs)
all_labels = 0
for labelID in labelIDs:
    idxs = np.where(clt.labels_ == labelID)[0]
    all_labels += len(idxs)
    print(len(idxs))
    idxs = np.random.choice(idxs, size=min(25, len(idxs)),replace=False)
    
    
    show_box = []
    for i in idxs:
        image = cv2.imread(image_path[i])
        image = cv2.resize(image, (96, 96))
        show_box.append(image)
    montage = build_montages(show_box, (96, 96), (5, 5))[0]
    title = "Type{}".format(labelID)
    cv2.imwrite(title + ".jpg",montage)
    # cv2.imshow(title, montage)
    # cv2.waitKey(0)
print("所有图片:",all_labels)

    

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

zyb-小波

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值