python ORM SqlAlchemy
最近在研究开源标注工具Brat(Brat官网),看到其后端使用了python做服务器,处理标注数据。心想着如何将数据通过python保存到数据库,于是便Google了下,找到了Sql Alchemy框架。官网地址在上篇文章里面已经给出了,大家需要的话可以到前面文章查看,在此就不单独贴出来了。今天就把最近写的小demo分享出来,哪位大佬有疑问的可以随时评论留言讨论哈。talk is cheap,show you the code
初始化sql
简单创建了两张表。
create table textmark.annotation_result
(
id int auto_increment
primary key,
content varchar(1000) null comment '标注结果',
create_time datetime default CURRENT_TIMESTAMP null,
update_time datetime default CURRENT_TIMESTAMP null,
doc_id int null
)
charset=utf8;
create table textmark.document
(
id int auto_increment
primary key,
name varchar(20) null,
text varchar(20000) null
)
charset=utf8;
简单使用
根据官网以及自己的理解,写了这个小demo。
from contextlib import contextmanager
from sqlalchemy import Column, Integer, String, MetaData
from sqlalchemy import create_engine
from sqlalchemy.orm import registry
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.decl_api import DeclarativeMeta
DB_CONNECT_STRING = 'mysql+pymysql://root:rootroot@127.0.0.1:3306/textmark?charset=utf8&autocommit=true'
# 这个 Engine ,第一次返回时 create_engine() ,尚未实际尝试连接到数据库;只有在第一次要求它对数据库执行任务时才会发生这种情况。这是一种称为 lazy initialization .
engine = create_engine(DB_CONNECT_STRING, echo=False, future=True, encoding='utf-8', isolation_level='AUTOCOMMIT')
mapper_registry = registry()
Base: DeclarativeMeta = mapper_registry.generate_base()
# session 创建工厂
DBSession = sessionmaker(bind=engine)
# session = DBSession()
metaData = MetaData()
# 使用上下文管理器来管理session
@contextmanager
def get_session():
"""
:author:zhangshaobo
:return:session
"""
try:
v_session = DBSession()
yield v_session
except ImportError:
raise ImportError
finally:
v_session.close()
@mapper_registry.mapped
class Document:
"""使用@mapper_registry.mapped注解或者继承Base类"""
__tablename__ = "document"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(20))
text = Column(String(200))
# def __str__(self):
# return u'%s\t%s\t%s' % (
# self.id,
# self.name,
# self.text
# )
# reload print方法
# def __repr__(self):
# return f"document:[id={self.id!r},name={self.name !r},text={self.text}]"
@mapper_registry.mapped
class CvAnnotation:
"""标注结果实体类"""
__tablename__ = 'annotation'
id = Column(Integer, primary_key=True, autoincrement=True)
doc_id = Column(Integer)
type_short = Column(String(20))
anno_no = Column(String(20))
start_offset = Column(Integer)
end_offset = Column(Integer)
content = Column(String(200))
business_type = Column(String(20))
virtual_id = Column(String(20))
remark = Column(String(200))
# reload print方法
# def __repr__(self):
# return f"CvAnnotation:[id={self.id!r},doc_id={self.doc_id !r},type_short={self.type_short !r},anno_no={self.anno_no !r},start_offset={self.start_offset !r}" \
# f",end_offset={self.end_offset !r},content={self.content !r},business_type={self.business_type},virtual_id={self.virtual_id}]"
@mapper_registry.mapped
class AnnotationResult:
"""输出结果"""
__tablename__ = 'annotation_result'
id = Column(Integer, autoincrement=True, primary_key=True)
content = Column(String(1000))
doc_id = Column(Integer)
def __repr__(self):
return f"AnnotationResult:[id={self.id!r},doc_id={self.doc_id!r},content={self.content}]"
def get_annotation_by_doc_id(doc_id):
"""根据doc_id获取标注list"""
with get_session() as session:
return session.query(CvAnnotation).filter(CvAnnotation.doc_id == doc_id).all()
def get_annotation_by_id_and_shorttype(doc_id, type_short):
with get_session() as session:
return session.query(CvAnnotation).filter(
CvAnnotation.doc_id == doc_id, CvAnnotation.type_short == type_short).all()
# print(len(get_annotation_by_id_and_shorttype(1, 'T')))
def select_document(doc_id):
with get_session() as session:
return session.query(Document).filter(Document.id == doc_id).first()
def save_annotation(annotation: CvAnnotation):
with get_session() as session:
session.add(annotation)
session.commit()
def save_annotation_result(result: AnnotationResult):
with get_session() as session:
session.add(result)
session.commit()
def update_annotation_result(result: AnnotationResult):
with get_session() as session:
session.query(AnnotationResult).filter(AnnotationResult.id == result.id).filter(
AnnotationResult.doc_id == result.doc_id).update({"content": result.content})
session.commit
def get_annotation_result_by_doc_id(doc_id: Integer):
with get_session() as session:
rlt = session.query(AnnotationResult).filter(AnnotationResult.doc_id == doc_id).first()
return AnnotationResult(id=rlt.id, content=rlt.content, doc_id=rlt.doc_id)
def get_content_by_task_obj_id(task_obj_id: Integer):
with get_session() as session:
session.exec()
return "s"
# rlt = select_document(1)
# print(type(rlt))
# print(str(rlt))
# print(rlt)
# save_annotation(
# CvAnnotation(doc_id=1, type_short='T', anno_no='1', start_offset='11', end_offset='15', business_type='测试',
# virtual_id='T1', remark='test remark'))
# print(str(select_document(1)))
# rlt = get_annotation_result_by_doc_id(1)
# content = rlt.content
# if content is not None:
# lines = content.split('\n')
# print(len(lines))
# for item in lines:
# print(item)
#
# print(u'doc_id:%s\tcontent:%s' % (rlt.doc_id, rlt.content))
# print(str(get_annotation_result_by_doc_id(1)))
# print(get_annotation_result_by_doc_id(1).content)
# rlt = get_annotation_result_by_doc_id(1).content
# lines = rlt.split('\n')
# print(len(lines))
# for i in lines:
# print(i)
# 更新标注结果
# content = get_annotation_result_by_doc_id(1).content
# print(content)
# content += 'T10\t疾病 814 820 moving\n'
# update_annotation_result(AnnotationResult(id=1, doc_id=1, content=content))