利用IDEA 自带功能生成POJOs 适用于mybatis

import com.intellij.database.model.DasTable
import com.intellij.database.model.ObjectKind
import com.intellij.database.util.Case
import com.intellij.database.util.DasUtil

/**
 * mybatis do 和 mapper 生成工具
 *
 *
 * 1、将该文件导入到 "临时文件和控制台-扩展-Database Tools and SQL-schema"
 * 2、修改author 和 baseModelDO
 * 3、选择idea 中的数据库 选中需要生成代码的表
 * 4、右键-工具-脚本化扩展程序-选择改脚本
 * 5、在弹出的文件夹选择框中选择需要存放文件的包
 * 结构如下
 * 选择的文件包
 * --dao
 * ----mapper(存放mapper类)
 * ----model(存放do类)
 * --service
 * ----xx(service接口)
 * ----impl(存放接口实现类)
 * --dto (需要根据实际情况手动移动到指定的目录)
 *
 */

/**
 * Available context bindings:
 * SELECTION Iterable<DasObject>
 * PROJECT project
 * FILES files helper
 */

//选择包名
packageName = ""
author = "xxx"
baseModelDO = "xx.common.model.BaseDO"
bases = ["id", "deleteFlag", "createdBy", "updatedBy", "dateCreated", "dateUpdated", "createdName", "updatedName", "createdTime", "updatedTime"]

//数据类型的映射关系
typeMapping = [
        (~/(?i)tinyint|smallint|mediumint/)       : "Integer",
        (~/(?i)int/)                              : "Long",
        (~/(?i)bool|bit/)                         : "Boolean",
        (~/(?i)float|double|decimal|real/)        : "BigDecimal",
        (~/(?i)year|datetime|timestamp|date|time/): "Date",
        (~/(?i)blob|binary|bfile|clob|raw|image/) : "InputStream",
        (~/(?i)/)                                 : "String"
]

/**
 * FILES.chooseDirectoryAndSave 是在 idea 的 Database 窗口鼠标右键
 * 点击 groovy 选项后弹出文件夹选择框关闭时回调的方法,
 * DasTable 指代一张表,保存了该张表中的一些信息,如表名,字段等,dir 是选中的文件夹。
 */
//选择生成的pojo类的存储路径
FILES.chooseDirectoryAndSave("Choose directory", "Choose where to store generated files") { dir ->
    packageName = getPackageNameForTemp(dir)
    initDDDStructure(dir)
    SELECTION.filter { it instanceof DasTable && it.getKind() == ObjectKind.TABLE }.each {
        //generate(it, dir)
        generate(it, dir)
    }
}

def initDDDStructure(dir) {
    Set files = new HashSet()

    /**
     * --dao
     * ----mapper(存放mapper类)
     * ----model(存放do类)
     * --service
     * ----xx(service接口)
     * ----impl(存放接口实现类)
     */
    files.add("dao")
    files.add("dao\\mapper")
    files.add("dao\\model")
    files.add("service")
    files.add("service\\impl")
    files.add("dto")
    files.each {
        def file = new File(dir.toString() + "\\" + it)
        if (!file.exists()) {
            file.mkdirs()
//            def packageFile = new File(file.getPath() + "\\package-info.java")
//            packageFile.withPrintWriter("UTF-8") {
//                out -> generatePackageInfo(out, dir.toString() + "\\" + it)
//            }

        }

    }
}


/**
 * generate 方法会根据 table 和 dir 生成目标文件。
 * generate(out, className, fields) 方法是真正进行模板生成并写入文件的地方。
 *
 */
