使用ThreadPoolTaskExecutor和countDownLatch实现异步下载功能

今天主要是一个实践性的文章,在阅读公司代码的时候发现我们这边将多个文件打包成zip下载的方式有点东西(虽说网上有类似教程),但是还是想要自己去动手尝试学习下,感兴趣的同学可以跟着我一步步来。

一、准备程序骨架

我们既然要实现下载文件,首先需要这个大量文件可供我们下载,其次是需要将这些文件的主要信息(这个实践主要是文件id,文件名,文件主要位置)给存储起来,方便查找以及下载将这些信息反馈给用户。那么在这里我们一般能想到的就是通过数据库去存储这些文件信息,那么我们这个该系统就使用JUC+springboot自带线程池+mybatisplus去完成该功能。

1.1、pom.xml

由于我是在之前的学习demo的基础上做的这个项目,因此有父项目的存在,那么首先就是父依赖:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>design_mode_start</artifactId>
    <packaging>pom</packaging>
    <version>1.0-SNAPSHOT</version>
    <modules>
        <module>design_mode_start_01</module>
    </modules>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <java.version>1.8</java.version>
    </properties>
<dependencyManagement>
    <dependencies>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-dependencies</artifactId>
            <version>2.3.0.RELEASE</version>
            <type>pom</type>
            <scope>import</scope>
        </dependency>
    </dependencies>
</dependencyManagement>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>
</project>

然后就是我们实践用到的子项目依赖:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <artifactId>design_mode_start</artifactId>
        <groupId>org.example</groupId>
        <version>1.0-SNAPSHOT</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>design_mode_start_01</artifactId>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
    </properties>
    <dependencies>

        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>5.1.47</version>
        </dependency>
        <dependency>
            <groupId>com.baomidou</groupId>
            <artifactId>mybatis-plus-boot-starter</artifactId>
            <version>3.2.0</version>
        </dependency>
        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>druid-spring-boot-starter</artifactId>
            <version>1.1.20</version>
        </dependency>
        <dependency>
            <groupId>cn.hutool</groupId>
            <artifactId>hutool-all</artifactId>
            <version>5.7.20</version>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>1.18.20</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
        </dependency>
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-compress</artifactId>
            <version>1.18</version>
        </dependency>
        <dependency>
            <groupId>commons-io</groupId>
            <artifactId>commons-io</artifactId>
            <version>2.5</version>
        </dependency>
    </dependencies>
    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <configuration>
                    <excludes>
                        <exclude>
                            <groupId>org.projectlombok</groupId>
                            <artifactId>lombok</artifactId>
                        </exclude>
                    </excludes>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

1.2、yaml配置

然后就是yaml配置,主要是去配置datasource相关配置:

server:
  port: 9090

