【人工智能学习之商品检测实战】

1 开发过程

  1. 拍摄并处理数据集
  2. 训练YOLOV8侦测商品
  3. 训练特征网络识别商品
  4. 商品跟踪与后处理

2 网络训练效果

2.1 分割网络

在这里插入图片描述
在这里插入图片描述
n效果较差
m和x耗时更长但效果并没有非常突出

最终选择:yolov8-s-seg

2.2 特征网络

在这里插入图片描述
对比ResNet50,ResNet101,MobileNet和Densnet121与Densnet169
Densnet169最为精准,在本数据集上提取特征能力和泛化性最强
最终选择:Densnet169
最终余弦相似度测试:95.7%

3 跟踪与后处理

判断新旧物体进行跟踪,裁剪所需内容。
在这里插入图片描述

4 特征库优化

读取特征库时按类别进行字典嵌套
在这里插入图片描述

5 项目源码解析

5.1 yolo训练

train_yolo.py

作用:训练测试yolo模型

from ultralytics import YOLO
from PIL import Image
import cv2
# 版本有问题可以进行以下尝试
# import matplotlib
# matplotlib.use('TkAgg')

def yolo_seg_train():
    model = YOLO(model=r"D:\zhrdpy_project\AutShop\model\yolov8s-seg.pt")
    model.train(data="D:/zhrdpy_project/AutShop/data/YOLODataset/dataset.yaml", epochs=100, batch=-1, workers=0, amp=False)
def yolo_train():
    model = YOLO("D:/zhrdpy_project/AutShop/model/yolov8n.yaml")
    model = YOLO("D:/zhrdpy_project/AutShop/runs/detect/train7/weights/best.pt")

    model.train(data="D:/zhrdpy_project/AutShop/data/YOLODataset/dataset.yaml", epochs=10, batch=-1, workers=0)
    # metrics = model.val()
def yolo_val():
    # model = YOLO("D:/zhrdpy_project/AutShop/model/yolov8n.yaml")
    model = YOLO(task='segment',model="D:/zhrdpy_project/AutShop/runs/segment/train_x/weights/best.pt")
    metrics = model.val(
        data='D:/zhrdpy_project/AutShop/data/YOLODataset/dataset.yaml',
        # imgsz=, batch= ,conf,iou,max_det,half,device,dnn
        save_json=True,  # 将结果保存到 JSON 文件        # 默认False
        save_hybrid=True,  # 保存混合版本的标签(标签 + 其他预测)# 默认False
        # plots=True,         # 在训练期间显示绘图     # 默认False
        rect=True,  # 矩形 val,每批都经过整理,以实现最小的填充   # 默认False
    )  # no arguments needed, dataset and settings remembered
def yolo_test():
    # 用自己训练好的权重用自己的"ultralytics-main1/runs/detect/train5/weights/best.pt
    model = YOLO("D:/zhrdpy_project/AutShop/runs/segment/train_x/weights/best.pt")
    # accepts all formats - image/dir/Path/URL/video/PIL/ndarray. 0 for webcam
    # 零是摄像头,现在不用把他注释掉
    # results = model.predict(source="0")
    # source=用自己的验证图片绝对路径 , save=True保存
    results = model.predict(source=r"D:\zhrdpy_project\AutShop\data\test.mp4", show=False, save=True)
    # success = model.export(format="onnx")
if __name__ == '__main__':
    # yolo_train()
    # yolo_seg_train()
    # yolo_test()
    yolo_val()

good_net.py

作用:特征网络

import torchvision.models as models
from torch import nn
import torch
from torch.nn import functional as F
from good_cls_data import one_hot_size
import math
class Arcsoftmax(nn.Module):
    def __init__(self, feature_num, cls_num):
        super().__init__()
        self.w = nn.Parameter(torch.randn((feature_num, cls_num)))
        nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))  # 更好的初始化方式,对张量 w 进行 Kaiming 均匀初始化
    def forward(self, feature, m=0.5, s=1):
        x = F.normalize(feature, dim=1)
        w = F.normalize(self.w, dim=0)
        cos = torch.matmul(x, w) / 10  #防止梯度爆炸 /10
        a = torch.acos(cos)
        top = torch.exp(s * torch.cos(a + m))
        down = torch.sum(torch.exp(s * torch.cos(a)), dim=1, keepdim=True) - torch.exp(s * torch.cos(a)) + top
        out = torch.log(top/down)
        return out
    '''
    # 复杂但似乎没什么暖用的优化
    def forward(self, x, s=1, m=0.5):
        x_norm = F.normalize(x, dim=1)
        w_norm = F.normalize(self.w, dim=0)

        cos_theta = torch.matmul(x_norm, w_norm)
        theta = torch.acos(cos_theta.clamp(-1 + 1e-7, 1 - 1e-7))  # 添加钳位防止acos溢出
        cos_theta_m = cos_theta - m
        idx = cos_theta > math.pi - m  # 为防止溢出,对角度接近π的情况进行特殊处理
        cos_theta_m[idx] = torch.cos(theta[idx] + m)
        # 对于减法部分,需要广播以保持维度一致
        adjustment = torch.where(idx.unsqueeze(1), torch.exp(s * torch.cos(theta.unsqueeze(1))), torch.tensor(0., device=cos_theta.device))

        numerator = torch.exp(s * cos_theta_m)
        denominator = numerator.sum(dim=1, keepdim=True) - adjustment.sum(dim=1, keepdim=True) + numerator
        # 确保denominator非零
        denominator = torch.clamp(denominator, min=1e-10)
        arcface_loss = torch.log(numerator / denominator)
        
        这里的修改主要是为了确保减法操作前后的张量在形状上是一致的。
        通过使用unsqueeze(1)来增加维度,使得adjustment张量能够与numerator进行广播操作,
        同时使用where函数来确保只有在idx标记为True的位置上才进行exp(s * cos(theta))的计算,
        其他位置保持为0,这样可以正确地进行减法而不改变张量的形状。
        此外,为了避免对数函数中分母为零的问题,添加了clamp函数来限制最小值。

        return arcface_loss


'''




class GoodNet(nn.Module):
    def __init__(self):
        super(GoodNet, self).__init__()
        self.nll_loss = nn.NLLLoss()
        self.loss_fn = nn.CrossEntropyLoss()
        # self.sub_net = nn.Sequential(
            # models.densenet121(weights = models.DenseNet169_Weights.IMAGENET1K_V1)
            # models.densenet169(weights = models.DenseNet169_Weights.IMAGENET1K_V1)
            # models.densenet201(weights = models.DenseNet169_Weights.IMAGENET1K_V1)
        # )
        self.sub_net = models.densenet169(weights=None) # ***_new.pt
        self.feature_net = nn.Sequential(
            nn.BatchNorm1d(1000),
            nn.LeakyReLU(0.1),
            nn.Linear(1000, 512, bias=False),
        )
        self.arc_softmax = Arcsoftmax(512, one_hot_size)

    def forward(self, x):
        feature = self.feature_net(self.sub_net(x))
        return feature, self.arc_softmax(feature)

    def get_feature(self, x):
        return self.feature_net(self.sub_net(x))

    def getSoftmaxLoss(self, outputs, labels):
        return self.nll_loss(outputs,labels)


if __name__ == '__main__':
    net = GoodNet()
    net.sub_net.load_state_dict(torch.load("weight/sub_net.pt"))
    net.feature_net.load_state_dict(torch.load("weight/feature_net.pt"))
    # net.load_state_dict(torch.load("weight/net.pt"))
    print(net)

dataset.py

