废话不多说,直接上代码
1、自定义注解方式
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface AutoId {
/**
* @return id类型(默认为雪花id)
*/
IdType value() default IdType.SNOWFLAKE;
/**
* id类型
*/
enum IdType {
/**
* UUID去掉“-”
*/
UUID,
/**
* 雪花id
*/
SNOWFLAKE
}
}
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import java.lang.reflect.Field;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
/**
* 主键id自动填充拦截器插件
* 支持UUID和雪花算法
*
* @author csw
* date 2022-02-14
*/
@Intercepts({
@Signature(type = Executor.class, method = "update", args = {
MappedStatement.class, Object.class
}),
})
public class AutoIdInterceptor implements Interceptor {
private Long workerId;
private Long dataCenterId;
/**
* key值为class对象 value可以理解成是该类带有AutoId注解的属性,只不过对属性封装了一层。
* 它是非常能够提高性能的处理器 它的作用就是不用每一次一个对象经来都要看下它的哪些属性带有AutoId注解
* 毕竟类的反射在性能上并不友好。只要key包含该对象那就不需要检查它哪些属性带AutoId注解。
*/
private Map<Class<?>, List<Handler>> handlerMap = new ConcurrentHashMap<>();
public AutoIdInterceptor(Long workerId, Long dataCenterId) {
if (workerId == null) {
this.workerId = SnowflakeIdWorker.getWorkerId();
} else {
this.workerId = workerId;
}
if (dataCenterId == null) {
this.dataCenterId = SnowflakeIdWorker.getDataCenterId();
} else {
this.dataCenterId = dataCenterId;
}
}
public synchronized Long getWorkerId() {
return workerId;
}
public synchronized void setWorkerId(Long workerId) {
this.workerId = workerId;
if (handlerMap.size() > 0) {
handlerMap.forEach((key, value) -> {
if (value != null && value.size() > 0) {
for (Handler handler : value) {
handler.setWorkerId(workerId);
}
}
});
}
}
public synchronized Long getDataCenterId() {
return dataCenterId;
}
public synchronized void setDataCenterId(Long dataCenterId) {
this.dataCenterId = dataCenterId;
if (handlerMap.size() > 0) {
handlerMap.forEach((key, value) -> {
if (value != null && value.size() > 0) {
for (Handler handler : value) {
handler.setDataCenterId(dataCenterId);
}
}
});
}
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object[] args = invocation.getArgs();
// args数组对应对象就是上面@Signature注解中args对应的对应类型
MappedStatement mappedStatement = (MappedStatement) args[0];
//实体对象
Object entity = args[1];
if (SqlCommandType.INSERT.equals(mappedStatement.getSqlCommandType())) {
// 获取实体集合
Set<Object> entitySet = getEntitySet(entity);
// 批量设置id
for (Object object : entitySet) {
process(object);
}
}
return invocation.proceed();
}
/**
* object是需要插入的实体数据,它可能是对象,也可能是批量插入的对象。
* 如果是单个对象,那么object就是当前对象
* 如果是批量插入对象,那么object就是一个map集合,key值为"list",value为ArrayList集合对象
*/
private Set<Object> getEntitySet(Object object) {
//
Set<Object> set = new HashSet<>();
if (object instanceof Map) {
//批量插入对象
/*
* modify by liuzixi 2020-04-17
* 修改处理对象的方式,原代码会导致自定义insert方法带多个入参时报错
*/
Map<String, Object> map = (Map<String, Object>) object;
map.forEach((key, value) -> {
if (value instanceof Collection) {
set.addAll((Collection) value);
} else {
set.add(value);
}
});
// Collection values = (Collection) ((Map) object).get("list");
// for (Object value : values) {
// if (value instanceof Collection) {
// set.addAll((Collection) value);
// } else {
// set.add(value);
// }
// }
} else {
//单个插入对象
set.add(object);
}
return set;
}
private void process(Object object) throws Exception {
Class<?> handlerKey = object.getClass();
List<Handler> handlerList = handlerMap.get(handlerKey);
//TODO 性能优化点,如果有两个都是user对象同时,那么只需有个进行反射处理属性就好了,另一个只需执行下面的for循环
sync:
if (handlerList == null || handlerList.isEmpty()) {
synchronized (this) {
handlerList = handlerMap.get(handlerKey);
//如果到这里map集合已经存在,则跳出到指定SYNC标签
if (handlerList != null && !handlerList.isEmpty()) {
break sync;
}
handlerMap.put(handlerKey, handlerList = new ArrayList<>());
// 反射工具类 获取带有AutoId注解的所有属性字段
Class<?> cls = object.getClass();
List<Field> idFields = getIdFields(cls);
// 增加父类,最多两层
Class<?> superClass = cls.getSuperclass();
if (superClass != null && !superClass.equals(Object.class)) {
idFields.addAll(getIdFields(superClass));
Class<?> doubleSuperClass = superClass.getSuperclass();
if (!doubleSuperClass.equals(Object.class)) {
idFields.addAll(getIdFields(doubleSuperClass));
}
}
// 如果无AutoId注解,不执行
if (idFields != null && !idFields.isEmpty()) {
for (Field idField : idFields) {
AutoId annotation = idField.getAnnotation(AutoId.class);
// 1、添加UUID字符串作为主键
if (idField.getType().isAssignableFrom(String.class)) {
if (annotation.value().equals(AutoId.IdType.UUID)) {
handlerList.add(new UUIDHandler(idField));
// 2、添加String类型雪花ID
// 注意:此处每个含AutoId注解的字段对应一个SnowflakeIdWorker对象,实现每张表各自使用独立的Sequence
} else if (annotation.value().equals(AutoId.IdType.SNOWFLAKE)) {
handlerList.add(new UniqueLongHexHandler(idField, new SnowflakeIdWorker(workerId, dataCenterId)));
}
} else if (idField.getType().isAssignableFrom(Long.class)) {
// 3、添加Long类型的雪花ID
// 注意:此处每个含AutoId注解的字段对应一个SnowflakeIdWorker对象,实现每张表各自使用独立的Sequence
if (annotation.value().equals(AutoId.IdType.SNOWFLAKE)) {
handlerList.add(new UniqueLongHandler(idField, new SnowflakeIdWorker(workerId, dataCenterId)));
}
}
}
}
}
}
for (Handler handler : handlerList) {
handler.accept(object);
}
}
/**
* 获取
*
* @param clazz
* @return
*/
private List<Field> getIdFields(Class<?> clazz) {
Field[] allFields = clazz.getDeclaredFields();
return Arrays.stream(allFields).filter(field -> field.getAnnotation(AutoId.class) != null).collect(Collectors.toList());
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
/**
* ID处理器基类
*/
private static abstract class Handler {
Field field;
Handler(Field field) {
this.field = field;
}
abstract void handle(Field field, Object object) throws Exception;
abstract void setWorkerId(Long workerId);
abstract void setDataCenterId(Long dataCenterId);
private boolean checkField(Object object, Field field) throws IllegalAccessException {
if (!field.isAccessible()) {
field.setAccessible(true);
}
//如果该注解对应的属性已经被赋值,那么就不用通过雪花生成的ID
return field.get(object) == null;
}
public void accept(Object o) throws Exception {
if (checkField(o, field)) {
handle(field, o);
}
}
}
/**
* UUID处理器
*/
private static class UUIDHandler extends Handler {
UUIDHandler(Field field) {
super(field);
}
/**
* 1、插入UUID主键
*/
@Override
void handle(Field field, Object object) throws Exception {
field.set(object, UUID.randomUUID().toString());
}
@Override
void setWorkerId(Long workerId) {
}
@Override
void setDataCenterId(Long dataCenterId) {
}
}
/**
* Long型雪花ID处理器
*/
private static class UniqueLongHandler extends Handler {
private SnowflakeIdWorker idWorker;
UniqueLongHandler(Field field, SnowflakeIdWorker idWorker) {
super(field);
this.idWorker = idWorker;
}
/**
* 2、插入Long类型雪花ID
*/
@Override
void handle(Field field, Object object) throws Exception {
field.set(object, idWorker.nextId());
}
@Override
void setWorkerId(Long workerId) {
idWorker.setWorkerId(workerId);
}
@Override
void setDataCenterId(Long dataCenterId) {
idWorker.setDatacenterId(dataCenterId);
}
}
/**
* String型雪花ID处理器
*/
private static class UniqueLongHexHandler extends Handler {
private SnowflakeIdWorker idWorker;
UniqueLongHexHandler(Field field, SnowflakeIdWorker idWorker) {
super(field);
this.idWorker = idWorker;
}
/**
* 3、插入String类型雪花ID
*/
@Override
void handle(Field field, Object object) throws Exception {
field.set(object, String.valueOf(idWorker.nextId()));
}
@Override
void setWorkerId(Long workerId) {
idWorker.setWorkerId(workerId);
}
@Override
void setDataCenterId(Long dataCenterId) {
idWorker.setDatacenterId(dataCenterId);
}
}
}
2、创建对象方式
package com.ebiz.base.db.plugin;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
/**
* Twitter_Snowflake<br>
* SnowFlake的结构如下(每部分用-分开):<br>
* 0 - 0000000000 0000000000 0000000000 0000000000 0 - 00000 - 00000 - 000000000000 <br>
* 1位标识,由于long基本类型在Java中是带符号的,最高位是符号位,正数是0,负数是1,所以id一般是正数,最高位是0<br>
* 41位时间截(毫秒级),注意,41位时间截不是存储当前时间的时间截,而是存储时间截的差值(当前时间截 - 开始时间截)
* 得到的值),这里的的开始时间截,一般是我们的id生成器开始使用的时间,由我们程序来指定的(如下下面程序IdWorker类的startTime属性)。41位的时间截,可以使用69年,年T = (1L << 41) / (1000L * 60 * 60 * 24 * 365) = 69<br>
* 10位的数据机器位,可以部署在1024个节点,包括5位datacenterId和5位workerId<br>
* 12位序列,毫秒内的计数,12位的计数顺序号支持每个节点每毫秒(同一机器,同一时间截)产生4096个ID序号<br>
* 加起来刚好64位,为一个Long型。<br>
* SnowFlake的优点是,整体上按照时间自增排序,并且整个分布式系统内不会产生ID碰撞(由数据中心ID和机器ID作区分),并且效率较高,经测试,SnowFlake每秒能够产生26万ID左右。
* https://www.cnblogs.com/relucent/p/4955340.html
*/
@Slf4j
public class SnowflakeIdWorker {
// 部分默认值
/**
* 机器id所占的位数
*/
private static final long WORKER_ID_BITS = Long.parseLong(System.getProperty("spring.datasource.druid.worker-id-bits", "8"));
/**
* 机房id所占的位数
*/
private static final long DATACENTER_ID_BITS = 10 - WORKER_ID_BITS;
/**
* 支持的最大机器id,结果是31 (这个移位算法可以很快的计算出几位二进制数所能表示的最大十进制数)
*/
private static final long MAX_WORKER_ID = -1L ^ (-1L << WORKER_ID_BITS);
/**
* 支持的最大数据标识id,结果是31
*/
private static final long MAX_DATACENTER_ID = -1L ^ (-1L << DATACENTER_ID_BITS);
// ==============================Fields===========================================
/**
* 开始时间截 (2015-01-01)
*/
private final long twepoch = 1420041600000L;
/**
* 机器id所占的位数
*/
private final long workerIdBits = WORKER_ID_BITS;
/**
* 机房id所占的位数
*/
private final long datacenterIdBits = DATACENTER_ID_BITS;
/**
* 支持的最大机器id
*/
private final long maxWorkerId = -1L ^ (-1L << workerIdBits);
/**
* 支持的最大数据标识id,结果是31
*/
private final long maxDatacenterId = -1L ^ (-1L << datacenterIdBits);
/**
* 序列在id中占的位数
*/
private final long sequenceBits = 12L;
/**
* 机器ID向左移12位
*/
private final long workerIdShift = sequenceBits;
/**
* 数据标识id向左移17位(12+5)
*/
private final long datacenterIdShift = sequenceBits + workerIdBits;
/**
* 时间截向左移22位(5+5+12)
*/
private final long timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits;
/**
* 生成序列的掩码,这里为4095 (0b111111111111=0xfff=4095)
*/
private final long sequenceMask = -1L ^ (-1L << sequenceBits);
/**
* 工作机器ID(0~31)
*/
private long workerId;
/**
* 数据中心ID(0~31)
*/
private long datacenterId;
/**
* 毫秒内序列(0~4095)
*/
private long sequence = 0L;
/**
* 上次生成ID的时间截
*/
private long lastTimestamp = -1L;
//==============================Constructors=====================================
/**
* 构造函数
*
* @param wid 工作ID (0~31)
* @param did 数据中心ID (0~31)
*/
public SnowflakeIdWorker(Long wid, Long did) {
long workerId = wid == null ? getWorkerId() : wid;
long datacenterId = did == null ? getDataCenterId() : did;
if (workerId > maxWorkerId || workerId < 0) {
throw new IllegalArgumentException(String.format("worker Id can't be greater than %d or less than 0", maxWorkerId));
}
if (datacenterId > maxDatacenterId || datacenterId < 0) {
throw new IllegalArgumentException(String.format("datacenter Id can't be greater than %d or less than 0", maxDatacenterId));
}
this.workerId = workerId;
this.datacenterId = datacenterId;
log.info("SnowflakeIdWorker已构建。workerId: [{}], datacenterId: [{}]", this.workerId, this.datacenterId);
}
public static long getWorkerId() {
String ip = getLocalFirstIp();
log.info("当前机器的IP地址: {}", ip);
if (StringUtils.isEmpty(ip)) {
// 随机数
Random random = new Random();
return random.nextInt((int) MAX_WORKER_ID);
}
String[] ipArray = ip.split("\\.");
List<Long> ipNums = new ArrayList<>();
for (int i = 0; i < 4; ++i) {
ipNums.add(Long.parseLong(ipArray[i].trim()));
}
long zhongIPNumTotal = ipNums.get(0) * 256L * 256L * 256L
+ ipNums.get(1) * 256L * 256L + ipNums.get(2) * 256L
+ ipNums.get(3);
return zhongIPNumTotal % MAX_WORKER_ID;
}
public void setWorkerId(long workerId) {
this.workerId = workerId;
}
// ==============================Methods==========================================
public static long getDataCenterId() {
String ip = getLocalFirstIp();
if (StringUtils.isEmpty(ip)) {
// 随机数
Random random = new Random();
return random.nextInt((int) MAX_DATACENTER_ID);
}
String[] ipArray = ip.split("\\.");
List<Long> ipNums = new ArrayList<>();
for (int i = 0; i < 4; ++i) {
ipNums.add(Long.parseLong(ipArray[i].trim()));
}
long zhongIPNumTotal = ipNums.get(0) * 256L * 256L * 256L
+ ipNums.get(1) * 256L * 256L + ipNums.get(2) * 256L
+ ipNums.get(3);
return zhongIPNumTotal % MAX_DATACENTER_ID;
}
/**
* 获取当前服务器ip地址
*
* @return
*/
private static String getLocalFirstIp() {
try {
Enumeration<NetworkInterface> allNetInterfaces = NetworkInterface.getNetworkInterfaces();
InetAddress ip = null;
while (allNetInterfaces.hasMoreElements()) {
NetworkInterface netInterface = allNetInterfaces.nextElement();
if (!netInterface.isLoopback() && !netInterface.isVirtual() && netInterface.isUp()) {
Enumeration<InetAddress> addresses = netInterface.getInetAddresses();
while (addresses.hasMoreElements()) {
ip = addresses.nextElement();
if (ip != null && ip instanceof Inet4Address) {
// System.out.println(ip.getHostAddress());
return ip.getHostAddress();
}
}
}
}
} catch (Exception e) {
log.error("IP地址获取失败" + e.toString());
}
return "";
}
public void setDatacenterId(long datacenterId) {
this.datacenterId = datacenterId;
}
/**
* 获得下一个ID (该方法是线程安全的)
*
* @return SnowflakeId
*/
public synchronized long nextId() {
long timestamp = timeGen();
//如果当前时间小于上一次ID生成的时间戳,说明系统时钟回退过这个时候应当抛出异常
if (timestamp < lastTimestamp) {
throw new RuntimeException(
String.format("Clock moved backwards. Refusing to generated id for %d milliseconds", lastTimestamp - timestamp));
}
//如果是同一时间生成的,则进行毫秒内序列
if (lastTimestamp == timestamp) {
sequence = (sequence + 1) & sequenceMask;
//毫秒内序列溢出
if (sequence == 0) {
//阻塞到下一个毫秒,获得新的时间戳
timestamp = tilNextMillis(lastTimestamp);
}
}
//时间戳改变,毫秒内序列重置
else {
sequence = 0L;
}
//上次生成ID的时间截
lastTimestamp = timestamp;
//移位并通过或运算拼到一起组成64位的ID
return ((timestamp - twepoch) << timestampLeftShift) //
| (datacenterId << datacenterIdShift) //
| (workerId << workerIdShift) //
| sequence;
}
/**
* 阻塞到下一个毫秒,直到获得新的时间戳
*
* @param lastTimestamp 上次生成ID的时间截
* @return 当前时间戳
*/
protected long tilNextMillis(long lastTimestamp) {
long timestamp = timeGen();
while (timestamp <= lastTimestamp) {
timestamp = timeGen();
}
return timestamp;
}
/**
* 返回以毫秒为单位的当前时间
*
* @return 当前时间(毫秒)
*/
protected long timeGen() {
return System.currentTimeMillis();
}
//==============================Test=============================================
/**
* 测试
*/
// public static void main(String[] args) {
// SnowflakeIdWorker idWorker = new SnowflakeIdWorker(0L, 0L);
// while (true) {
// System.out.println("输入数量:");
// Scanner sc = new Scanner(System.in);
// // 获取键盘输入的int数字
// int num = sc.nextInt();
// StringBuilder writeMe = new StringBuilder();
// for (int i = 0; i < num; i++) {
// long id = idWorker.nextId();
// // System.out.println(Long.toBinaryString(id));
// writeMe.append(id);
// writeMe.append("\n");
// }
// Clipboard clip = Toolkit.getDefaultToolkit().getSystemClipboard();
// Transferable tText = new StringSelection(writeMe.toString());
// clip.setContents(tText, null);
// }
// }
}