多线程下载如何优雅地控制任务起停(改进版)

前言:

        小编之前写过一个Java版本的多线程HTTP下载服务器,实现了多线程下载、断点续传、限速、限流等功能,但控制任务起停方面的代码写得不够好,而且代码耦合度太高不方便扩展。历经半年修炼,我又重新把这个项目写了一个版本,新版本的优化有这些:

  • 代码通过TDD进行了解耦,每个模块可以单独测试,单测覆盖率达到了75%以上,保证了代码的准确性和可扩展性。
  • 重写下载模块和Redis模块所有接口,使用了自定义线程池异步操作。
  • 在控制任务暂停方面,原来是轮询Redis的任务状态,当检测到status = "canceled" 时来关闭线程从而中断下载;新版本中向线程池提交任务会返回一个Future对象,一个待下载文件会有若干个分片,因此可以得到Future List,将Future列表和任务ID存放在ConcurrentHashMap中,当调用pause接口时会执行Future的cancel方法从而中断线程取消下载。
  • 前端框架改用Vue3 + Vite + TailwindUI

新版本项目地址(欢迎star):  https://github.com/Lemon001017/HTTP-download-server-Java

直接上代码:

多线程下载:

@Service
public class TaskServiceImpl implements TaskService {
    @Autowired
    private SettingsMapper settingsMapper;

    @Autowired
    private TaskMapper taskMapper;

    @Autowired
    private RedisService redisService;

    @Autowired
    private SseService sseService;

    private static final Logger log = LoggerFactory.getLogger(TaskServiceImpl.class);

    private static final Object lock = new Object();

    private final ConcurrentHashMap<String, List<Future<?>>> chunkFutures = new ConcurrentHashMap<>();

    private final ThreadPoolExecutor downloadExecutor = new ThreadPoolExecutor(4, 8,
            60L, TimeUnit.SECONDS, new LinkedBlockingQueue<>());

    @Override
    public Result<String> submit(String url) {
        Result<String> result = new Result<>();
        String taskId = UUIDUtils.generateId();
        result.setCode(Constants.HTTP_STATUS_OK);
        result.setData(taskId);

        // Asynchronous processing download
        CompletableFuture.runAsync(() -> {
            try {
                Task task = initOneTask(taskId, url);
                redisService.initializeScoreboard(taskId, task.getChunkNum());
                processDownload(task);
            } catch (IOException | URISyntaxException e) {
                log.error("Submit task error id:{} err:{}", taskId, e.getMessage(), e);
                result.setCode(Constants.HTTP_STATUS_SERVER_ERROR);
                result.setMessage(Constants.ERR_SUBMIT_TASK);
            }
        });
        return result;
    }

    @SuppressWarnings("UnstableApiUsage")
    private void processDownload(Task task) throws IOException, URISyntaxException {
        long startTime = System.currentTimeMillis();
        task.setStatus(Constants.TASK_STATUS_DOWNLOADING);
        File outputFile = new File(task.getSavePath());
        List<Future<?>> futures = new ArrayList<>();
        List<Integer> scoreboard = redisService.getScoreboard(task.getId());
        RateLimiter limiter = RateLimiter.create(settingsMapper.selectOne(null).getMaxDownloadSpeed() * 1000 * 1000);

        // Submit the fragment for download
        for (int i = 0; i < task.getChunkNum(); i++) {
            int start = i * task.getChunkSize();
            int end = (int) Math.min(task.getSize(), start + task.getChunkSize()) - 1;
            int chunkIndex = i;
            if (scoreboard.contains(chunkIndex)) {
                Future<?> future = downloadExecutor.submit(() -> downloadChunk(task, start, end, outputFile, startTime, chunkIndex, limiter));
                futures.add(future);
            }
        }
        chunkFutures.put(task.getId(), futures);
    }

    @SuppressWarnings("UnstableApiUsage")
    private void downloadChunk(Task task, int start, int end, File file, long startTime, int index, RateLimiter limiter) {
        try {
            HttpURLConnection conn = getConn(task.getUrl());
            conn.setRequestProperty("Range", "bytes=" + start + "-" + end);
            BufferedInputStream in = new BufferedInputStream(conn.getInputStream());
            RandomAccessFile raf = new RandomAccessFile(file, "rw");

            synchronized (lock) {
                raf.seek(start);
            }

            byte[] buffer = new byte[4096];
            int bytesRead;

            long lastMessageTime = System.currentTimeMillis();

            while ((bytesRead = in.read(buffer)) != -1) {
                // Check whether the current thread is interrupted
                if (Thread.currentThread().isInterrupted()) {
                    log.info("Download paused for task id:{} threadId:{}", task.getId(), Thread.currentThread().threadId());
                    in.close();
                    raf.close();
                    conn.disconnect();
                    return;
                }
                limiter.acquire(bytesRead);
                raf.write(buffer, 0, bytesRead);
                synchronized (lock) {
                    task.setTotalDownloaded(task.getTotalDownloaded() + bytesRead);
                    if (System.currentTimeMillis() - lastMessageTime >= Constants.MessageInterval) {
                        // Calculate download data
                        long elapsedTime = System.currentTimeMillis() - startTime;
                        double speed = Math.round((task.getTotalDownloaded() / (elapsedTime / 1000.0) / 1024 / 1024) * 100.0) / 100.0;
                        double progress = Math.round((task.getTotalDownloaded() * 1.0 * 100 / task.getSize()) * 100.0) / 100.0;
                        double remainingTime = Math.round((((task.getSize() - task.getTotalDownloaded()) / 1024.0 / 1024.0) / speed) * 100.0) / 100.0;

                        task.setSpeed(speed);
                        task.setProgress(progress);
                        task.setRemainingTime(remainingTime);

                        sseService.send(task.getId(), task);
                        taskMapper.updateById(task);
                        lastMessageTime = System.currentTimeMillis();
                    }
                }
            }

            synchronized (lock) {
                redisService.updateScoreboard(task.getId(), index);
            }

            if (task.getTotalDownloaded() == task.getSize() || redisService.getScoreboard(task.getId()).isEmpty()) {
                log.info("Download complete id:{} url:{}", task.getId(), task.getUrl());
                redisService.deleteScoreboard(task.getId());
                task.setProgress(100);
                task.setRemainingTime(0);
                task.setStatus(Constants.TASK_STATUS_DOWNLOADED);
                taskMapper.updateById(task);
                sseService.send(task.getId(), task);
            }

            in.close();
            raf.close();
            conn.disconnect();
        } catch (IOException | URISyntaxException e) {
            log.error("Download failed id:{} err:{}", task.getId(), e.getMessage());
            task.setStatus(Constants.TASK_STATUS_FAILED);
            taskMapper.updateById(task);
        }
    }