作用:特征网络训练测试训练集加载器

import glob
import os.path

import cv2
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from good_net import one_hot_size
from good_cls_data import one_hot_dict

train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.5),  # 执行水平翻转的概率为0.5
    transforms.RandomVerticalFlip(p=0.5),  # 执行垂直翻转的概率为0.5
    transforms.Resize((320, 320), antialias=True)
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((320, 320), antialias=True)
])
class TrainDataset(Dataset):
    def __init__(self,root=r"./data/CLSDataset_train", transform=train_transform):
        super().__init__()
        img_paths = glob.glob(os.path.join(root,"*","*","*"))
        self.dataset = []
        for path in img_paths:
            label = path.rsplit('\\',maxsplit=2)[-2]
            self.dataset.append((label,path))
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        label, img_path = self.dataset[idx]
        img = Image.open(img_path)
        img_tensor = self.transform(img)
        one_hot = torch.zeros(one_hot_size)
        one_hot[one_hot_dict[label]] = 1
        one_hot_idx = one_hot.argmax()
        return one_hot_idx,img_tensor

class TestDataset(Dataset):
    def __init__(self,root=r"./data/CLSDataset_test", transform=test_transform):
        super().__init__()
        img_paths = glob.glob(os.path.join(root,"*","*"))
        self.dataset = []
        for path in img_paths:
            label = path.rsplit('\\',maxsplit=2)[-2]
            self.dataset.append((label,path))
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        label, img_path = self.dataset[idx]
        img = Image.open(img_path)
        img_tensor = self.transform(img)
        one_hot = torch.zeros(one_hot_size)
        one_hot[one_hot_dict[label]] = 1
        one_hot_idx = one_hot.argmax()
        return one_hot_idx,img_tensor



class TestDataset2(Dataset):
    def __init__(self, root_dir=r"./data/CLSDataset_test2", transform=test_transform):
        self.root_dir = root_dir
        self.classes = sorted(os.listdir(root_dir))  # 获取所有分类目录
        self.class_images = {cls: [os.path.join(cls, img) for img in os.listdir(os.path.join(root_dir, cls))] for cls in self.classes}
        self.transform = transform
        self.class_count = len(self.classes)
        self.image_per_class = 10  # 每个分类的图片数量

    def __len__(self):
        return self.image_per_class # 每批25张总共10批

    def __getitem__(self, index):
        images_of_batch = []
        labels_of_batch = []

        for class_index in range(self.class_count):
            class_name = self.classes[class_index]
            img_path = self.class_images[class_name][index % self.image_per_class]  # 确保循环取图
            img = Image.open(os.path.join(self.root_dir, img_path))

            if self.transform is not None:
                img = self.transform(img)

            images_of_batch.append(img)
            labels_of_batch.append(class_index)  # 类别ID可以直接用索引表示

        # 如果想要返回一个批次的数据,可以将它们打包在一起
        images_of_batch = torch.stack(images_of_batch)  # 将图片列表转换为Tensor
        labels_of_batch = torch.tensor(labels_of_batch)  # 将类别ID列表转换为Tensor

        return labels_of_batch,images_of_batch  # 返回整个批次的图片和标签



if __name__ == '__main__':
    '''
    m = TestDataset()
    one_hot,img_tensor = m[3]
    img_path = r'D:\zhrdpy_project\AutShop\data\CLSDataset_train\bag\lao_mu_ji_tang_mian_dai_zhuang\3382.png'
    img = Image.open(img_path)
    img.show('1')
    imgcv2 = cv2.imread(img_path)
    cv2.imshow('2',imgcv2)
    cv2.waitKey(0)
    '''

    # 初始化数据集
    dataset = TestDataset2()

    # 使用 DataLoader 进行批量加载
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)  # 注意,batch_size设为1是因为每次迭代已经包含了所有分类的图片
    i = 0
    # 遍历数据加载器进行测试
    for labels ,images in dataloader:
        # 在这里执行你的测试逻辑
        print(labels)
        i+=1
        pass
    print(i)

good_cls_data.py

作用:通用分类文本信息

one_hot_size = 25
one_hot_dict = {
                'bai_shi_ke_le_guan_zhuang': 0,
                'ke_kou_ke_le_guan_zhuang': 1,
                'yong_chuang_tian_ya_guan_zhuang': 2,
                'jian_li_bao_guan_zhuang': 3,
                'le_shi_shu_pian_xiao_tong_huanggua_wei': 4,
                'xue_bi_guan_zhuang': 5,
                'qiao_ke_li_bing_gan_he_zhuang': 6,
                'bing_lu_ping_zhuang': 7,
                'le_shi_shu_pian_da_tong_huanggua_wei': 8,
                'le_shi_shu_pian_da_tong_ning_meng_wei': 9,
                'kang_shi_fu_hong_shao_niu_rou_mian_tong_zhuang': 10,
                'bing_ling_qi_bu_ding_he_zhuang': 11,
                'lao_tan_suan_cai_niu_rou_mian_tong_zhuang': 12,
                'a_sa_mu_nai_cha_ping_zhuang': 13,
                'mei_zhi_yuan_ping_zhuang': 14,
                'xing_qiu_bei_yi_tong': 15,
                'nong_fu_shan_quan_ping_zhuang': 16,
                'le_shi_shu_pian_dai_zhuang_niu_pai_wei': 17,
                'le_shi_shu_pian_dai_zhuang_huanggua_wei': 18,
                'lao_mu_ji_tang_mian_dai_zhuang': 19,
                'ou_de_cui_pian_dai_zhuang_jie_mo_wei': 20,
                'mai_xiang_ji_wei_kuai_dai_zhuang': 21,
                'wei_hua_bing_gan_he_zhuang': 22,
                'O_pao_nai_he_zhuang': 23,
                'chun_niu_nai_he_zhuang': 24,
                }
