网上已经有很多实现。不过也记录下自己的java实现。
@Slf4j
class SnowFlake {
/**
* 时间戳更新频率, 毫秒
*/
private static final long TIMESTAMP_UPDATE_INTERVAL = 3_000;
/**
* 符号位,始终0
*/
private static final byte SYMBOL_BITS = 1;
/**
* 时间戳,可到 Wed May 15 15:35:11 CST 2039
*/
private static final byte TIMESTAMP_BITS = 41;
private static final byte WORK_ID_BITS = 10;
/**
* 毫秒并发量
*/
private static final byte SYNC_BITS = 12;
private volatile long lastTime;
private volatile int syncId = 0;
private Integer workerId;
SnowFlake(Integer workerId, Long lastTime) {
// 溢出检查
if (workerId == null) {
throw new IllegalStateException("workerId 空");
}
if (workerId > (1 << WORK_ID_BITS) - 1) {
throw new IllegalStateException("workerId 溢出");
}
this.lastTime = lastTime;
this.workerId = workerId;
}
/**
* 获取 id
*/
final long next() throws ClockBackException {
long now;
final int maxRetryTimes = 5;
int retryTimes = 0;
while ((now = Instant.now().toEpochMilli()) <= this.lastTime && retryTimes < maxRetryTimes) {
// 可能NTP导致时钟回拨则等待
long interval = this.lastTime - now;
if (interval <= TIMESTAMP_UPDATE_INTERVAL) {
try {
TimeUnit.MILLISECONDS.sleep(interval);
} catch (InterruptedException ignored) {
}
retryTimes++;
continue;
}
log.error("时钟回拨, lastTime = {}, now = {}", this.lastTime, now);
throw new ClockBackException();
}
long lt;
int si;
synchronized (this) {
lt = lastTime;
si = syncId++;
}
long rtn = generateId(lt, si);
if ((now = Instant.now().toEpochMilli()) > this.lastTime) {
synchronized (this) {
this.lastTime = now;
this.syncId = 0;
}
}
return rtn;
}
/**
* 组合
*/
private long generateId(long timestamp, int syncId) {
// 溢出检查
if (timestamp > ((1L << TIMESTAMP_BITS) - 1)) {
throw new IllegalStateException("时间戳溢出");
}
if (syncId > ((1 << SYNC_BITS) - 1)) {
throw new IllegalStateException("syncId 溢出");
}
// 序列检查
// 符号位 + 时间戳 + workerId + 同步序列
long t = 0L;
t = t << SYMBOL_BITS;
t = (t << TIMESTAMP_BITS) | timestamp;
t = (t << WORK_ID_BITS) | this.workerId;
t = (t << SYNC_BITS) | syncId;
return t;
}
}
下面是简单的测试
public class SnowFlakeTest {
private SnowFlake instance;
@Before
public void setUp() {
this.instance = new SnowFlake(1, Instant.now().toEpochMilli());
}
@Test
public void next() {
Instant start = Instant.now();
List ids = LongStream.generate(() -> {
try {
return instance.next();
} catch (IllegalStateException e) {
return 0L;
}
})
.parallel()
.limit(1<<12).boxed()
.collect(Collectors.toList());
Instant end = Instant.now();
long elapseTime = start.until(end, ChronoUnit.MILLIS);
System.out.println("耗时(ms) " + elapseTime);
List validIds = ids.stream().filter(it -> it != 0L).collect(Collectors.toList());
HashSet uniqueIds = new HashSet<>(validIds);
System.out.println(String.format("生产 %d , 实际 %d ", ids.size(), validIds.size()));
Assert.assertEquals(uniqueIds.size(), validIds.size());
}
}