spring:
  #DATABASE CONFIG
  datasource:
    #driver-class-name: com.mysql.jdbc.Driver
    #driver-class-name: com.p6spy.engine.spy.P6SpyDriver
    driverClassName: com.mysql.jdbc.Driver
    username: root
    password: 123456
    url: jdbc:mysql://127.0.0.1:3306/file?serverTimezone=GMT%2B8&characterEncoding=UTF-8&allowMultiQueries=true&useSSL=false&rewriteBatchedStatements=true
    type: com.alibaba.druid.pool.DruidDataSource   #这里是配置druid连接池,以下都是druid的配置信息
    druid:
      # 初始连接数
      initialSize: 5
      # 最小连接池数量
      minIdle: 10
      # 最大连接池数量
      maxActive: 20
      # 配置获取连接等待超时的时间
      maxWait: 60000
      # 配置间隔多久才进行一次检测,检测需要关闭的空闲连接,单位是毫秒
      timeBetweenEvictionRunsMillis: 60000
      # 配置一个连接在池中最小生存的时间,单位是毫秒
      minEvictableIdleTimeMillis: 300000
      # 配置一个连接在池中最大生存的时间,单位是毫秒
      maxEvictableIdleTimeMillis: 900000
      # 配置检测连接是否有效
      testWhileIdle: true
      validationQuery: SELECT 1 FROM DUAL
      testOnBorrow: false
      testOnReturn: false
      webStatFilter:
        enabled: true
      statViewServlet:
        enabled: true
        # 设置白名单,不填则允许所有访问
        allow:
        url-pattern: /druid/*
        # 控制台管理用户名和密码
        login-username:
        login-password:
      filter:
        stat:
          enabled: true
          # 慢SQL记录
          log-slow-sql: true
          slow-sql-millis: 1000
          merge-sql: true
        wall:
          config:
            multi-statement-allow: true





mybatis-plus:
  #扫描mapper文件所在位置
  mapper-locations: classpath*:com.mbw.mapper/*.xml
  #可以指定实体类所在包路径
  typeAliasesPackage: com.mbw.thread
  global-config:
    banner: false
  configuration:
    map-underscore-to-camel-case: off

1.3、数据准备

1.3.1、准备文件

①首先我们需要准备在一个文件夹中去准备大量文件
大家可以在自己路径新建一个空文件夹,然后在该文件夹里面,新建一个.bat
然后用记事本方式打开,复制以下内容

@echo off
rem 按指定名称规则创建多个txt文本文件
mode con lines=1000
set #=Any question&set @=WX&set $=Q&set/az=0x53b7e0b4
title %#% +%$%%$%/%@% %z%
cd /d "%~dp0"
for /f "tokens=2 delims==" %%a in ('wmic OS get LocalDateTime /value ^|find "="') do set d=%%a
 
for %%a in (1-100 200-300) do (
    for /f "tokens=1,2 delims=-" %%b in ("%%~a") do call :create "mbw" "%%b" "%%c" "%d:~0,8%"
)
 
:end
echo;%#% +%$%%$%/%@% %z%
pause
exit
 
:create
set "prefix=%~1"
set "begin=%~2"
set "end=%~3"
set "day=%~4"
 
:loop
set "file=%prefix%%begin%-%day%.txt"
echo;"%file%"
cd.>"%file%"
set /a begin+=1
if %begin% geq %end% goto break
goto loop
 
:break
exit/b

改脚本就是去创建200个文件,数字范围为(1-100,200-300),文件类型是txt,然后文件名统一为“mbw+数字+日期”这样的格式。点击执行后效果如下:
在这里插入图片描述

1.3.2、通过mybatisplus将文件相关信息数据批量导入进数据库

①filedata表

DROP TABLE IF EXISTS `filedata`;
CREATE TABLE `filedata`  (
  `id` bigint(20) NOT NULL,
  `fileName` varchar(50) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
  `fileLocation` varchar(100) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
  `createTime` datetime(0) NULL DEFAULT CURRENT_TIMESTAMP(0),
  `updateTime` datetime(0) NULL DEFAULT CURRENT_TIMESTAMP(0) ON UPDATE CURRENT_TIMESTAMP(0),
  `isDelete` tinyint(1) NULL DEFAULT 0,
  PRIMARY KEY (`id`) USING BTREE
) ENGINE = InnoDB CHARACTER SET = utf8 COLLATE = utf8_general_ci ROW_FORMAT = Dynamic;

SET FOREIGN_KEY_CHECKS = 1;

②然后准备实体类FileData:

package com.mbw.thread;

import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@AllArgsConstructor
@NoArgsConstructor
@TableName("fileData")
public class FileData {
    @TableId(type = IdType.ID_WORKER)
    private Long id;
    private String fileName;
    private String fileLocation;
}

③FileDataMapper及其xml

package com.mbw.mapper;

import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.mbw.thread.FileData;
import org.apache.ibatis.annotations.Mapper;

@Mapper
public interface FileDataMapper extends BaseMapper<FileData> {
}

<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.mbw.mapper.FileDataMapper">

</mapper>

④FileDataService及其实现类

import com.baomidou.mybatisplus.extension.service.IService;
import com.mbw.thread.FileData;

import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.List;


public interface FileDataService extends IService<FileData> {
    void batchInsertFile();
}

@Slf4j
@Service
public class FileDataServiceImpl extends ServiceImpl<FileDataMapper, FileData> implements FileDataService {

    public final String ENCODING = "UTF-8";

    @Resource
    private FileDataMapper fileDataMapper;

    @Override
    public void batchInsertFile() {
        try {
            ArrayList<FileData> files = new ArrayList<>();
            String filepath = "E:\\txt";
            File file = new File(filepath);
            if (!file.isDirectory()) {
                log.info("Not folder");
            } else if (file.isDirectory()) {
                log.info("Be folder");
                String[] fileList = file.list();
                if(CollectionUtil.isEmpty(Arrays.asList(fileList))){
                    throw new Exception("文件夹是空的");
                }
                for (String s : fileList) {
                    File readFile = new File(filepath + "\\" + s);
                    String absolutePath = readFile.getAbsolutePath();
                    String fileName = readFile.getName();
                    FileData fileData = new FileData();
                    fileData.setFileName(fileName);
                    fileData.setFileLocation(absolutePath);
                    files.add(fileData);
                }
                this.saveBatch(files);
                log.info("All finished");
            }
        } catch (Exception e) {
            log.warn(e.getMessage(),e);
        }

    }
}

该类主要是遍历存放所有txt文件的文件夹,将该文件名,文件路径全部存进用来批量插入的FileData的list,然后通过mybatisplus批量插入将该list传入实现文件信息插入数据库
要注意的是,在批量插入之前,需要在yaml的datasource配置那块儿url加上:

rewriteBatchedStatements=true

然后写一个测试类执行批量插入:

import com.mbw.service.FileDataService;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.junit4.SpringRunner;

@SpringBootTest
@RunWith(SpringRunner.class)
@Slf4j
public class TestFile {

    @Autowired
    private FileDataService fileDataService;
    @Test
    public void testSave() {
        fileDataService.batchInsertFile();
    }
}

在这里插入图片描述
然后就是本文重点

1.4、实现批量下载

我们需要使用springboot自带的线程池ThreadPoolTaskExecutor去实现,那么首先就需要对该线程池去做相关配置

1.4.1、ThreadPoolTaskConfig

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.concurrent.ThreadPoolExecutor;

@Configuration
@EnableAsync
public class ThreadPoolTaskConfig {

    private static final int I = Runtime.getRuntime().availableProcessors();//获取到服务器的cpu内核
    /** 核心线程数(默认线程数) */
    private static final int CORE_POOL_SIZE = 5;
    /** 最大线程数 */
    private static final int MAX_POOL_SIZE = 5;
    /** 允许线程空闲时间(单位:默认为秒) */
    private static final int KEEP_ALIVE_TIME = 10;
    /** 缓冲队列大小 */
    private static final int QUEUE_CAPACITY = 0;
    /** 线程池名前缀 */
    private static final String THREAD_NAME_PREFIX = "mbw-Async-";
    @Bean("taskExecutor") // bean的名称,默认为首字母小写的方法名
    public ThreadPoolTaskExecutor taskExecutor(){
        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
        executor.setCorePoolSize(CORE_POOL_SIZE);
        executor.setMaxPoolSize(MAX_POOL_SIZE);
        executor.setQueueCapacity(QUEUE_CAPACITY);
        executor.setKeepAliveSeconds(KEEP_ALIVE_TIME);
        executor.setThreadNamePrefix(THREAD_NAME_PREFIX);
        // 线程池对拒绝任务的处理策略
        // CallerRunsPolicy:由调用线程(提交任务的线程)处理该任务
        executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
        // 初始化
        executor.initialize();
        return executor;
    }
}