# 乐事薯片小桶烧烤味:le_shi_shu_pian_xiao_tong_shaokao_wei,未注册参照样本test
# 老坛泡椒牛肉面桶装:lao_tan_pao_jiao_niu_rou_mian_tong_zhuang,未注册参照样本test
# 香蕉牛奶盒装:'xiang_jiao_niu_nai_he_zhuang',未注册参照样本test
# class names
# 百事可乐罐装,可口可乐罐装,勇闯天涯罐装,健力宝罐装,乐事薯片小罐黄瓜味,雪碧罐装,乐事薯片小罐烧烤味,冰露瓶装,乐事薯片大罐黄瓜味,乐事薯片大罐柠檬味,康师傅红烧牛肉面,老坛泡椒牛肉面,老坛酸菜牛肉面,阿萨姆奶茶瓶装,美汁源瓶装,星球杯一桶,农夫山泉瓶装,乐事薯片袋装牛排味,乐事薯片袋装黄瓜味,老母鸡汤面袋装,藕的脆片袋装芥末味,麦香鸡味块袋装,威化饼干盒装,O泡奶盒装,纯牛奶盒装,冰淇淋布丁盒装,巧克力饼干盒装,香蕉牛奶盒装,
name_dict = {
    '百事可乐罐装':'bai_shi_ke_le_guan_zhuang',
    '可口可乐罐装':'ke_kou_ke_le_guan_zhuang',
    '勇闯天涯罐装':'yong_chuang_tian_ya_guan_zhuang',
    '健力宝罐装':'jian_li_bao_guan_zhuang',
    '乐事薯片小桶黄瓜味':'le_shi_shu_pian_xiao_tong_huanggua_wei',
    '雪碧罐装':'xue_bi_guan_zhuang',
    '乐事薯片小桶烧烤味':'le_shi_shu_pian_xiao_tong_shaokao_wei',
    '冰露瓶装':'bing_lu_ping_zhuang',
    '乐事薯片大桶黄瓜味':'le_shi_shu_pian_da_tong_huanggua_wei',
    '乐事薯片大桶柠檬味':'le_shi_shu_pian_da_tong_ning_meng_wei',
    '康师傅红烧牛肉面':'kang_shi_fu_hong_shao_niu_rou_mian_tong_zhuang',
    '老坛泡椒牛肉面桶装':'lao_tan_pao_jiao_niu_rou_mian_tong_zhuang',
    '老坛酸菜牛肉面桶装':'lao_tan_suan_cai_niu_rou_mian_tong_zhuang',
    '阿萨姆奶茶瓶装':'a_sa_mu_nai_cha_ping_zhuang',
    '美汁源瓶装':'mei_zhi_yuan_ping_zhuang',
    '星球杯一桶':'xing_qiu_bei_yi_tong',
    '农夫山泉瓶装':'nong_fu_shan_quan_ping_zhuang',
    '乐事薯片袋装牛排味':'le_shi_shu_pian_dai_zhuang_niu_pai_wei',
    '乐事薯片袋装黄瓜味':'le_shi_shu_pian_dai_zhuang_huanggua_wei',
    '老母鸡汤面袋装':'lao_mu_ji_tang_mian_dai_zhuang',
    '藕的脆片袋装芥末味':'ou_de_cui_pian_dai_zhuang_jie_mo_wei',
    '麦香鸡味块袋装':'mai_xiang_ji_wei_kuai_dai_zhuang',
    '威化饼干盒装':'wei_hua_bing_gan_he_zhuang',
    'O泡奶盒装':'O_pao_nai_he_zhuang',
    '纯牛奶盒装':'chun_niu_nai_he_zhuang',
    '冰淇淋布丁盒装':'bing_ling_qi_bu_ding_he_zhuang',
    '巧克力饼干盒装':'qiao_ke_li_bing_gan_he_zhuang',
    '香蕉牛奶盒装':'xiang_jiao_niu_nai_he_zhuang'
}

cls_dict = {
    'bai_shi_ke_le_guan_zhuang':'can',
    'ke_kou_ke_le_guan_zhuang':'can',
    'yong_chuang_tian_ya_guan_zhuang':'can',
    'jian_li_bao_guan_zhuang':'can',
    'le_shi_shu_pian_xiao_tong_huanggua_wei':'bucket',
    'xue_bi_guan_zhuang':'can',
    'le_shi_shu_pian_xiao_tong_shaokao_wei':'bucket',
    'bing_lu_ping_zhuang':'bottle',
    'le_shi_shu_pian_da_tong_huanggua_wei':'bucket',
    'le_shi_shu_pian_da_tong_ning_meng_wei':'bucket',
    'kang_shi_fu_hong_shao_niu_rou_mian_tong_zhuang':'bucket',
    'lao_tan_pao_jiao_niu_rou_mian_tong_zhuang':'bucket',
    'lao_tan_suan_cai_niu_rou_mian_tong_zhuang':'bucket',
    'a_sa_mu_nai_cha_ping_zhuang':'bottle',
    'mei_zhi_yuan_ping_zhuang':'bottle',
    'xing_qiu_bei_yi_tong':'bucket',
    'nong_fu_shan_quan_ping_zhuang':'bottle',
    'le_shi_shu_pian_dai_zhuang_niu_pai_wei':'bag',
    'le_shi_shu_pian_dai_zhuang_huanggua_wei':'bag',
    'lao_mu_ji_tang_mian_dai_zhuang':'bag',
    'ou_de_cui_pian_dai_zhuang_jie_mo_wei':'bag',
    'mai_xiang_ji_wei_kuai_dai_zhuang':'bag',
    'wei_hua_bing_gan_he_zhuang':'box',
    'O_pao_nai_he_zhuang':'box',
    'chun_niu_nai_he_zhuang':'box',
    'bing_ling_qi_bu_ding_he_zhuang':'box',
    'qiao_ke_li_bing_gan_he_zhuang':'box',
    'xiang_jiao_niu_nai_he_zhuang':'box'
}

from torchvision import transforms
def img_transforms(image):
    size_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320), antialias=True)
    ])
    return size_transform(image)

save_feature.py

作用:保存和读取特征文件

import torch
import glob
import cv2
import os
from good_cls_data import cls_dict
import good_net
from good_cls_data import img_transforms
import json
mydict={}
# 保存字典到文件
filepath = "dictionary.pt"
def add_dict_feature(feature,name,my_dict):
    my_dict[name] = feature
    print(f"{name}已写入字典")

def save_dict_feature(my_dict):
    torch.save(my_dict, filepath)
    print("已保存字典到文件:", filepath)

def load_dict_feature():
    try:
        loaded_dict = torch.load(filepath)
        cls_goodF_dict = {}
        for name, good_feature in loaded_dict.items():
            good_name_front = name.rsplit('_', maxsplit=1)[-2]
            if cls_dict[good_name_front] not in cls_goodF_dict:
                cls_goodF_dict[cls_dict[good_name_front]] = {name:good_feature}
            else:
                cls_goodF_dict[cls_dict[good_name_front]][name] = good_feature
        return cls_goodF_dict
    except:
        print("没有找到特征文件")
        return mydict


def log_all(): # 字典优化后弃用
    net = good_net.GoodNet()
    net.load_state_dict(torch.load("weight/best_new957.pt"))
    net.eval()
    loggood_dict = load_dict_feature()
    # 内存不够只能逐个文件夹存储,不然会爆
    base_path = 'D:/zhrdpy_project/AutShop/data/log_img'
    dir_paths = glob.glob(os.path.join(base_path, "*"))
    for dir_path in dir_paths:
        img_paths = glob.glob(os.path.join(dir_path, "*"))
        for path in img_paths:
            img_name = path.rsplit('\\', maxsplit=2)[-1]
            name = img_name.split('.', maxsplit=2)[0]
            img = cv2.imread(path)
            image_tensor = img_transforms(img)
            image_tensor = torch.unsqueeze(image_tensor,dim=0)
            feature = net.get_feature(image_tensor)
            add_dict_feature(feature, name, loggood_dict)
        save_dict_feature(loggood_dict)
        loggood_dict.clear()   # 释放内存
        loggood_dict = load_dict_feature()
