今天脑子里闪过使用ThreadLocal实现计数器的念头,百度了一下,没有讲到怎么聚合所有进程各自的计数器值。所以自己实现一个,代码如下。
import java.util.WeakHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
public class ThradeLocalTest {
static class Counter {
private static class Entry {
long q0, q1, q2, q3, q4, q5, q6, q7, q8, q9, qa, qb, qc, qd, qe;
long count = 0;
}
private static Lock lock = new ReentrantLock();
private static WeakHashMap<Thread, Entry> map = new WeakHashMap<>();
private static ThreadLocal<Entry> local = ThreadLocal.withInitial(Entry::new);
public void increase() {
Entry entry = local.get();
long count = entry.count;
if (count == 0) {
lock.lock();
try {
map.put(Thread.currentThread(), entry);
} finally {
lock.unlock();
}
}
local.get().count = count + 1;
}
public long getAll() {
return map.entrySet().stream().map(entry->entry.getValue().count).reduce(0L, Long::sum);
}
}
public static void main(String[] args) throws InterruptedException {
Counter counter = new Counter();
int number = 100;
Thread[] threads = new Thread[number];
for (int i = 0; i < number; i++) {
threads[i] = new Thread(()->{
for (int j = 0; j < 100_000_000; j++) {
counter.increase();
}
});
}
for (Thread thread1 : threads) {
thread1.start();
}
System.out.println(counter.getAll());
for (Thread thread2 : threads) {
thread2.join();
}
System.out.println(counter.getAll());
}
}
该代码有个问题,就是如果Thread被回收了,对应的计数就丢失了,所以需要自己实现一下存储计数使用的map,改完之后发代码。
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
public class ThreadLocalTest {
static class Counter {
private static class Entry {
long q0, q1, q2, q3, q4, q5, q6, q7, q8, q9, qa, qb, qc, qd, qe;
long count = 0;
}
private volatile long removedTotal = 0;
long q0, q1, q2, q3, q4, q5, q6, q7, q8, q9, qa, qb, qc, qd, qe;
private Lock lock = new ReentrantLock();
private Map<WeakReference<Thread>, Entry> map = new HashMap<>();
private ReferenceQueue<Object> queue = new ReferenceQueue<>();
private final ThreadLocal<Entry> local = ThreadLocal.withInitial(Entry::new);
public void increase() {
Entry entry = local.get();
long count = entry.count;
if (count == 0) {
lock.lock();
try {
map.put(new WeakReference<>(Thread.currentThread(), queue), entry);
} finally {
lock.unlock();
}
}
expunge();
local.get().count = count + 1;
}
private void expunge() {
for (Object x; (x = queue.poll()) != null; ) {
synchronized (queue) {
lock.lock();
try {
Entry e = map.get(x);
map.remove(x);
removedTotal += e.count;
} finally {
lock.unlock();
}
}
}
}
public long getAll() {
return map.entrySet().stream().map(entry->entry.getValue().count).reduce(0L, Long::sum) + removedTotal;
}
}
public static void main(String[] args) throws InterruptedException {
Counter counter = new Counter();
int number = 4;
Thread[] threads = new Thread[number];
for (int i = 0; i < number; i++) {
threads[i] = new Thread(()->{
for (int j = 0; j < 100_000_000; j++) {
counter.increase();
}
});
}
for (Thread thread1 : threads) {
thread1.start();
}
for (Thread thread2 : threads) {
thread2.join();
}
for (int i = 0; i < number; i++) {
threads[i] = new Thread(()->{
for (int j = 0; j < 100_000_000; j++) {
counter.increase();
}
});
}
long start = System.currentTimeMillis();
for (Thread thread1 : threads) {
thread1.start();
}
for (Thread thread2 : threads) {
thread2.join();
}
System.out.println(System.currentTimeMillis() - start);
System.out.println(counter.getAll());
start = System.currentTimeMillis();
for (long i = 0; i < 800_000_000L; i++) {}
System.out.println(System.currentTimeMillis() - start);
}
}
极端情况啥的还没测试,没有触及gc。
又修改了一个版本,不使用ThreadLocal了,创建一个subCounter用于每个线程计数。
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
public class ThreadLocalTest {
static class Counter {
private static class Entry {
long count = 0;
}
public interface SubCounter {
void increase();
}
private class SubCounterImpl implements SubCounter {
private final Entry entry;
SubCounterImpl(Entry entry) {
this.entry = entry;
}
public void increase() {
entry.count++;
}
}
private volatile long removedTotal = 0;
long q0, q1, q2, q3, q4, q5, q6, q7, q8, q9, qa, qb, qc, qd, qe;
private final Lock lock = new ReentrantLock();
private final Map<WeakReference<SubCounter>, Entry> map = new HashMap<>();
private final ReferenceQueue<Object> queue = new ReferenceQueue<>();
public SubCounter createSubCounter() {
expunge();
Entry entry = new Entry();
SubCounter subCounter = new SubCounterImpl(entry);
lock.lock();
try {
map.put(new WeakReference<>(subCounter, queue), entry);
} finally {
lock.unlock();
}
return subCounter;
}
private void expunge() {
for (Object x; (x = queue.poll()) != null; ) {
synchronized (queue) {
lock.lock();
try {
Entry e = map.get(x);
map.remove(x);
removedTotal += e.count;
} finally {
lock.unlock();
}
}
}
}
public long getAll() {
expunge();
return map.entrySet().stream().map(entry->entry.getValue().count).reduce(0L, Long::sum) + removedTotal;
}
}
public static void main(String[] args) throws InterruptedException {
Counter counter = new Counter();
int number = 4;
Thread[] threads = new Thread[number];
for (int i = 0; i < number; i++) {
threads[i] = new Thread(new Runnable() {
Counter.SubCounter subCounter = counter.createSubCounter();
@Override
public void run() {
for (int j = 0; j < 100_000_000; j++) {
subCounter.increase();
}
}
}
);
}
for (Thread thread1 : threads) {
thread1.start();
}
for (Thread thread2 : threads) {
thread2.join();
}
for (int i = 0; i < number; i++) {
threads[i] = new Thread(new Runnable() {
Counter.SubCounter subCounter = counter.createSubCounter();
@Override
public void run() {
for (int j = 0; j < 100_000_000; j++) {
subCounter.increase();
}
}
}
);
}
long start = System.currentTimeMillis();
for (Thread thread1 : threads) {
thread1.start();
}
for (Thread thread2 : threads) {
thread2.join();
}
System.out.printf("Time:%d\n", System.currentTimeMillis() - start);
System.out.printf("Count:%d\n", counter.getAll());
start = System.currentTimeMillis();
int intCount = 0;
for (long i = 0; i < 800_000_000L; i++) {
intCount++;
}
System.out.printf("Time:%d\n", System.currentTimeMillis() - start);
System.out.printf("Count:%d\n", intCount);
// 测试垃圾回收
Counter counter2 = new Counter();
start = System.currentTimeMillis();
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 200_000; j++) {
counter2.createSubCounter().increase();
}
}
System.out.printf("Time:%d\n", System.currentTimeMillis() - start);
System.out.printf("Count:%d\n", counter2.getAll());
System.out.printf("counter size:%d\n", counter.map.size());
System.out.printf("counter2 size:%d\n", counter2.map.size());
}
}
配置最大运行内存为2M,运行结果如下:
Time:29
Count:800000000
Time:1193
Count:800000000
Time:5945
Count:400000
counter size:5
counter2 size:504