需求:
最近项目写了poi导入excel数据到数据库,在代码上已经算是很优了,虽然领导没有要求我优化导入接口,但是本着技术而言,想把学到的知识用于实践,于是使用多线程方式导入excel。
所需pow依赖:
<dependency>
<groupId>org.apache.poi</groupId>
<artifactId>poi</artifactId>
<version>3.17</version>
</dependency>
<dependency>
<groupId>org.apache.poi</groupId>
<artifactId>poi-ooxml</artifactId>
<version>3.17</version>
</dependency>
导入的service实现类:
/**
* 多线程导入
* @param file
* @return
* @throws Exception
*/
@Override
public Map<String,Object> importData(MultipartFile file) throws Exception{
final Date now = new Date();
SimpleDateFormat format = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
logger.info("{},开始导入数据...", format.format(now));
//设置一个信号量为5的信号量,限制同时运行的线程数量最大为5
Semaphore semaphore = new Semaphore(10);
Map<String,Object> map = new HashMap<>();
//多线程编程需要一个线程安全的ArrayList
List<ErrorInfo> list = Collections.synchronizedList(new ArrayList<ErrorInfo>());
Workbook workbook = null;
String filename = file.getOriginalFilename();
if(filename.endsWith("xls")){
workbook = new HSSFWorkbook(file.getInputStream());
}else if(filename.endsWith("xlsx")){
workbook = new XSSFWorkbook(file.getInputStream());
}else {
ErrorInfo errorInfo = new ErrorInfo();
errorInfo.setErrorMsg("请上传xlx或xlsx格式的文件");
list.add(errorInfo);
map.put("code",501);
map.put("data",list);
return map;
}
Sheet sheet = workbook.getSheetAt(0);
int physicalNumberOfRows = sheet.getPhysicalNumberOfRows();
logger.info("获取到workbook中的总行数:{}" ,physicalNumberOfRows);
//第一行是表头,实际行数要减1
int rows = physicalNumberOfRows - 1;
//一个线程让他处理200个row,也许可以处理更多吧
int threadNum = rows/200 + 1; //线程数量
//设置一个倒计时门闩,用来处理主线程等待蚂蚁线程执行完成工作之后再运行
CountDownLatch countDownLatch = new CountDownLatch(threadNum);
//查询是否重名
Set<String> names = this.findAllUser().stream().map(User::getUsername).collect(Collectors.toSet());
//创建一个定长的线程池
ExecutorService executorService = Executors.newFixedThreadPool(threadNum);
logger.info("开始创建线程,数据总行数:{},线程数量:{}",rows,threadNum);
List<Future<Integer>> futures = new ArrayList<>();
int successCount = 0;
for(int i = 1; i <= threadNum; i++){
int startRow = (i-1)*200 +1;
int endRow = i*200;
if(i == threadNum){
endRow = rows;
}
logger.info("开始执行线程方法,线程ID:<{}>,线程名称:<{}>",Thread.currentThread().getId(),Thread.currentThread().getName());
Future<Integer> future = executorService.submit(new UserThread(semaphore,workbook, startRow, endRow, list, names,this,countDownLatch));
futures.add(future);
logger.info("结束线程执行方法,返回结果:<{}>,当前线程ID:<{}>,当前线程名称:<{}>",JSON.toJSONString(future),Thread.currentThread().getId(),Thread.currentThread().getName());
//get方法中可以设置超时时间,即规定时间内没有返回结果,则继续运行
//get方法是线程阻塞的,调用get方法会导致后续线程因主线程阻塞而没有创建,达不到效果。
//successCount += future.get();
}
//主线程等待子线程完成任务,60秒还没执行完成就继续执行
for(Future<Integer> future : futures){
successCount += future.get();
}
//主线程等待子线程全部跑完才继续运行。设置60秒等待时间,超时后继续执行。
countDownLatch.await(60,TimeUnit.SECONDS);
executorService.shutdown();
Date endDate = new Date();
long difference = endDate.getTime() - now.getTime();
String duration = DurationFormatUtils.formatDuration(difference, "HH:mm:ss");
logger.info("执行完成,错误信息:{}", JSON.toJSONString(list));
logger.info("{},结束导入,共{}条数据,导入成功:{},耗时={}", format.format(endDate), rows,successCount,duration);
map.put("code",200);
map.put("msg","结束导入,共" + rows + "条数据,导入成功" + successCount + "条,耗时:" +duration);
map.put("data",list);
return map;
}
导入线程类:
package com.thread.demo.thread;
import com.thread.demo.common.ErrorInfo;
import com.thread.demo.entity.User;
import com.thread.demo.service.UserService;
import org.apache.poi.ss.usermodel.Cell;
import org.apache.poi.ss.usermodel.Row;
import org.apache.poi.ss.usermodel.Sheet;
import org.apache.poi.ss.usermodel.Workbook;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
/**
* @Author Honey
* @Date 2019/11/15 10:31
* @Description
**/
public class UserThread implements Callable<Integer>{
private Logger logger = LoggerFactory.getLogger(UserThread.class);
private Workbook workbook;
private Integer startRow;
private Integer endRow;
private List<ErrorInfo> errorInfoList;
private Set<String> names;
private UserService userService;
private Semaphore semaphore;
private CountDownLatch latch;
public UserThread(Semaphore semaphore,Workbook workbook,Integer startRow,Integer endRow,List<ErrorInfo> errorInfoList,Set<String> names,UserService userService,CountDownLatch latch){
this.workbook = workbook;
this.startRow = startRow;
this.endRow = endRow;
this.errorInfoList = errorInfoList;
this.names = names;
this.userService = userService;
this.semaphore = semaphore;
this.latch = latch;
}
@Override
public Integer call() throws Exception {
logger.info("线程ID:<{}>开始运行,startRow:{},endRow:{}",Thread.currentThread().getId(),startRow,endRow);
semaphore.acquire();
logger.info("消耗了一个信号量,剩余信号量为:{}",semaphore.availablePermits());
latch.countDown();
Sheet sheet = workbook.getSheetAt(0);
int count = 0;
for(int i = startRow; i <= endRow; i++){
User user = new User();
Row row = sheet.getRow(i);
Cell cell1 = row.getCell(0);
String username = cell1.getStringCellValue();
user.setUsername(username);
user.setPassword("123456");
Cell cell2 = row.getCell(1);
String realname = cell2.getStringCellValue();
user.setRealName(realname);
if(names.contains(username)){
ErrorInfo errorInfo = new ErrorInfo();
errorInfo.setRow(startRow);
errorInfo.setColumn(1);
errorInfo.setErrorMsg("第" + startRow + "行用户账号已存在");
errorInfoList.add(errorInfo);
break;
}
count += userService.addUser(user);
}
semaphore.release();
return count;
}
}
controller也贴一下吧。没什么东西
package com.thread.demo.controller;
import com.thread.demo.service.UserService;
import org.apache.poi.ss.usermodel.Workbook;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import javax.servlet.http.HttpServletResponse;
import java.util.Map;
/**
* @Author Honey
* @Date 2019/11/15 10:27
* @Description
**/
@RestController
public class UserController {
@Autowired
private UserService userService;
/**
* 多线程导入
* @param file
* @return
*/
@PostMapping("/importManyThread")
public Map importData(MultipartFile file){
Map<String, Object> map = null;
try {
map = userService.importData(file);
return map;
} catch (Exception e) {
e.printStackTrace();
map.put("code",501);
map.put("msg","数据出错");
return map;
}
}
/**
* 单线程导入
* @param file
* @return
*/
@PostMapping("/importSingleThread")
public Map importData2(MultipartFile file){
Map<String, Object> map = null;
try {
map = userService.importDataYiBan(file);
return map;
} catch (Exception e) {
e.printStackTrace();
map.put("code",501);
map.put("msg","数据出错");
return map;
}
}
/**
* 导出excel
* @param response
* @throws Exception
*/
@GetMapping("/export")
public void exportData(HttpServletResponse response) throws Exception{
Workbook workbook = userService.exportData();
response.setContentType("application/vnd.ms-excel;charset=utf-8");
response.setCharacterEncoding("UTF-8");
//test.xls是弹出下载对话框的文件名,不能为中文,中文请自行编码
response.setHeader("Content-Disposition", "attachment;filename=user.xlsx");
workbook.write(response.getOutputStream());
}
}
执行结果:
使用多线程方式导入5000条数据花费时间14秒,而单线程导入则需1分钟14秒。
可见多线程方式运行程序是可以达到空间换时间的目的的。