'''
当我通过模型得到特征向量,并将其存储到字典中时,原先的特征向量(torch.Tensor对象)所占用的内存并不会立即被释放。
这是因为Python的垃圾回收机制(Garbage Collector,GC)并不保证在对象不再被引用时立即回收其占用的内存,尤其是在处理大型数据结构时。
GC的工作机制是周期性的,它会在内存压力达到一定程度或经过一定的时间间隔后运行,来清理不再使用的对象。
在我的场景中,每一次模型计算产生的torch.Tensor对象都会占用一定的内存空间,即使该对象随后被放入字典并可能被新的条目覆盖,
只要该对象还在某个地方被引用(就比如在我的字典中),它的内存就不会被立即释放。这意味着,如果我不清空字典,那些曾经存储在字典中的torch.Tensor对象,
它们的生命周期还没有结束(即还有其他引用指向它们),那么它们占用的内存就会持续存在,直到GC运行并确定它们确实不再被引用,才会回收这部分内存。
GC不会回收那些torch.Tensor对象所以我的内存就会爆,当我手动清空字典(使用dict.clear()或重新初始化字典)时,字典内部对所有torch.Tensor对象的引用都被移除,
此时如果这些对象没有其他外部引用,它们将变成孤立的对象,不再被任何变量引用,这就使得它们满足了被GC回收的条件。一旦GC运行,它就能检测到这些孤立的torch.Tensor对象,
并释放它们占用的内存。我需要的仅仅只是计算结果,所以重新读取的特征向量只有计算结果而其之前计算产生的torch.Tensor对象已经被回收了,我的内存就不会爆。
'''
def log_all_without_grad():
    net = good_net.GoodNet()
    net.load_state_dict(torch.load("weight/best_new957.pt"))
    net.eval()
    loggood_dict = {}
    base_path = 'D:/zhrdpy_project/AutShop/data/log_img'
    dir_paths = glob.glob(os.path.join(base_path, "*"))
    for dir_path in dir_paths:
        img_paths = glob.glob(os.path.join(dir_path, "*"))
        for path in img_paths:
            img_name = path.rsplit('\\', maxsplit=2)[-1]
            name = img_name.split('.', maxsplit=2)[0]
            img = cv2.imread(path)
            image_tensor = img_transforms(img)
            image_tensor = torch.unsqueeze(image_tensor,dim=0)
            with torch.no_grad():  # 禁用梯度计算以节省内存
                feature = net.get_feature(image_tensor)
            add_dict_feature(feature, name, loggood_dict)
    save_dict_feature(loggood_dict)

def log_single():
    net = good_net.GoodNet()
    net.load_state_dict(torch.load("weight/best_new957.pt"))
    net.eval()
    loggood_dict = load_dict_feature()
    path = 'D:/zhrdpy_project/AutShop/data/log_img/xiang_jiao_niu_nai_he_zhuang/xiang_jiao_niu_nai_he_zhuang_1040.jpg'
    img_name = path.rsplit('/', maxsplit=2)[-1]
    name = img_name.split('.', maxsplit=2)[0]
    img = cv2.imread(path)
    image_tensor = img_transforms(img)
    image_tensor = torch.unsqueeze(image_tensor,dim=0)
    feature = net.get_feature(image_tensor)
    add_dict_feature(feature, name, loggood_dict)
    print(name + '成功注册')
    save_dict_feature(loggood_dict)

if __name__ == '__main__':
    log_all_without_grad()
    # log_single()
    # print('加载注册信息——————————————————————————————————————————')
    show = load_dict_feature()
    print(show)

注意:

  • 当我通过模型得到特征向量,并将其存储到字典中时,原先的特征向量(torch.Tensor对象)所占用的内存并不会立即被释放。这是因为Python的垃圾回收机制(Garbage Collector,GC)并不保证在对象不再被引用时立即回收其占用的内存,尤其是在处理大型数据结构时。
    GC的工作机制是周期性的,它会在内存压力达到一定程度或经过一定的时间间隔后运行,来清理不再使用的对象。在我的场景中,每一次模型计算产生的torch.Tensor对象都会占用一定的内存空间,即使该对象随后被放入字典并可能被新的条目覆盖,只要该对象还在某个地方被引用(就比如在我的字典中),它的内存就不会被立即释放。这意味着,如果我不清空字典,那些曾经存储在字典中的torch.Tensor对象,它们的生命周期还没有结束(即还有其他引用指向它们),那么它们占用的内存就会持续存在,直到GC运行并确定它们确实不再被引用,才会回收这部分内存。
    GC不会回收那些torch.Tensor对象所以我的内存就会爆,当我手动清空字典(使用dict.clear()或重新初始化字典)时,字典内部对所有torch.Tensor对象的引用都被移除,此时如果这些对象没有其他外部引用,它们将变成孤立的对象,不再被任何变量引用,这就使得它们满足了被GC回收的条件。一旦GC运行,它就能检测到这些孤立的torch.Tensor对象,并释放它们占用的内存。我需要的仅仅只是计算结果,所以重新读取的特征向量只有计算结果而其之前计算产生的torch.Tensor对象已经被回收了,我的内存就不会爆。
  • 另外可以直接禁用梯度计算(一开始没想到哈哈)
    with torch.no_grad(): # 禁用梯度计算以节省内存和加速推理

analyse_good.py

作用:分析与跟踪商品

