【TF lite】从tensorflow模型训练到lite模型移植

前言

本文使用tensorflow下的ssdlite-mobilenet v2物体检测模型,并转换为tflite模型,并完成测试

1. 安装 TensorFlow Object Detection API

1.1 下载tensorflow-master和models-master

下载地址分别为https://github.com/tensorflow/tensorflowhttps://github.com/tensorflow/models

1.2 安装依赖项、编译工具

pip install matplotlib pillow lxml Cython pycocotools
sudo apt-get install protobuf-compiler

1.3 使用proto编译

cd models/research/
protoc object_detection/protos/*.proto --python_out=.

1.4 添加环境变量

在.bashrc中添加环境变量,路径根据实际情况补充完整,然后source更新环境变量

export PYTHONPATH=$PYTHONPATH:/.../models/research:/.../models/research/slim

1.5 测试models是否安装成功:

python object_detection/builders/model_builder_test.py

返回OK则OK

2. TF Record格式数据准备

使用label-image标注工具对样本进行标注,得到VOC格式数据。将所有的图片放入images/文件夹,标注得到的xml文件保存到merged_xml/文件夹内,并新建文件夹Annotations/。

2.1 训练集划分

新建train_test_split.py把xml数据集分为了train 、test、 validation三部分,并存储在Annotations文件夹中,train为训练集占76.5%,test为测试集10%,validation为验证集13.5%,train_test_split.py代码如下:

import os  
import random  
import time  
import shutil  
  
xmlfilepath=r'merged_xml'  
saveBasePath=r"./Annotations"  
  
trainval_percent=0.9  
train_percent=0.85  
total_xml = os.listdir(xmlfilepath)  
num=len(total_xml)  
list=range(num)  
tv=int(num*trainval_percent)  
tr=int(tv*train_percent)  
trainval= random.sample(list,tv)  
train=random.sample(trainval,tr)  
print("train and val size",tv)  
print("train size",tr)  
# print(total_xml[1])  
start = time.time()   
# print(trainval)  
# print(train)  
test_num=0  
val_num=0  
train_num=0  
# for directory in ['train','test',"val"]:  
#         xml_path = os.path.join(os.getcwd(), 'Annotations/{}'.format(directory))  
#         if(not os.path.exists(xml_path)):  
#             os.mkdir(xml_path)  
#         # shutil.copyfile(filePath, newfile)  
#         print(xml_path)  
for i  in list:  
    name=total_xml[i]  
            # print(i)  
    if i in trainval:  #train and val set  
    # ftrainval.write(name)  
        if i in train:  
            # ftrain.write(name)  
            # print("train")  
            # print(name)  
            # print("train: "+name+" "+str(train_num))  
            directory="train"  
            train_num+=1  
            xml_path = os.path.join(os.getcwd(), 'Annotations/{}'.format(directory))  
            if(not os.path.exists(xml_path)):  
                os.mkdir(xml_path)  
            filePath=os.path.join(xmlfilepath,name)  
            newfile=os.path.join(saveBasePath,os.path.join(directory,name))  
            shutil.copyfile(filePath, newfile)  
  
        else:  
            # fval.write(name)  
            # print("val")  
            # print("val: "+name+" "+str(val_num))  
            directory="validation"  
            xml_path = os.path.join(os.getcwd(), 'Annotations/{}'.format(directory))  
            if(not os.path.exists(xml_path)):  
                os.mkdir(xml_path)  
            val_num+=1  
            filePath=os.path.join(xmlfilepath,name)   
            newfile=os.path.join(saveBasePath,os.path.join(directory,name))  
            shutil.copyfile(filePath, newfile)  
            # print(name)  
    else:  #test set  
        # ftest.write(name)  
        # print("test")  
        # print("test: "+name+" "+str(test_num))  
        directory="test"  
        xml_path = os.path.join(os.getcwd(), 'Annotations/{}'.format(directory))  
        if(not os.path.exists(xml_path)):  
            os.mkdir(xml_path)  
        test_num+=1  
        filePath=os.path.join(xmlfilepath,name)  
        newfile=os.path.join(saveBasePath,os.path.join(directory,name))  
        shutil.copyfile(filePath, newfile)  
            # print(name)  
  
# End time  
end = time.time()  
seconds=end-star
  • 4
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值