首先代码中的I代表服务器中的CPU内核数,一般来说核心线程数可以设置为该内核数*2,但是如果我这样设置的话,我的内核数是12,这样配置的话核心线程数就是24,不方便我后面测试拒绝策略是否生效,所以在这里直接写成5,一般开发将它设置为I*2即可,其次就是最大线程数我这里也设置成5,任务队列数量直接设置为0,方便后面测试拒绝策略。拒绝策略我这边选择通过调用线程去处理,也就是可以是主线程,也可以是内嵌tomcat的任务线程,让执行任务的这个线程去执行多出来的任务。然后然后将配置好的ThreadPoolTaskExecutor作为bean通过配置方法返回。

1.4.2、FileDataService

然后在service层增加批量下载的方法:

void batchDownloadCert(List<Long> fileIds, HttpServletResponse response) throws IOException, InterruptedException;

1.4.3、FileDataServiceImpl

接着在实现类实现该方法:
首先通过@Resource注入我们配置的线程池,注意@Resource是根据name进行装配,所以这里变量名要和我们配置类中的配置方法名一致

  @Resource
  ThreadPoolTaskExecutor taskExecutor;

然后就是执行下载任务,这个下载任务的主要核心就是:
通过countDownlatch通过await()阻塞住执行主任务的线程,然后通过线程池去异步执行将文件写入zip的任务,直到线程池将全部文件写完到zip后,countDownLatch计数减为0,执行主任务的线程才可以执行后面的任务。

 @Override
    public void batchDownloadCert(List<Long> fileIds, HttpServletResponse response) throws IOException, InterruptedException {
        log.info("开始异步下载====");
        long start = System.currentTimeMillis();
        StringBuilder stringBuilder = new StringBuilder();
        String zipName = stringBuilder.append("mbw").append(DateUtil.format(new Date(), "YYYY-MM-dd HH-mm-ss")).toString();
        String fileType = "txt";
        response.setHeader("content-type","application/octet-stream;charset=" + ENCODING);
        response.setHeader("Content-Disposition", "attachment;filename=" + URLEncoder.encode(zipName + ".zip" , ENCODING));
        response.setContentType("application/octet-stream;charset="+ENCODING);
        ServletOutputStream outputStream = response.getOutputStream();
        ZipArchiveOutputStream zous = new ZipArchiveOutputStream(outputStream);
        //防止中文乱码
        zous.setEncoding(ENCODING);
        zous.setUseZip64(Zip64Mode.AsNeeded);
        final String baseAddress = System.getProperty("java.io.tmpdir");
        final String timeStamp = DateUtil.format(new Date(), "yyyyMMddHHmmssSSS");
        final String baseDir = baseAddress + "testCountDownLatch" + timeStamp;
        log.info(baseDir);
        CountDownLatch countDownLatch = new CountDownLatch(fileIds.size());
        try {
            for (Long id : fileIds) {
                //每次循环都让线程池取一个线程执行下载任务,实现异步下载
                taskExecutor.submit(new DownloadTask(countDownLatch, zous, baseDir, id, fileDataMapper));
            }
            countDownLatch.await();
            long end = System.currentTimeMillis();
            log.info("耗时==== " + (end - start) + "ms");

        } catch (Exception e) {
            log.warn("文件打包下载出错", e);
            throw new IOException("下载失败");
        } finally {
            zous.finish();
            zous.close();
            log.info("下载结束");
        }

    }