import torch
import torch.hub
import numpy as np
import cv2
import glob
import os
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
from good_cls_data import *
from save_feature import *
from good_net import GoodNet
from ultralytics import YOLO
import torch.nn.functional as F
import matplotlib
import time
matplotlib.use('TkAgg')
FONT = ImageFont.truetype('simsun.ttc', size=30)
COLOR = ['blue', 'green', 'yellow', 'orange', 'purple', 'brown', 'red']
YOLO_LOAD = r'D:\zhrdpy_project\AutShop\weight\best_YOLOs.pt'
MY_NET_LOAD = "weight/best_new957.pt"
DETECT_TIMES = 30
# 定义一个训练的设备device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GoodDetect:
    def __init__(self):
        # 加载目标侦测的模型
        self.yolo_model = self.load_yolomodel().to(device)
        # 加载特征检测的模型
        self.screen_model = self.load_mymodel().to(device)
        # 加载已注册商品
        self.loggood_dict = load_dict_feature() # {cls:{'name':feature}}
        # 暂存物体信息,判断是否新加
        self.good_positions = []
        self.good_names = []
        self.goods_dic = {} # {id:{'name_cn':'xx', 'good_cls':good_cls, 'position':good_position, 'detect_times':1}}
        self.id_num = 0
        # 初始裁剪区域
        self.region_mask = np.zeros((320, 320, 3), dtype=np.uint8)

    def load_yolomodel(self):
        model = YOLO(YOLO_LOAD)
        model.conf = 0.85
        return model

    def load_mymodel(self):
        net = GoodNet()
        net.load_state_dict(torch.load(MY_NET_LOAD))
        net.eval()
        return net

    def compera_good(self, target_good_feature, good_cls, threshold=0.96):
        matching_goods = {}  # 使用字典来记录每个商品名及其出现的次数
        good_similarity = {} # 俩个出现相同次数输出最相似的
        for cls, dic in self.loggood_dict.items():
            if good_cls == cls :
                # k = 0
                for name, good_feature in dic.items():
                    similarity = F.cosine_similarity(target_good_feature, good_feature.to(device)).item()
                    if similarity >= threshold:
                        print('相似度达到阈值:' + name + ':' + str(similarity))
                        good_name_front = name.rsplit('_', maxsplit=1)[-2]
                        # 记录物品最大相似度
                        if good_name_front in good_similarity:
                            if similarity > good_similarity[good_name_front]:
                                good_similarity[good_name_front] = similarity
                        else:
                            good_similarity[good_name_front] = similarity

                        # 记录相似物品次数
                        if good_name_front in matching_goods:
                            matching_goods[good_name_front] += 1
                        else:
                            matching_goods[good_name_front] = 1
                    # k += 1
                # print(f'检测次数{k}')
        # 如果没有任何商品名达到阈值,返回'unknown'
        if not matching_goods:
            return 'unknown'

        # 找出出现次数最多的商品名
        most_common_goods = [k for k, v in matching_goods.items() if v == max(matching_goods.values())]
        # print('出现次数最多:' + str(most_common_goods))

        # 在出现次数相同的商品中选取相似度最大的
        if len(most_common_goods) > 1:
            most_similar_good = max(most_common_goods, key=lambda x: good_similarity[x])
            return most_similar_good
        else:
            return most_common_goods[0]

    """弃用
    def compera_good(self,target_good_feature,config=0.90):
        max_similarity = -1  # 初始化最大相似度为负数,保证之后会被更新
        good_name = None
        for name, good_feature in self.loggood_dict.items():
            similarity = F.cosine_similarity(target_good_feature, good_feature)
            if similarity.item() > max_similarity:
                max_similarity = similarity.item()
                good_name = name
        print(max_similarity)
        if max_similarity < config:
            good_name = 'unkown'
            return good_name
        return good_name
    """

    def compare_position(self, box1, box2, iou_threshold=0.3):
        x1_1, y1_1, x2_1, y2_1 = box1
        x1_2, y1_2, x2_2, y2_2 = box2

        # 计算两个矩形框的面积
        area_1 = (x2_1 - x1_1) * (y2_1 - y1_1)
        area_2 = (x2_2 - x1_2) * (y2_2 - y1_2)

        # 计算交集区域
        intersect_x1 = max(x1_1, x1_2)
        intersect_y1 = max(y1_1, y1_2)
        intersect_x2 = min(x2_1, x2_2)
        intersect_y2 = min(y2_1, y2_2)

        # 确保交集区域有效
        intersection_area = max(0, intersect_x2 - intersect_x1) * max(0, intersect_y2 - intersect_y1)

        # 计算IoU
        iou = intersection_area / (area_1 + area_2 - intersection_area)

        # 判断是否超过阈值,超过则可能是同一物体,返回T表明不是新商品而是同一商品
        return iou >= iou_threshold

    def get_region_mask(self,img,rect,box2,ls):
        # 创建一个与原始图像大小相同的全黑掩码图像
        b_mask = np.zeros(img.shape[:2], np.uint8)
        # 绘制掩码轮廓
        cv2.drawContours(b_mask, [ls], -1, (255, 255, 255), cv2.FILLED)

        (height, width) = rect[1]
        width = int(width)
        height = int(height)

        src_points = box2.squeeze().astype(np.float32)
        dst_points = np.float32([[0, 0], [width, 0], [width, height], [0, height]])

        # 计算透视变换矩阵
        M = cv2.getPerspectiveTransform(src_points, dst_points)

        # 创建一个与img大小相同的3通道图像,初始值为黑色
        isolated_3_channel = np.zeros_like(img)

        # 使用b_mask作为掩码,将原始图像img的像素复制到isolated_3_channel
        isolated_3_channel[b_mask > 0] = img[b_mask > 0]

        # 应用透视变换
        warped_image = cv2.warpPerspective(isolated_3_channel, M, (width, height))

        # 摆正的最小外接矩形
        region = warped_image
        if width > height:
            region = cv2.rotate(region, cv2.ROTATE_90_COUNTERCLOCKWISE)
            width, height = height, width

        # 创建一个正方形底图
        region_mask = np.zeros((height, height, 3), dtype=np.uint8)
        # 计算region图片在底图上的起始位置
        start_x = (height - width) // 2
        start_y = (height - height) // 2

        # 使用NumPy索引将region图片放置到底图上
        region_mask[start_y:start_y + height, start_x:start_x + width, :] = region
        # cv2.imshow('test',region_mask)

        self.region_mask = region_mask

    def region_to_tensor(self):
        return img_transforms(trans_square(Image.fromarray(self.region_mask))).unsqueeze(dim=0)

    def name_to_cn(self,good_name):
        if good_name == 'unknown':
            name_cn = '未注册商品'
        else:
            name_cn = [k for k, v in name_dict.items() if v == good_name][0]
        return name_cn

    def detect_vidio(self, img):
        draw_img = np.copy(img)
        results = self.yolo_model(img)
        if results[0] is not None and results[0].masks is not None:
            number = len(results[0].masks.xy)
            good_names = []
            goods_cls = results[0].boxes.cls.tolist()
            for i in range(number):
                new_good = False # 默认不是新商品
                good_cls = results[0].names[goods_cls[i]]
                contours = np.array(results[0].masks.xy[i], dtype=np.int32)
                ls = contours.reshape(-1, 1, 2)
                ln = contours.reshape(1, -1, 1, 2)

                # 最小矩形
                # 返回值格式为 (center, (width, height), angle),angle 描述的是最小外接矩形的长边(width 边)相对于最上边水平线的旋转角度。
                rect = cv2.minAreaRect(ls)
                box = cv2.boxPoints(rect)
                box2 = np.int32(box).reshape(-1, 4, 2)

                # 在原图上绘制最小外接矩形
                img_contour = cv2.polylines(draw_img, box2, True, (0, 0, 255), 3)
                img_contour = cv2.drawContours(draw_img, tuple(ln), -1, (0, 255, 0), 3)

                # 最小矩形的最小矩形坐标
                if number == 1:
                    x1, y1, x2, y2 = results[0].boxes.xyxy.cpu().numpy().squeeze().astype(np.int32)
                else:
                    x1, y1, x2, y2 = results[0].boxes.xyxy.cpu().numpy().squeeze().astype(np.int32)[i]

                # 记录此帧位置,最小矩形的最小矩形,保持画面水平一致
                good_position = (x1, y1, x2, y2)

                old_id = 0 # 记录需要操作的已有商品的id
                # 与上一帧位置对比
                if not self.goods_dic:
                    # 没有记录一定是新商品
                    new_good = True
                else:
                    # 遍历字典
                    for id in self.goods_dic:
                        # 筛选同类商品
                        if self.goods_dic[id]['good_cls'] == good_cls:
                            # 同类商品进行IOU进一步判断是否是新商品
                            if self.compare_position(good_position,self.goods_dic[id]['position']):
                                new_good = False
                                old_id = id # 记录需要操作的已有商品的id
                                '''为使熔断机制提前放弃有问题的帧,对已有商品的确认操作应置后
                                # 如果是存在过的商品则需要跟踪检测,已存在物体进行多次复查
                                if self.goods_dic[id]['detect_times'] <= DETECT_TIMES:
                                    # 裁剪出目标商品区域
                                    self.get_region_mask(img, rect, box2, ls)
                                    # 开始计时
                                    start_time = time.time()
                                    # 传递mask获取特征
                                    with torch.no_grad():  # 禁用梯度计算以节省内存和加速推理
                                        feature = self.screen_model.get_feature(self.region_to_tensor().to(device))
                                    # 结束计时
                                    end_time = time.time()
                                    # 计算推理时间ms
                                    inference_time = (end_time - start_time)*1000
                                    # 比较余弦相似度获取名称
                                    good_name = self.compera_good(feature,good_cls)
                                    name_cn = self.name_to_cn(good_name)
                                    # 判断之前检测结果是否有误
                                    if name_cn == self.goods_dic[id]['name_cn']:
                                        self.goods_dic[id]['detect_times'] += 1
                                        print(f"本次模型处理{self.goods_dic[id]['name_cn']}用时: {inference_time:.1f} ms ---第{self.goods_dic[id]['detect_times']}次检测无误")
                                    else:
                                        print(f"本次模型处理{self.goods_dic[id]['name_cn']}用时: {inference_time:.1f} ms ---发现检测错误,已更正为{name_cn}!")
                                        self.goods_dic[id]['name_cn'] = name_cn
                                        self.goods_dic[id]['detect_times'] = 1
                                    # 更新本帧位置信息
                                    self.goods_dic[id]['position'] = good_position
                                # 已达侦测次数视为达标
                                else:
                                    print(f'{self.goods_dic[id]["name_cn"]}达到检测次数{DETECT_TIMES}标准,已确认,不再检测')
                                    '''
                                # 确定当前商品不是新商品即可直接break
                                break
                            else:
                                new_good = True
                        # 没有同类别商品一定是新的商品
                        else:
                            new_good = True

                # 对新商品进行处理
                if new_good:
                    # 添加进goods字典,新商品初始化
                    id = self.id_num
                    self.id_num += 1
                    self.goods_dic[id] = {'name_cn':'xx', 'good_cls':good_cls, 'position':good_position, 'detect_times':1}
                    # 裁剪出目标商品区域
                    self.get_region_mask(img, rect, box2, ls)
                    # 开始计时
                    start_time = time.time()
                    # 传递mask获取特征
                    with torch.no_grad():  # 禁用梯度计算以节省内存和加速推理
                        feature = self.screen_model.get_feature(self.region_to_tensor().to(device))
                    # 结束计时
                    end_time = time.time()
                    # 计算推理时间
                    inference_time = (end_time - start_time)*1000
                    # 比较余弦相似度获取名称
                    good_name = self.compera_good(feature,good_cls)
                    name_cn = self.name_to_cn(good_name)
                    self.goods_dic[id]['name_cn'] = name_cn
                    print(f"本次模型处理{self.goods_dic[id]['name_cn']}用时: {inference_time:.1f} ms ---侦测到新商品!!!")

                # 熔断机制
                if len(self.goods_dic) > number:
                    self.goods_dic.clear()
                    return draw_img, self.region_mask, ['请勿遮挡商品!']

                # 对已有商品进行处理
                if new_good == False:
                    id = old_id
                    # 如果是存在过的商品则需要跟踪检测,已存在物体进行多次复查
                    if self.goods_dic[id]['detect_times'] <= DETECT_TIMES:
                        # 裁剪出目标商品区域
                        self.get_region_mask(img, rect, box2, ls)
                        # 开始计时
                        start_time = time.time()
                        # 传递mask获取特征
                        with torch.no_grad():  # 禁用梯度计算以节省内存和加速推理
                            feature = self.screen_model.get_feature(self.region_to_tensor().to(device))
                        # 结束计时
                        end_time = time.time()
                        # 计算推理时间ms
                        inference_time = (end_time - start_time) * 1000
                        # 比较余弦相似度获取名称
                        good_name = self.compera_good(feature, good_cls)
                        name_cn = self.name_to_cn(good_name)
                        # 判断之前检测结果是否有误
                        if name_cn == self.goods_dic[id]['name_cn']:
                            self.goods_dic[id]['detect_times'] += 1
                            print(
                                f"本次模型处理{self.goods_dic[id]['name_cn']}用时: {inference_time:.1f} ms ---第{self.goods_dic[id]['detect_times']}次检测无误")
                        else:
                            print(
                                f"本次模型处理{self.goods_dic[id]['name_cn']}用时: {inference_time:.1f} ms ---发现检测错误,已更正为{name_cn}!")
                            self.goods_dic[id]['name_cn'] = name_cn
                            self.goods_dic[id]['detect_times'] = 1
                        # 更新本帧位置信息
                        self.goods_dic[id]['position'] = good_position
                    # 已达侦测次数视为达标
                    else:
                        print(f'{self.goods_dic[id]["name_cn"]}达到检测次数{DETECT_TIMES}标准,已确认,不再检测')

            # 返回当前帧信息
            if self.goods_dic:
                for idx,id in enumerate(self.goods_dic):
                    good_names.append(self.goods_dic[id]['name_cn'])
                    draw_img = cv2ImgAddText(draw_img, self.goods_dic[id]['name_cn'], self.goods_dic[id]['position'][0], self.goods_dic[id]['position'][1] - 30, textColor=COLOR[idx % len(COLOR)])

            return draw_img,self.region_mask,good_names
        return draw_img,self.region_mask,['没有商品']

    def detect_img(self, img, log=True, log_name='0'):
        draw_img = np.copy(img)
        region = img[0:320, 0:320]
        region_mask = img[0:320, 0:320]
        results = self.yolo_model(img)
        if results[0] is not None and results[0].masks is not None:
            number = len(results[0].masks.xy)
            good_names = []
            goods_cls = results[0].boxes.cls.tolist()
            for i in range(number):
                good_cls = results[0].names[goods_cls[i]]
                contours = np.array(results[0].masks.xy[i], dtype=np.int32)
                ls = contours.reshape(-1, 1, 2)
                ln = contours.reshape(1, -1, 1, 2)
                # 创建一个与原始图像大小相同的全黑掩码图像
                b_mask = np.zeros(img.shape[:2], np.uint8)
                # 绘制掩码轮廓
                cv2.drawContours(b_mask, [ls], -1, (255, 255, 255), cv2.FILLED)

                # 使用掩码图像抠出原始图像中的掩码区域(弃用2)
                # extracted_img = cv2.bitwise_and(img, img, mask=b_mask)
                # 掩码复合图(弃用)
                # isolated = np.dstack([img, b_mask])

                # 最小矩形
                # 返回值格式为 (center, (width, height), angle),angle 描述的是最小外接矩形的长边(width 边)相对于最上边水平线的旋转角度。
                rect = cv2.minAreaRect(ls)
                box = cv2.boxPoints(rect)
                box2 = np.int32(box).reshape(-1, 4, 2)

                (height, width) = rect[1]
                width = int(width)
                height = int(height)

                src_points = box2.squeeze().astype(np.float32)
                dst_points = np.float32([[0, 0], [width, 0], [width, height], [0, height]])

                # 计算透视变换矩阵
                M = cv2.getPerspectiveTransform(src_points, dst_points)

                # Create contour mask
                _ = cv2.drawContours(b_mask, [ls], -1, (255, 255, 255), cv2.FILLED)
                # 创建一个与img大小相同的3通道图像,初始值为黑色
                isolated_3_channel = np.zeros_like(img)

                # 使用b_mask作为掩码,将原始图像img的像素复制到isolated_3_channel
                isolated_3_channel[b_mask > 0] = img[b_mask > 0]

                # 应用透视变换
                warped_image = cv2.warpPerspective(isolated_3_channel, M, (width, height))

                # 摆正的最小外接矩形
                region = warped_image
                if width > height:
                    region = cv2.rotate(region, cv2.ROTATE_90_COUNTERCLOCKWISE)
                    width, height = height, width

                # 创建一个正方形底图
                region_mask = np.zeros((height, height, 3), dtype=np.uint8)
                # 计算region图片在底图上的起始位置
                start_x = (height - width) // 2
                start_y = (height - height) // 2

                # 使用NumPy索引将region图片放置到底图上
                region_mask[start_y:start_y + height, start_x:start_x + width, :] = region
                # cv2.imshow('test',region_mask)

                warped_image = Image.fromarray(region_mask)
                warped_image = trans_square(warped_image)
                warped_image = img_transforms(warped_image)
                warped_image = warped_image.unsqueeze(dim=0)

                """
                # 弃用2
                angle = rect[-1]
                # 计算旋转角度,使短边为宽
                if rect[1][0] > rect[1][1]:
                    angle = -(90 - angle)

                # 获取旋转矩阵  正角度(正数)表示逆时针旋转。
                # 负角度(负数)表示顺时针旋转。
                center = (rect[0][0], rect[0][1])
                M = cv2.getRotationMatrix2D(center, angle, 1.0)
                # 旋转图像
                (h, w) = extracted_img.shape[:2]
                rotated = cv2.warpAffine(extracted_img, M, (w, h))
                # 提取旋转后的矩形区域
                width, height = rect[1][0], rect[1][1]
                if width > height:
                    width, height = height, width
                cropped_rotated = cv2.getRectSubPix(rotated, (int(width), int(height)), center)
                # 将numpy数组转换为PIL图像
                pil_image = Image.fromarray((cropped_rotated).astype('uint8'))
                # 转换tensor
                input_tensor = img_transforms(pil_image).unsqueeze(dim=0)
                """
                if number == 1:
                    x1, y1, x2, y2 = results[0].boxes.xyxy.cpu().numpy().squeeze().astype(np.int32)
                else:
                    x1, y1, x2, y2 = results[0].boxes.xyxy.cpu().numpy().squeeze().astype(np.int32)[i]
                # 裁剪mask区域(弃用)
                # iso_crop = isolated[y1:y2, x1:x2]
                # iso_crop_3channel = iso_crop[:, :, :3]

                # 开始计时
                start_time = time.time()
                # 传递mask获取特征
                with torch.no_grad():  # 禁用梯度计算以节省内存和加速推理
                    feature = self.screen_model.get_feature(self.region_to_tensor().to(device))
                # 结束计时
                end_time = time.time()
                # 计算推理时间
                inference_time = (end_time - start_time)*1000
                print(f"本次模型处理用时: {inference_time:.1f} ms")

                if log == True:
                    good_name = log_name
                    cn_name = good_name
                    add_dict_feature(feature,good_name,self.loggood_dict)
                    good_names.append(good_name + '成功注册')
                    # save_dict_feature(self.loggood_dict)
                else:
                    good_name = self.compera_good(feature,good_cls)
                    if good_name == 'unknown':
                        cn_name = '未注册商品'
                    else:
                        cn_name = [k for k, v in name_dict.items() if v == good_name][0]
                    good_names.append(cn_name)


                # 绘制最小外接矩形
                draw_img = cv2ImgAddText(draw_img,cn_name,x1,y1-30,textColor=COLOR[i % len(COLOR)])
                img_contour = cv2.polylines(draw_img, box2, True, (0, 0, 255), 3)
                img_contour = cv2.drawContours(draw_img, tuple(ln), -1, (0, 255, 0), 3)
                # cv2.imwrite(base_path + '/' + f'{good_name}_{i}.jpg', region_mask, [cv2.IMWRITE_JPEG_QUALITY, 100])
            return draw_img,region_mask,good_names
        return draw_img,region,['没有商品']