//将生成的内容写入文件
def generate(table, dir) {
    //def className = javaName(table.getName(), true)
    def className = javaClassName(table.getName(), true)
    def fields = calcFields(table)


    def serviceFile = new File(dir.toString() + "\\service", className + "Service.java")
    if (!serviceFile.exists()) {
        serviceFile.withPrintWriter("UTF-8") {
            out -> generateService(out, className, dir.toString())
        }
    }

    def serviceImplDir = dir.toString() + "\\service\\impl"
    def serviceImplFile = new File(serviceImplDir, className + "ServiceImpl.java")
    if (!serviceImplFile.exists()) {
        serviceImplFile.withPrintWriter("UTF-8") {
            out -> generateServiceImpl(out, className, dir.toString())
        }
    }


    def mapperDir = dir.toString() + "\\dao\\mapper"
    def mapperFile = new File(mapperDir, className + "Mapper.java")
    if (!mapperFile.exists()) {
        mapperFile.withPrintWriter("UTF-8") {
            out -> generateMapper(out, className, fields, table, mapperDir)
        }
    }

    def mapperXml = new File(mapperDir, className + "Mapper.xml")
    if (!mapperXml.exists()) {
        mapperXml.withPrintWriter("UTF-8") {
            out -> generateMapperXml(out, className, fields, table, mapperDir)
        }
    }

    def dosDir = dir.toString() + "\\dao\\model"
    def dosFile = new File(dosDir, className + "DO.java")
    if (!dosFile.exists()) {
        dosFile.withPrintWriter("UTF-8") {
            out -> generateDO(out, className, fields, table, dosDir)
        }
    }

    def dtosDir = dir.toString() + "\\dto"
    def dtosFile = new File(dtosDir, className + "DTO.java")
    if (!dtosFile.exists()) {
        dtosFile.withPrintWriter("UTF-8") {
            out -> generateDTO(out, className, fields, table, dtosDir)
        }
    }


}


def generateDTO(out, className, fields, table, dir) {
    out.println "package " + getPackageName(dir)
    out.println ""
    out.println "import com.fasterxml.jackson.annotation.JsonIgnoreProperties;"
    out.println ""
    out.println "import java.io.Serializable;"
    out.println ""
    out.println "import io.swagger.annotations.ApiModel;\n" +
            "import io.swagger.annotations.ApiModelProperty;\n" +
            "import lombok.AllArgsConstructor;\n" +
            "import lombok.Builder;\n" +
            "import lombok.Data;\n" +
            "import lombok.NoArgsConstructor;"


    Set types = new HashSet()

    fields.each() {
        types.add(it.type)
    }

    if (types.contains("Date")) {
        out.println "import java.util.Date;"
    }

    if (types.contains("InputStream")) {
        out.println "import java.io.InputStream;"
    }
    if (types.contains("BigDecimal")) {
        out.println "import java.math.BigDecimal;"
    }
    out.println ""

    out.println "/**\n" +
            " * @author: $author\n" +
            " */"


    out.println ""
    out.println "@Data\n" +
            "@Builder\n" +
            "@NoArgsConstructor\n" +
            "@AllArgsConstructor\n" +
            "@JsonIgnoreProperties(ignoreUnknown = true)\n" +
            "@ApiModel(\"" + "${table.getComment()}" + "DTO\")"


    out.println "public class $className" + "DTO implements Serializable {"
    out.println ""


// 输出get/set方法
    fields.each() {
        if (!bases.contains(it.name)) {

            out.println "    @ApiModelProperty(\"${it.commoent}\")"
            out.println "    private ${it.type} ${it.name};"
            out.println ""



        }
    }
    out.println ""
    out.println "}"
}


