这个是合并显示的代码:showMerge.py,这里获取的数据是之前跟踪模型与OCR模型保存的结果。
from DAO import *
import os
import cv2
from PIL import Image, ImageFont, ImageDraw
import numpy as np
import math
import textwrap
import sys,getopt
def fetchFromTrack(videoFile,imageName):
RectBoxTable.videoFile = videoFile
RectBoxTable.imageFile = imageName
rectsInfoByName = RectBoxTable.find()
trackInfos = []
for rectInfo in rectsInfoByName:
trackInfos.append([rectInfo.x0,rectInfo.y0,rectInfo.x1,rectInfo.y1,rectInfo.label])
return trackInfos
def drawRects(trackInfo,image):
font = ImageFont.truetype("model/STXINWEI.TTF", size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'), encoding='utf-8')
thickness = (image.size[0] + image.size[1]) // 300
[left, top,right, bottom,label] = trackInfo
draw = ImageDraw.Draw(image)
label_size = draw.textsize(label, font)
# print(label, (left, top), (right, bottom))
if top - label_size[1] >= 0:
text_origin = np.array([left, top - label_size[1]])
else:
text_origin = np.array([left, top + 1])
# My kingdom for a good redistributable image drawing library.
for i in range(thickness):
draw.rectangle(
[left + i, top + i, right - i, bottom - i])
draw.rectangle(
[tuple(text_origin), tuple(text_origin + label_size)])
draw.text(text_origin, label, fill=(0, 0, 0), font=font)
del draw
#-------------------------------------------------------------------------------------
def order_points(pts):
def centeroidpython(pts):
x, y = zip(*pts)
l = len(x)
return sum(x) / l, sum(y) / l
centroid_x, centroid_y = centeroidpython(pts)
pts_sorted = sorted(pts, key=lambda x: math.atan2((x[1] - centroid_y), (x[0] - centroid_x)))
return pts_sorted
def fetchFromOCR(videoFile,imageName):
TextBoxTable.videoFile = videoFile
TextBoxTable.imageFile = imageName
polygonsInfoByName = TextBoxTable.find()
ocrInfos = []
for polygonInfo in polygonsInfoByName:
ocrInfos.append([polygonInfo.x0,polygonInfo.y0,polygonInfo.x1,polygonInfo.y1,
polygonInfo.x2,polygonInfo.y2,polygonInfo.x3,polygonInfo.y3,polygonInfo.label
])
return ocrInfos
def drawOCR(ocrInfo,frame):
points = ocrInfo[:-1]
label = ocrInfo[-1]
points = np.asarray(points)
points = np.reshape(points, [-1, 2])
cv2.polylines(frame, np.int32([points]), 1, (0, 255, 0), 2)
image = Image.fromarray(frame)
font_size = np.floor(3e-2 * image.size[1] + 0.5).astype('int32')
font = ImageFont.truetype("model/STXINWEI.TTF", size=font_size, encoding='utf-8')
thickness = (image.size[0] + image.size[1]) // 300
DRAW = ImageDraw.Draw(image)
points = order_points(points)
# lines = textwrap.wrap(label, width=1)
# y_text = points[0][1]
# for line in lines:
# width, height = font.getsize(line)
# DRAW.text((max(points[0][0] - font_size, 0), y_text), line, (0,0,255), font=font)
# y_text += height
DRAW.text((points[0][0], max(points[0][1] - font_size, 0)), label, (0,0,255), font=font)
del DRAW
return image
#---------------------------------------------------------------------------------------------------------
def showVideo(videoFile):
# 1.读视频帧
capture = cv2.VideoCapture(os.path.join(r"static\video",videoFile))
imageName = 1
while True:
ret, frame = capture.read()
# print(type(frame))
time = capture.get(0) / 1000.0 # 时间戳,并转换为秒
print("帧时间戳{:.2f}秒".format(time))
if not ret:
print('ret is False')
# g.videoState = "end"
break
image = None
# 3.1获取保存的OCR数据
ocrInfos = fetchFromOCR(videoFile,imageName)
# 3.2画
if len(ocrInfos)>0:
for ocrInfo in ocrInfos:
image = drawOCR(ocrInfo,frame)
else:
image = Image.fromarray(frame) # 由于没有OCR的框 所以image这里的人为赋值
# 2.1获取保存的Track框与标签
trackInfos = fetchFromTrack(videoFile,imageName)
# 2.2画Track框与标签
if len(trackInfos)>0:
for trackInfo in trackInfos:
drawRects(trackInfo,image)
# 显示
# img = 255 * np.array(image).astype('uint8')
img = np.array(image)
# img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) # PIL转cv2
cv2.imshow("Merge_Result",img)
cv2.waitKey(30)
imageName +=1
def main(argv):
videoFile = ""
try:
opts, args = getopt.getopt(argv,"hv:",["videoName="])
except getopt.GetoptError:
print("Error:usage -v <videoName>")
sys.exit(2)
for opt, arg in opts:
if opt=="-h":
print("usage -v <videoName>")
elif opt in ("-v","--videoName"):
videoFile = arg
if videoFile!="":
showVideo(videoFile)
if __name__ == '__main__':
main(sys.argv[1:])
# showVideo("boat1.mp4")
看完上面的代码你肯定很懵圈,因为表单的类结构还没有给你。下面我们来看看它俩
# 导入模块
from flask import Flask
from flask_sqlalchemy import SQLAlchemy #ORM框架 依赖pymysql才能操作数据库
import pymysql
# 创建flask对象
app = Flask(__name__)
app.secret_key = 'asdfasdf'
def getDb():
# 配置flask配置对象中键:SQLALCHEMY_DATABASE_URI
app.config['SQLALCHEMY_DATABASE_URI'] = "mysql+pymysql://root:root@127.0.0.1:3306/labelSoft"
# 配置flask配置对象中键:SQLALCHEMY_COMMIT_TEARDOWN,设置为True,应用会自动在每次请求结束后提交数据库中变动
app.config['SQLALCHEMY_COMMIT_TEARDOWN'] = True
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = True
app.config['SQLALCHEMY_COMMIT_ON_TEARDOWN'] = True
# 获取SQLAlchemy实例对象,接下来就可以使用对象调用数据
db = SQLAlchemy(app)
return db
db = getDb()
class TextBoxTable(db.Model):
__tablename__ = 'TextBoxInfo'
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
videoName = db.Column(db.String(25)) # 是单张图片就为空
imageName = db.Column(db.String(25), nullable=False)
# imageId = db.Column(db.Integer)
label = db.Column(db.String(100), nullable=False)
x0 = db.Column(db.FLOAT(8, 2), nullable=False)
y0 = db.Column(db.FLOAT(8, 2), nullable=False)
x1 = db.Column(db.FLOAT(8, 2), nullable=False)
y1 = db.Column(db.FLOAT(8, 2), nullable=False)
x2 = db.Column(db.FLOAT(8, 2), nullable=False)
y2 = db.Column(db.FLOAT(8, 2), nullable=False)
x3 = db.Column(db.FLOAT(8, 2), nullable=False)
y3 = db.Column(db.FLOAT(8, 2), nullable=False)
videoFile = "null"
imageFile = "image1.jpg" # 类变量 通过类名.imageName = "..." 修改
def __init__(self, label, x0, y0, x1, y1, x2, y2, x3, y3, image, video="null"):
self.videoName = video
self.imageName = image
self.label = label # 这个是实例变量,也就是对象变量
self.x0 = x0
self.y0 = y0
self.x1 = x1
self.y1 = y1
self.x2 = x2
self.y2 = y2
self.x3 = x3
self.y3 = y3
def __repr__(self):
return '<BoxTable %r>' % self.label # 类似重写了java中的toString方法
def save(self): # 增
db.session.add(self)
db.session.commit()
@classmethod # 类方法
def find(cls): # 查
recsByName = TextBoxTable.query.filter(
cls.imageName == cls.imageFile, cls.videoName == cls.videoFile).all() # 返回的是对象列表 不加.all()只是sql语句
return recsByName
def delete(self): # 删
db.session.delete(self)
db.session.commit()
# print("删除框:")
# print(self)
class RectBoxTable(db.Model):
__tablename__ = 'RectBoxInfo'
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
videoName = db.Column(db.String(25)) # 是单张图片就为空
imageName = db.Column(db.String(25), nullable=False)
label = db.Column(db.String(25), nullable=False)
x0 = db.Column(db.FLOAT(8, 2), nullable=False)
y0 = db.Column(db.FLOAT(8, 2), nullable=False)
x1 = db.Column(db.FLOAT(8, 2), nullable=False)
y1 = db.Column(db.FLOAT(8, 2), nullable=False)
videoFile = "null"
imageFile = "image1.jpg" # 类变量 通过类名.imageName = "..." 修改
def __init__(self, label, x0, y0, x1, y1, image, video="null"):
self.videoName = video
self.imageName = image
self.label = label # 这个是实例变量,也就是对象变量
self.x0 = x0
self.y0 = y0
self.x1 = x1
self.y1 = y1
def __repr__(self):
return '<BoxTable %r>' % self.label # 类似重写了java中的toString方法
def save(self): #增
db.session.add(self)
db.session.commit()
@classmethod # 类方法
def find(cls): # 查
recsByName = RectBoxTable.query.filter(
cls.imageName == cls.imageFile, cls.videoName == cls.videoFile).all() # 返回的是对象列表 不加.all()只是sql语句
return recsByName
def delete(self): # 删
db.session.delete(self)
db.session.commit()
# print("删除框:")
# print(self)
if __name__ == '__main__':
db.create_all() # 创建表
补充:
# 1.获得图片帧总数(剔除重复
sql = 'select DISTINCT imageName from TextBoxInfo WHERE videoname <=>"%s"' % TextBoxTable.videoFile
data_query = db.session.execute(sql)
frameList = []
for data in data_query.fetchall():
frameList.append(data[0])
PS:SQLAlchemy下的操作参考link