def cv2ImgAddText(img, text, left, top, textColor=(0, 255, 0)):
    if (isinstance(img, np.ndarray)):  # 判断是否OpenCV图片类型
        img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    # 创建一个可以在给定图像上绘图的对象
    draw = ImageDraw.Draw(img)
    # 字体的格式 超参数已设置
    # 绘制文本
    draw.text((left, top), text, textColor, font=FONT)
    # 转换回OpenCV格式
    return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)

def trans_square(image):
    img = image.convert('RGB')
    img = np.array(img, dtype=np.uint8)  # 图片转换成numpy
    img_h, img_w, img_c = img.shape
    if img_h != img_w:
        # 宽和高的最大值和最小值
        long_side = max(img_w, img_h)
        short_side = min(img_w, img_h)
        # (宽-高)除以 2
        loc = abs(img_w - img_h) // 2
        # 如果高是长边则换轴,最后再换回来   WHC
        img = img.transpose((1, 0, 2)) if img_w < img_h else img
        # 创建正方形背景
        background = np.zeros((long_side, long_side, img_c), dtype=np.uint8)
        # 数据填充在中间位置
        background[loc:loc + short_side] = img[...]
        # HWC
        img = background.transpose((1, 0, 2)) if img_w < img_h else background
    return Image.fromarray(img, 'RGB')

