packagesrcimportcom.intellij.database.model.DasTableimportcom.intellij.database.util.Caseimportcom.intellij.database.util.DasUtil/** Available context bindings:
* SELECTION Iterable
* PROJECT project
* FILES files helper*/
//entity(dto)、mapper(dao) 与数据库表的对应关系在这里手动指明,idea Database 窗口里只能选下列配置了的 mapper//tableName(key) : [mapper(dao),entity(dto)]
typeMapping =[
(~/(?i)int/) : "INTEGER",
(~/(?i)float|double|decimal|real/): "DOUBLE",
(~/(?i)datetime|timestamp/) : "TIMESTAMP",
(~/(?i)date/) : "TIMESTAMP",
(~/(?i)time/) : "TIMESTAMP",
(~/(?i)/) : "VARCHAR"]
basePackage= "com.chitic.bank.mapping" //包名需手动填写
FILES.chooseDirectoryAndSave("Choose directory", "Choose where to store generated files") { dir ->SELECTION.filter { itinstanceofDasTable }.each { generate(it, dir) }
}
def generate(table, dir) {
def baseName= mapperName(table.getName(), true)
def fields=calcFields(table)new File(dir, baseName + "Mapper.xml").withPrintWriter { out ->generate(table, out, baseName, fields) }
}
def generate(table, out, baseName, fields) {
def baseResultMap= 'BaseResultMap'def base_Column_List= 'Base_Column_List'def date= new Date().format("yyyy/MM/dd")
def tableName=table.getName()
def dao= basePackage + ".dao.${baseName}Mapper"def to= basePackage + ".to.${baseName}TO"out.println mappingsStart(dao)
out.println resultMap(baseResultMap, to, fields)
out.println sql(fields, base_Column_List)
out.println selectById(tableName, fields, baseResultMap, base_Column_List)
out.println deleteById(tableName, fields)
out.println delete(tableName, fields, to)
out.println insert(tableName, fields, to)
out.println update(tableName, fields, to)
out.println selectList(tableName, fields, to, base_Column_List, baseResultMap)
out.println mappingsEnd()
}staticdef resultMap(baseResultMap, to, fields) {
def inner= ''fields.each() {
inner+= '\t\t\n'}return '''\t
''' + inner + '''\t
'''}
def calcFields(table) {
DasUtil.getColumns(table).reduce([]) { fields, col->def spec=Case.LOWER.apply(col.getDataType().getSpecification())
def typeStr= typeMapping.find { p, t ->p.matcher(spec).find() }.value
fields+=[[
comment : col.getComment(),
name : mapperName(col.getName(),false),
sqlFieldName: col.getName(),
type : typeStr,
annos :""]]
}
}
def mapperName(str, capitalize) {
def s=com.intellij.psi.codeStyle.NameUtil.splitNameIntoWords(str)
.collect { Case.LOWER.apply(it).capitalize() }
.join("")
.replaceAll(/[^\p{javaJavaIdentifierPart}[_]]/, "_")
name= capitalize || s.length() == 1 ? s : Case.LOWER.apply(s[0]) + s[1..-1]
}//------------------------------------------------------------------------ mappings
staticdef mappingsStart(mapper) {return '''<?xml version="1.0" encoding="UTF-8"?>
'''}//------------------------------------------------------------------------ mappings
staticdef mappingsEnd() {return ''''''}//------------------------------------------------------------------------ selectById
staticdef selectById(tableName, fields, baseResultMap, base_Column_List) {return ''' selectfrom''' + tableName + '''where id=#{id}'''}//------------------------------------------------------------------------ insert
staticdef insert(tableName, fields, parameterType) {return ''' insert into''' + tableName + '''
''' + testNotNullStr(fields) + '''
''' + testNotNullStrSet(fields) + '''
'''}//------------------------------------------------------------------------ update
staticdef update(tableName, fields, parameterType) {return ''' update''' + tableName + '''
''' + testNotNullStrWhere(fields) + '''
where id=#{id}'''}//------------------------------------------------------------------------ deleteById
staticdef deleteById(tableName, fields) {return ''' delete
from''' + tableName + '''where id=#{id}'''}//------------------------------------------------------------------------ delete
staticdef delete(tableName, fields, parameterType) {return ''' delete from''' + tableName + '''where1 = 1
''' + testNotNullStrWhere(fields) + '''
'''}//------------------------------------------------------------------------ selectList
staticdef selectList(tableName, fields, parameterType, base_Column_List, baseResultMap) {return ''' selectfrom''' + tableName + '''where1 = 1
''' + testNotNullStrWhere(fields) + '''order by id desc'''}//------------------------------------------------------------------------ sql
staticdef sql(fields, base_Column_List) {
def str= '''\t
@inner@ '''def inner= ''fields.each() {
inner+= ('\t\t' + it.sqlFieldName + ',\n')
}return str.replace("@inner@", inner.substring(0, inner.length() - 2))
}staticdef testNotNullStrWhere(fields) {
def inner= ''fields.each {
inner+= ''' and''' + it.sqlFieldName + ''' = #{''' + it.name + '''}\n'''}returninner
}staticdef testNotNullStrSet(fields) {
def inner= ''fields.each {
inner+= ''' #{''' + it.name + '''},\n'''}returninner
}staticdef testNotNullStr(fields) {
def inner1= ''fields.each {
inner1+= ''' \t''' + it.sqlFieldName + ''',\n'''}returninner1
}