那么在这里需要注意的几点就是:
①写zip要注意乱码问题:

 zous.setEncoding(ENCODING);
zous.setUseZip64(Zip64Mode.AsNeeded);

②关于countDownlatch,一定要用await阻塞执行主任务的线程,然后关于countDownLatch的大小,直接通过传的文件id的这个list大小决定就好

1.4.4、DownloadTask

然后就是线程池submit的这个任务类,首先由于是通过线程池submit该任务,那么这个类一定要实现Callable或者Runnable接口,然后覆写对应方法,那么我们这个任务主要就是在之前的for循环中通过传进来的该文件的id,然后在数据库找到该文件信息,将相关信息写入txt,然后将txt放入zip,此时zip是公共资源,记住需要加锁,然后写完后,countDownLatch调用countDown()将计数-1.

package com.mbw.service.impl;

import cn.hutool.json.JSONUtil;
import com.mbw.mapper.FileDataMapper;
import com.mbw.thread.FileData;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.archivers.zip.ZipArchiveEntry;
import org.apache.commons.compress.archivers.zip.ZipArchiveOutputStream;

import java.util.HashMap;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;

@Slf4j
public class DownloadTask implements Callable {

    private CountDownLatch countDownLatch;
    private ZipArchiveOutputStream zous;
    private String baseDir;
    private Long id;
    private FileDataMapper fileDataMapper;

    public DownloadTask(CountDownLatch countDownLatch, ZipArchiveOutputStream zous, String baseDir, Long id, FileDataMapper fileDataMapper) {
        this.countDownLatch = countDownLatch;
        this.zous = zous;
        this.baseDir = baseDir;
        this.id = id;
        this.fileDataMapper = fileDataMapper;
    }

    @Override
    public Object call() throws Exception {
        try{
            log.info("当前线程:{}",Thread.currentThread().getName());
            //通过id获取文件信息
            FileData fileData = fileDataMapper.selectById(id);
            String fileName = fileData.getFileName();
            String fileLocation = fileData.getFileLocation();
            HashMap<String, String> result = new HashMap<>();
            //将文件信息放入结果集map
            result.put("fileName",fileName);
            result.put("fileLocation",fileLocation);
            String base64File = JSONUtil.toJsonStr(result);
            //得到file
            byte[] file = base64File.getBytes();
            String filePath =fileName + "-" + System.currentTimeMillis() + ".txt";
            //对zip公共资源加锁
            synchronized (zous){
                //将文件放入zip
                ZipArchiveEntry entry = new ZipArchiveEntry(filePath);
                zous.putArchiveEntry(entry);
                zous.write(file);
                zous.closeArchiveEntry();
            }

        } catch (Exception e) {
            log.warn(e.getMessage(), e);
        } finally {
            //全部执行完后计数-1
            countDownLatch.countDown();
        }
        return null;
    }
}

然后在controller写调用这个接口的方法:

import com.mbw.service.FileDataService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.List;

@RestController
@Slf4j
public class FileController {
    @Resource
    private FileDataService fileDataService;

    @PostMapping("/batch/download")
    public void BatchDownload(@RequestBody List<Long> fileIds , HttpServletResponse response) {
        try {
            fileDataService.batchDownloadCert(fileIds,response);
        } catch (IOException e) {
            log.warn(e.getMessage(),e);
        } catch (InterruptedException e) {
            log.warn(e.getMessage(),e);
        }
    }

}

接着启动程序,在postman调用:
为了测试拒绝策略,我们放六个id。
在这里插入图片描述
下载成功:
在这里插入图片描述
这时候可以去控制台看下日志:
发现5个核心线程均被分配用来执行异步下载,然后多出来的任务也由执行主任务的tomcat中的任务线程池的线程执行。
在这里插入图片描述

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

雨~旋律

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值