def get_log_img():
    # 原图裁剪最小矩形进行注册
    gd = GoodDetect()
    base_path = 'D:/zhrdpy_project/AutShop/data/log_org'
    img_paths = glob.glob(os.path.join(base_path, "*", "*"))
    for img_path in img_paths:
        img_name = img_path.rsplit('\\', maxsplit=1)[-1]
        save_name = img_name.rsplit('.', maxsplit=1)[0]
        save_path = img_path.rsplit('\\', maxsplit=1)[0]
        good_name = img_path.rsplit('\\', maxsplit=2)[1]
        img = cv2.imread(img_path)
        draw_img, region, good_names = gd.detect_img(img, log=True, log_name=good_name)
        cv2.imwrite(save_path + '/' + f'{good_name}_{save_name}.jpg', region, [cv2.IMWRITE_JPEG_QUALITY, 100])
        # os.remove(img_path)

if __name__ == '__main__':
    # get_log_img()

    gd = GoodDetect()
    base_path = 'D:/zhrdpy_project/AutShop/data/log_img_test/'
    img_name = '2.jpg'
    img_path = base_path + img_name
    img = cv2.imread(img_path)
    draw_img,region,good_names = gd.detect_img(img,log=False)
    cv2.imshow("draw_img", draw_img)
    cv2.imshow("region", region)
    print(good_names)
    cv2.waitKey(0)



shop_window.py

作用:程序的图形界面与接口

from tkinter import *
from tkinter import filedialog
from PIL import Image, ImageTk
from analyse_good import *
from good_cls_data import name_dict
import threading
lock = threading.Lock()

