package jartest;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.*;
/**
* jar包加载并发测试
*/
public class LoaderJarTest {
static Logger logger = LoggerFactory.getLogger(LoaderJarTest.class);
private static final Map<String, JarTestInterface> apiMap = new ConcurrentHashMap<>();
//生成jar包方法
//cd \test\test-boot\target\test-classes\
//jar -cef jartest.JarTestInterface ../../test/resources/version_one.jar jartest/impl
//jar -cef jartest.JarTestInterface ../../test/resources/version_two.jar jartest/impl
@Before
public void before() {
logger.info("测试加载jar包,加载异常版本问题");
}
/**
* 栅栏 测试
* 注重准备工作完成之后同时执行线程
*/
@Test
public void testCyclicBarrier() {
final var begin = System.currentTimeMillis();
try {
final int count = 1000;
final CyclicBarrier barrier = new CyclicBarrier(count, () -> System.out.println("开始执行"));
for (int i = 0; i < count; i++) {
new Thread(() -> {
// System.out.println("准备工作");
try {
barrier.await();
} catch (InterruptedException | BrokenBarrierException e) {
e.printStackTrace();
}
var one = getInstance("one");
var two = getInstance("two");
if (!StringUtils.equals("one", one.version())) {
logger.error("获取版本错误!");
}
if (!StringUtils.equals("two", two.version())) {
logger.error("获取版本错误!");
}
}).start();
}
} finally {
System.out.println("总共用时:" + (System.currentTimeMillis() - begin));
}
}
/**
* 发令枪 测试
* 注重完成count线程事件之后执行
*/
@Test
public void testLatch() throws Exception {
var begin = System.currentTimeMillis();
try {
final int count = 1000;
CountDownLatch countDownLatch = new CountDownLatch(count);
for (var i = 0; i < count; i++) {
new Thread(() -> {
var one = getInstance("one");
var two = getInstance("two");
if (!StringUtils.equals("one", one.version())) {
logger.error("获取版本错误!");
}
if (!StringUtils.equals("two", two.version())) {
logger.error("获取版本错误!");
}
countDownLatch.countDown();
}).start();
}
countDownLatch.await();
} finally {
System.out.println("总共用时:" + (System.currentTimeMillis() - begin));
}
}
/**
* 信号量 测试
* acquire占用n个许可证才会执行 /release释放许可证
*/
@Test
public void testSemaphore() {
final var begin = System.currentTimeMillis();
try {
final int count = 1000;
final int n = 1;
final Semaphore semaphore = new Semaphore(count);
for (int i = 0; i < count; i++) {
new Thread(() -> {
try {
semaphore.acquire(n);
var one = getInstance("one");
var two = getInstance("two");
if (!StringUtils.equals("one", one.version())) {
logger.error("获取版本错误!");
}
if (!StringUtils.equals("two", two.version())) {
logger.error("获取版本错误!");
}
semaphore.release(n);
} catch (InterruptedException e) {
e.printStackTrace();
}
}).start();
}
} finally {
System.out.println("总共用时:" + (System.currentTimeMillis() - begin));
}
}
/**
* 线程池 测试
*/
@Test
public void testPool() throws Exception {
var begin = System.currentTimeMillis();
ScheduledExecutorService pool = new ScheduledThreadPoolExecutor(50,
new BasicThreadFactory.Builder().namingPattern("loadJar-pool-%d").daemon(Boolean.TRUE).build());
for (var i = 0; i < 1000; i++) {
pool.submit(() -> {
JarTestInterface one = getInstance("one");
JarTestInterface two = getInstance("two");
if (!StringUtils.equals("one", one.version())) {
logger.error("获取版本错误!");
}
if (!StringUtils.equals("two", two.version())) {
logger.error("获取版本错误!");
}
});
}
pool.shutdown();
pool.awaitTermination(1, TimeUnit.HOURS);
System.out.println("总共用时:" + (System.currentTimeMillis() - begin));
}
/**
* 获取jar包class
*/
public static JarTestInterface getInstance(String version) {
JarTestInterface api = null;
try {
//缓存
api = apiMap.get(version);
if (Objects.nonNull(api)) {
return api;
}
try {
//获取jar包class
var jarFile = new File(getJarPath(version));
var myClassLoader1 = new URLClassLoader(new URL[]{jarFile.toURI().toURL()}, Thread.currentThread()
.getContextClassLoader());
var clazz = myClassLoader1.loadClass("jartest.impl.JarTestImpl");
api = (JarTestInterface) clazz.getDeclaredConstructor().newInstance();
} catch (Exception e) {
e.printStackTrace();
}
return api;
} finally {
//注释下面一行,每次重新读取jar的class
apiMap.put(version, api);
}
}
/**
* 项目路径
*/
public static String getJarPath(String version) {
var path = new File(LoaderJarTest.class.getClassLoader().getResource("").getFile()).getPath() + File.separator;
return path.replace("target\\test-classes", "test\\resources").concat("version_").concat(version).concat(".jar");
}
}