    private int getChunkSize(long fileSize) {
        int chunkSize;
        if (fileSize < 10 * 1024 * 1024) {
            chunkSize = Constants.MIN_CHUNK_SIZE;
        } else if (fileSize < 100 * 1024 * 1024) {
            chunkSize = Constants.MID_CHUNK_SIZE;
        } else {
            chunkSize = Constants.MAX_CHUNK_SIZE;
        }
        return chunkSize;
    }

    private Task initOneTask(String id, String urlString) throws IOException, URISyntaxException {
        String downloadPath = settingsMapper.selectOne(null).getDownloadPath();

        HttpURLConnection conn = getConn(urlString);

        long fileSize = conn.getContentLength();
        int chunkSize = getChunkSize(fileSize);
        int chunkNums = (int) ((fileSize + chunkSize - 1) / chunkSize);

        String fileName = extractFileName(conn, urlString);
        String ext = fileName.substring(fileName.lastIndexOf("."));
        String outputPath = downloadPath + "/" + fileName;

        log.info("Init a task, id:{} fileSize:{} savePath:{} chunkSize:{} chunkNums:{}", id, fileSize, outputPath, chunkSize, chunkNums);

        Task task = new Task(
                id,
                fileName,
                ext,
                fileSize,
                urlString,
                outputPath,
                Constants.TASK_STATUS_PENDING,
                Constants.DEFAULT_THREADS,
                chunkNums,
                chunkSize,
                LocalDateTime.now()
        );

        taskMapper.insert(task);
        return task;
    }

    private HttpURLConnection getConn(String urlStr) throws IOException, URISyntaxException {
        URI uri = new URI(urlStr);
        URL url = uri.toURL();
        return (HttpURLConnection) url.openConnection();
    }

    private String extractFileName(HttpURLConnection connection, String urlString) {
        String fileName = null;

        Map<String, List<String>> headers = connection.getHeaderFields();
        List<String> contentDisposition = headers.get("Content-Disposition");

        if (contentDisposition != null && !contentDisposition.isEmpty()) {
            String disposition = contentDisposition.getFirst();
            int index = disposition.indexOf("filename=");
            if (index > 0) {
                fileName = disposition.substring(index + 10, disposition.length() - 1);
            }
        }

        if (fileName == null) {
            fileName = urlString.substring(urlString.lastIndexOf("/") + 1);
        }

        return fileName;
    }

 暂停任务下载:

@Override
    public Result<List<String>> pause(List<String> ids) {
        Result<List<String>> result = new Result<>();
        List<Task> tasks = taskMapper.selectBatchIds(ids);
        for (Task task : tasks) {
            if (task.getStatus().equals(Constants.TASK_STATUS_DOWNLOADING)) {
                task.setStatus(Constants.TASK_STATUS_CANCELED);
                taskMapper.updateById(task);

                List<Future<?>> futures = chunkFutures.get(task.getId());
                if (futures != null) {
                    for (Future<?> future : futures) {
                        future.cancel(true);
                    }
                } else {
                    log.error("The task futures is null id:{}", task.getId());
                }
            } else {
                log.error("The task status is not downloading id:{} status:{}", task.getId(), task.getStatus());
                result.setCode(Constants.HTTP_STATUS_BAD_REQUEST);
                result.setMessage("Task status is not downloading");
                return result;
            }
        }
        result.setData(ids);
        result.setCode(Constants.HTTP_STATUS_OK);
        return result;
    }

总结:

        这个项目我前前后后一共写了三遍(Go一个版本,Java两个版本),4月份的时候第一次写完,现在回过头来看自己之前写的代码简直一坨hhh,所以最近花了10天重构了一下。做完这个项目可以学习到多线程的相关知识、如何使用Redis的Pub/Sub来发布和订阅消息、Java中的锁如何使用、如何实现一个限流算法以及Java中如何使用SSE来实现客户端与服务端的通信等。

        如果有小伙伴对这个项目感兴趣的欢迎私聊我,也可以互相学习互相进步!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值