class Window_shop():
    def __init__(self):
        self.root = Tk()
        self.img_Label = Label(self.root)
        self.img_outLabel = Label(self.root)
        self.txt = Text(self.root)
        self.detect = GoodDetect()
        self.img = None
        self.Type = None
        self.no_img = True
        self.imgshow_width = 0
        self.imgshow_height = 0
        self.img_ratio = 0.5
        self.var_name = StringVar() #文件输入路径变量
        self.logname = 'logname'
        self.cnname = 'cnname'

    # 输入文件路径
    def selectPath_file(self):
        path_ = filedialog.askopenfilename(filetypes=[("图片或视频", [".jpg",".png", ".MOV", ".mp4"])])
        self.var_name.set(path_)
        self.Type = self.var_name.get().rsplit('.', maxsplit=2)[-1]
        if self.Type == 'jpg' or self.Type == 'png':
            self.no_img = False
            self.img = cv2.imread(self.var_name.get())
            img = Image.fromarray(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
            self.img_width = int(img.width*self.img_ratio)
            self.img_height = int(img.height*self.img_ratio)
            img = img.resize((self.img_width, self.img_height), Image.ANTIALIAS)
            photo = ImageTk.PhotoImage(img)
            self.img_Label.config(image=photo)
            self.img_Label.image = photo
        if self.Type == "MOV" or self.Type == "mp4":
            self.no_img = False
            self.img = cv2.imread('data/vidio_show.jpg')
            img = Image.fromarray(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
            photo = ImageTk.PhotoImage(img)
            self.img_Label.config(image=photo)
            self.img_Label.image = photo
            self.txt.delete(1.0, END)  # 清除文本
            export = self.var_name.get()
            self.txt.insert(END, export)  # 追加显示运算结果export

    def detect_img_start(self):
        self.txt.delete(1.0, END)  # 清除文本
        if self.no_img:
            print("请选择图片或视频")
            export = "请选择图片或视频"
            self.txt.insert(END, export)  # 追加显示运算结果export
            return 0

        draw_img,region,good_list = self.detect.detect_img(self.img,log=False)
        export = str(good_list)

        img_show = Image.fromarray(cv2.cvtColor(draw_img, cv2.COLOR_BGR2RGB))
        img_out = Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB))

        img_show = img_show.resize((self.img_width, self.img_height), Image.ANTIALIAS)
        photo = ImageTk.PhotoImage(img_show)
        self.img_Label.config(image=photo)
        self.img_Label.image = photo

        img_out = img_out.resize((320, 320), Image.ANTIALIAS)
        photo_out = ImageTk.PhotoImage(img_out)
        self.img_outLabel.config(image=photo_out)
        self.img_outLabel.image = photo_out

        self.txt.insert(END, export)  # 追加显示运算结果export

    def login_start(self):
        self.txt.delete(1.0, END)  # 清除文本
        if self.no_img:
            print("请选择图片或视频")
            export = "请选择图片或视频"
            self.txt.insert(END, export)  # 追加显示运算结果export
            return 0

        with lock:
            if self.logname == 'logname' or self.cnname == 'cnname':
                self.get_good_name()
            if self.logname != 'logname' and self.cnname != 'cnname':
                draw_img, region, good_list = self.detect.detect_img(self.img, log=True, log_name=self.logname)
                export = str(good_list)

                img_show = Image.fromarray(cv2.cvtColor(draw_img, cv2.COLOR_BGR2RGB))
                img_out = Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB))

                img_show = img_show.resize((self.img_width, self.img_height), Image.ANTIALIAS)
                photo = ImageTk.PhotoImage(img_show)
                self.img_Label.config(image=photo)
                self.img_Label.image = photo

                # img_out = img_out.resize((320, 320), Image.ANTIALIAS)
                photo_out = ImageTk.PhotoImage(img_out)
                self.img_outLabel.config(image=photo_out)
                self.img_outLabel.image = photo_out

                self.txt.insert(END, export)  # 追加显示运算结果export
                # 注册复位
                self.logname = 'logname'
                self.cnname = 'cnname'

    def log_mane(self,logname,cnname):
        self.logname = logname
        self.cnname = cnname
        if cnname not in name_dict:
            name_dict[cnname] = logname

    def get_good_name(self):
        namewin = Tk()  # 调用tkinter模块中的TK()方法,实例化一个窗口对象
        namewin.geometry("250x130")  # 窗口对象调用geometry()方法,规划窗口大小
        namewin.title("商品命名窗口")
        # 设计提示标签 输入框 按钮
        # 设计两个提示标签
        logname = Label(namewin, text='注册名称', width=80)
        cnname = Label(namewin, text='中文名称', width=80)
        # 设计两个输入框
        entlog = Entry(namewin, width=100)
        entcn = Entry(namewin, width=100)
        # 设计2个按钮
        name_ok = Button(namewin, text='确认', command=lambda: self.log_mane(entlog.get(), entcn.get()))
        namewin_quit = Button(namewin, text='关闭', command=lambda: namewin.destroy())
        # --窗口各组件布局--
        # 组件的窗口布局
        logname.place(x=20, y=10, width=80, height=20)
        cnname.place(x=20, y=40, width=80, height=20)
        entlog.place(x=120, y=10, width=80, height=20)
        entcn.place(x=120, y=40, width=80, height=20)
        name_ok.place(x=100, y=80, width=50, height=20)
        namewin_quit.place(x=170, y=80, width=50, height=20)

    def choose_imgorvidio(self):
        if self.Type == 'jpg' or self.Type == 'png':
            self.detect_img_start()
        if self.Type == "MOV" or self.Type == "mp4":
            cap = cv2.VideoCapture(self.var_name.get())
            while cap.isOpened():
                retval, frame = cap.read()
                if not retval:
                    print('can not read frame')
                    break
                # 检测
                draw_img,region,good_list = self.detect.detect_vidio(frame)
                cv2.imshow("draw_img", draw_img)
                cv2.imshow("region", region)
                print(good_list)

                key = cv2.waitKey(42)
                if key == ord('q'):
                    break
                # 释放资源
            cap.release()
            cv2.destroyAllWindows()

    def run(self):
        # 窗口
        self.root.title('商品自动检测')
        self.root.geometry('1000x800') # 这里的乘号不是 * ,而是小写英文字母 x
        # 标题
        lb_top = Label(self.root, text='商品自动检测程序',
                   bg='#d3fbfb',
                   fg='red',
                   font=('华文新魏', 32),
                   width=20,
                   height=2,
                   relief=SUNKEN)
        lb_top.pack()

        # 结果文本
        self.txt.place(rely=0.8, relwidth=1, relheight=0.3)

        # 按钮
        btn2 = Button(self.root, text='开始检测', command=lambda: self.choose_imgorvidio()).place(relx=0.7, rely=0.14, relwidth=0.2, relheight=0.08)
        btn1 = Button(self.root, text='开始注册', command=lambda: self.login_start()).place(relx=0.4, rely=0.14, relwidth=0.2, relheight=0.08)
        btn0 = Button(self.root, text="路径选择", command=lambda: self.selectPath_file()).place(relx=0.1, rely=0.14, relwidth=0.2, relheight=0.08)

        # 图像
        self.img_Label.place(relx=0.05, rely=0.25, relwidth=0.65, relheight=0.5)
        self.img_outLabel.place(relx=0.72, rely=0.25, relwidth=0.23, relheight=0.5)

        self.root.mainloop()

if __name__ == '__main__':
    win = Window_shop()
    win.run()

test.py

作用:不包括在项目中,但可能用到的一些小方法

import os.path
import time
import glob
import os
import cv2
import torch
import good_net
def rename(img_folder):
    for img_name in os.listdir(img_folder):  # os.listdir(): 列出路径下所有的文件
        #os.path.join() 拼接文件路径
        src = os.path.join(img_folder, img_name)   #src:要修改的目录名
        image_name = '1000' + img_name
        dst = os.path.join(img_folder, image_name) #dst: 修改后的目录名
        os.rename(src, dst) #用dst替代src

def delete():
    base_path = 'D:/zhrdpy_project/AutShop/data/log_img'
    img_paths = glob.glob(os.path.join(base_path, "*", "*", "*"))
    for img_path in img_paths:
        img_name = img_path.rsplit('\\', maxsplit=1)[-1]
        jpg = img_name.rsplit('.', maxsplit=1)[-1]
        if jpg == 'jpg':
            os.remove(img_path)

def png2jpg():
    org_img_paths = glob.glob(os.path.join(r"D:\zhrdpy_project\AutShop\data\CLSDataset_test\box\xiang_jiao_niu_nai_he_zhuang", "*"))
    for path in org_img_paths:
    # png转jpg
        img = cv2.imread(path)
        image_name = path.rsplit('\\', maxsplit=1)[-1]
        save_name = image_name.rsplit('.', maxsplit=1)[0]
        targe_path = path.rsplit('\\', maxsplit=1)[0]
        cv2.imwrite(targe_path+'/'+save_name+'.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, 100])
        os.remove(path)

def main():
    img_folder0 = r'D:\zhrdpy_project\grayimg\label\lll' #文件夹路径    直接放文件夹路径即可\train&\test
    rename(img_folder0)


def save_densnet():
    gd = good_net.GoodNet()
    gd.load_state_dict(torch.load("weight/best.pt")) # 加载最好权重
    gd.eval()
    print(gd)
    # torch.save(gd.sub_net.state_dict(), 'weight/sub_net.pt') #保存densnet部分
    # torch.save(gd.feature_net.state_dict(), 'weight/feature_net.pt') #保存feature_net部分


if __name__=="__main__":
    save_densnet()


6 结语

一个简单的商品检测项目,数据集读者可以自行拍摄,录制视频抽帧即可训练模型了。
有什么交流意见可以评论或者私信我。
这里放一个展示视频:

商品检测效果视频

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值