//out指文件路径,className类名即表名,classConment表字段的注解,fields装载了一个表的所有列的列信息,元素类型为 Map 的 List。
def generateDO(out, className, fields, table, dir) {
    out.println "package " + getPackageName(dir)
    out.println ""
    out.println "import com.baomidou.mybatisplus.annotation.TableField;\n" +
            "import com.baomidou.mybatisplus.annotation.TableName;\n" +
            "\n" +
            "import java.io.Serializable;\n" +
            "\n" +
            "import $baseModelDO;\n" +
            "import lombok.AllArgsConstructor;\n" +
            "import lombok.Builder;\n" +
            "import lombok.Data;\n" +
            "import lombok.NoArgsConstructor;"


    Set types = new HashSet()

    fields.each() {
        types.add(it.type)
    }

    if (types.contains("Date")) {
        out.println "import java.util.Date;"
    }

    if (types.contains("InputStream")) {
        out.println "import java.io.InputStream;"
    }
    if (types.contains("BigDecimal")) {
        out.println "import java.math.BigDecimal;"
    }
    
    out.println ""

    out.println "/**\n" +
            " * @author: $author\n" +
            " */"


    out.println ""
    out.println "@Data\n" +
            "@Builder\n" +
            "@NoArgsConstructor\n" +
            "@AllArgsConstructor"

    out.println "@TableName(value = \"${table.getName()}\")"

    out.println "public class $className" + "DO extends BaseDO implements Serializable {"
    out.println ""


// 输出get/set方法
    fields.each() {
        if (!bases.contains(it.name)) {
            out.println "    /**\n" +
                    "     * ${it.commoent}\n" +
                    "     */"
            out.println "    @TableField(value = \"${it.colName}\")"
            out.println "    private ${it.type} ${it.name};"
            out.println ""

        }
    }
    out.println ""
    out.println "}"
}


//out指文件路径,className类名即表名,classConment表字段的注解,fields装载了一个表的所有列的列信息,元素类型为 Map 的 List。
def generateMapper(out, className, fields, table, dir) {
    out.println "package " + getPackageName(dir)
    out.println ""
    out.println "import com.baomidou.mybatisplus.core.mapper.BaseMapper;"
    out.println ""
    out.println "import org.apache.ibatis.annotations.Mapper;"
    out.println ""

    out.println "import $packageName" + ".dao.model." + "$className" + "DO;"
    out.println ""
    out.println "/**\n" +
            " * @author: $author\n" +
            " */"
    out.println ""
    out.println "@Mapper"
    out.println "public interface " + "$className" + "Mapper extends BaseMapper<" + "$className" + "DO> {"
    out.println ""
    out.println "}"
}


def generateMapperXml(out, className, fields, table, dir) {
    def name = getPackageNameForTemp(dir)
    out.println "<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
    out.println "<!DOCTYPE mapper PUBLIC \"-//mybatis.org//DTD Mapper 3.0//EN\" \"http://mybatis.org/dtd/mybatis-3-mapper.dtd\">"
    out.println "<mapper namespace=\"" + name + "." + "$className" +"Mapper"+ "\">"
    out.println "</mapper>"

}


def generateServiceImpl(out, className, dir) {
    out.println "package " + getPackageNameForTemp(dir) + ".service.impl;"
    out.println ""
    out.println "import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;"
    out.println ""
    out.println "import org.springframework.beans.factory.annotation.Autowired;"
    out.println "import org.springframework.stereotype.Service;"
    out.println ""
    out.println "import " + getPackageNameForTemp(dir) + ".dao.mapper." + "$className" + "Mapper;"
    out.println "import " + getPackageNameForTemp(dir) + ".dao.model." + "$className" + "DO;"
    out.println "import " + getPackageNameForTemp(dir) + ".service." + "$className" + "Service;"
    out.println "import lombok.extern.slf4j.Slf4j;"
    out.println ""
    out.println "/**\n" +
            " * @author: $author\n" +
            " */"
    out.println "@Slf4j"
    out.println "@Service"
    out.println "public class " + "$className" + "ServiceImpl extends ServiceImpl<" + "$className" + "Mapper, " + "$className" + "DO>\n" +
            "    implements " + "$className" + "Service {"
    out.println ""
    out.println "    @Autowired"
    out.println "    private " + "$className" + "Mapper " + getFirstLowerJavaName("$className") + "Mapper;"
    out.println ""
    out.println "}"
}

def generateService(out, className, dir) {

    out.println "package " + getPackageNameForTemp(dir) + ".service;"
    out.println ""
    out.println "import com.baomidou.mybatisplus.extension.service.IService;"
    out.println ""
    out.println "import " + getPackageNameForTemp(dir) + ".dao.model." + "$className" + "DO;"
    out.println ""
    out.println "/**\n" +
            " * @author: $author\n" +
            " */"
    out.println ""
    out.println " public interface " + "$className" + "Service extends IService<" + "$className" + "DO> {"
    out.println ""
    out.println ""
    out.println "}"
}


