用TensorFlow训练自己的第一个模型

现在学AI的一个优势就是:前人栽树后人乘凉,很多资料都已完善,而且有很多很棒的开源作品可以学习,感谢大佬们

项目

项目源码地址
视频教程地址
我在大佬的基础上基于此模型还加上了根据特征值缓存进行快速识别的方法,以应对超市某些未能正确识别的场景,针对项目中window.py文件进行修改和补充:

  1. 缓存查询方法
 def query_cache(self, image_features):
        if not cache:
            return None
        max_similarity = -1
        best_label = None
        for image_id, (cached_features, label) in cache.items():
            similarity = self.cosine_similarity(image_features, cached_features)
            if similarity > max_similarity:
                max_similarity = similarity
                best_label = label

        if max_similarity >= 0.5:
            return best_label
        else:
            return None
  1. 余弦相邻计算
 def cosine_similarity(self, features1, features2):
        dot_product = np.dot(features1.flatten(), features2.flatten())
        norm_features1 = np.linalg.norm(features1)
        norm_features2 = np.linalg.norm(features2)
        return dot_product / (norm_features1 * norm_features2)
  1. 缓存更新方法
 def update_cache(self):
        input_text = self.input_box.text()
        self.label = input_text or self.label
        self.result.setText(self.label)
        # 如果缓存已满,移除最久未使用的条目
        if len(cache) >= CACHE_CAPACITY:
            cache.popitem(last=False)
        # 添加新条目
        cache[self.image_id] = (self.image_features, self.label)
        self.input_box.clear()
        self.class_names.append(self.label)
  1. 获取图片哈希值
    def get_image_id_from_hash(self, img):
        buffer = img.tobytes()
        return hashlib.md5(buffer).hexdigest()
  1. 预测图片
    def predict_img(self):
        self.input_box.clear()
        img = Image.open('images/target.png')  # 读取图片
        self.image_id = self.get_image_id_from_hash(img)
        img = np.asarray(img)  # 将图片转化为numpy的数组
        start_time = time.time()  # 记录开始时间
        outputs = self.model.predict(img.reshape(1, 224, 224, 3), batch_size=1, )  # 将图片输入模型得到结果
        end_time = time.time()  # 记录结束时间
        elapsed_time = end_time - start_time  # 计算时间差
        print("运行时间:", elapsed_time, "秒")
        self.image_features = outputs
        result = self.query_cache(outputs)
        self.label = result
        if result is None:
            result_index = int(np.argmax(outputs))
            result = self.class_names[result_index]  # 获得对应的水果名称
            self.result.setText(result)
            self.label = result
        else:
            self.result.setText(result)  # 在界面上做显示
  1. UI改造
    def initUI(self):
        main_widget = QWidget()
        main_layout = QHBoxLayout()
        font = QFont('楷体', 15)

        # 主页面,设置组件并在组件放在布局上
        left_widget = QWidget()
        left_layout = QVBoxLayout()
        img_title = QLabel("样本")
        img_title.setFont(font)
        img_title.setAlignment(Qt.AlignCenter)
        self.img_label = QLabel()
        img_init = cv2.imread(self.to_predict_name)
        h, w, c = img_init.shape
        scale = 400 / h
        img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale)
        cv2.imwrite("images/show.png", img_show)
        img_init = cv2.resize(img_init, (224, 224))
        cv2.imwrite('images/target.png', img_init)
        self.img_label.setPixmap(QPixmap("images/show.png"))
        left_layout.addWidget(img_title)
        left_layout.addWidget(self.img_label, 1, Qt.AlignCenter)
        left_widget.setLayout(left_layout)
        right_widget = QWidget()
        right_layout = QVBoxLayout()
        btn_change = QPushButton(" 上传图片 ")
        btn_change.clicked.connect(self.change_img)
        btn_change.setFont(font)
        btn_predict = QPushButton(" 开始识别 ")
        btn_predict.setFont(font)
        btn_predict.clicked.connect(self.predict_img)
        btn_update = QPushButton(" 更新缓存 ")
        btn_update.setFont(font)
        btn_update.clicked.connect(self.update_cache)
        label_result = QLabel(' 果蔬名称 ')
        self.result = QLabel("等待识别")
        label_result.setFont(QFont('楷体', 16))
        self.result.setFont(QFont('楷体', 24))
        self.input_box = QLineEdit()
        self.input_box.setPlaceholderText("请输入内容...")
        right_layout.addStretch()
        right_layout.addWidget(label_result, 0, Qt.AlignCenter)
        right_layout.addStretch()
        right_layout.addWidget(self.result, 0, Qt.AlignCenter)
        right_layout.addStretch()
        right_layout.addWidget(self.input_box)
        right_layout.addStretch()
        right_layout.addWidget(btn_change)
        right_layout.addWidget(btn_predict)
        right_layout.addWidget(btn_update)
        right_layout.addStretch()
        right_widget.setLayout(right_layout)
        main_layout.addWidget(left_widget)
        main_layout.addWidget(right_widget)
        main_widget.setLayout(main_layout)

        # 关于页面,设置组件并把组件放在布局上
        label_super = QLabel("作者:cpa")  # todo 更换作者信息
        label_super.setFont(QFont('楷体', 12))
        # label_super.setOpenExternalLinks(True)
        label_super.setAlignment(Qt.AlignRight)

        # 添加注释
        self.addTab(main_widget, '主页')
        self.setTabIcon(0, QIcon('images/主页面.png'))

运行:

在这里插入图片描述

待完善:

  1. 缓存部分目前市面上是需要将缓存值存入本地sqllite之类的数据库进行保存的,这样下次开机缓存数据不会丢失,这里只展示思路
  2. 缓存除了保存在本地外还可以上传云端进行增强学习然后下发最新模型在本地进行更新,形成完美闭环
    在这里插入图片描述

打包

  1. pip install pyinstaller
  2. pyinstaller -F -w (-i icofile) filename

说明:
filename表示你的Python程序文件名
-w 表示隐藏程序运行时的命令行窗口(不加-w会有黑色窗口)
括号内的为可选参数,-i icofile表示给程序加上图标,图标必须为.ico格式
icofile表示图标的位置,建议直接放在程序文件夹里面,这样子打包的时候直接写文件名就好

  1. pyinstaller -F -w -i 'test.ico' window.py
  2. 将图片文件和模型文件等window.py中用到的资源在打包后一起移入dist目录中,不然会资源找不到的错
  3. 发给你的小伙伴看看效果吧
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值