def generatePackageInfo(out, dir) {
    out.println "package " + getPackageName(dir)
}

// 获取包所在文件夹路径
def getPackageName(dir) {
    return getPackageNameForTemp(dir) + ";"
}

def getPackageNameForTemp(dir) {
    return dir.toString().replaceAll("\\\\", ".").replaceAll("/", ".").replaceAll("^.*src(\\.main\\.java\\.)?", "")
}


/**
 * 字段计算(处理)函数
 * calcFields 方法会遍历并取出 DasTable 中每一个字段的属性并放入 fields 中,
 * fields 类型相当于 java 中一个元素类型为 Map 的 List。
 */
def calcFields(table) {
    DasUtil.getColumns(table).reduce([]) { fields, col ->
        def spec = Case.LOWER.apply(col.getDataType().getSpecification())

        def typeStr = typeMapping.find { p, t -> p.matcher(spec).find() }.value
        def comm = [
                /**name   : fieldName(col.getName(), false),
                 * 我的字段会给字典字段和外键字段增加标识头,这里转属性时去掉
                 def fieldName(str, capitalize) {//去除开头的标识
                 if (str.startsWith("dict_")) str = str.substring(5)
                 if (str.startsWith("fk_")) str = str.substring(3)
                 def s = com.intellij.psi.codeStyle.NameUtil.splitNameIntoWords(str)
                 .collect { Case.LOWER.apply(it).capitalize() }.join("")
                 .replaceAll(/[^\p{javaJavaIdentifierPart}[_]]/, "_")
                 capitalize || s.length() == 1 ? s : Case.LOWER.apply(s[0]) + s[1..-1]}*/
                colName : col.getName(),
                // name : changeStyle (javaName(col.getName(), false) ,true),
                name    : javaName(col.getName(), false),
                type    : typeStr,
                commoent: col.getComment(),
                annos   : "\t@Column(name = \"" + col.getName() + "\" )"]
        if ("id".equals(Case.LOWER.apply(col.getName())))
            comm.annos += "\n\t@Id\n\t@GeneratedValue"
        fields += [comm]
    }
}

// 处理类名(这里是因为我的表都是以t_命名的,所以需要处理去掉生成类名时的开头的T,
// 如果你不需要那么请查找用到了 javaClassName这个方法的地方修改为 javaName 即可)
def javaClassName(str, capitalize) {
    def s = str.split(/[^\p{Alnum}]/).collect { def s = Case.LOWER.apply(it).capitalize() }.join("")
// 去除开头的T http://developer.51cto.com/art/200906/129168.htm
    //s = s[7..s.size() - 1]
    capitalize ? s : Case.LOWER.apply(s[0]) + s[1..-1]
}
/**
 * 类名以驼峰式命名
 */
/**
 * javaName 将数据库字段名映射为驼峰风格的 java 变量名。
 */
def javaName(str, capitalize) {
    def s = com.intellij.psi.codeStyle.NameUtil.splitNameIntoWords(str)
            .collect { Case.LOWER.apply(it).capitalize() }
            .join("")
            .replaceAll(/[^\p{javaJavaIdentifierPart}[_]]/, "_")
    capitalize || s.length() == 1 ? s : Case.LOWER.apply(s[0]) + s[1..-1]
}

def getFirstLowerJavaName(str) {
    str.length() == 1 ? str : Case.LOWER.apply(str[0]) + str[1..-1]
}

/**
 * 类名以表名(其中表名的首字母大写)命名
 */
/*def javaName(str, capitalize) {
    def s = str.split(/(?<=[^\p{IsLetter}])/).collect { Case.LOWER.apply(it).capitalize() }
            .join("").replaceAll(/[^\p{javaJavaIdentifierPart}]/, "_")
    capitalize || s.length() == 1? s : Case.LOWER.apply(s[0]) + s[1..-1]
}*/

def isNotEmpty(content) {
    return content != null && content.toString().trim().length() > 0
}


static String genSerialID() {
    return "\tprivate static final long serialVersionUID = " + Math.abs(new Random().nextLong()) + "L;"
}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值