Java 多线程下载技术实现

多线程下载

多线程下载技术,简单的说就是把要下载的文件分成几块,由不同的线程来负责每一块数据的下载任务。

技术要点

  • RandomAccessFile
    Java中用来实现随机访问文件的类
  • http Range请求头

具体思路

1、文件分块。 文件分块大小(blockSize)= (文件大小 +线程数 - 1 )/ 线程数 ;
2、确定每一个线程所要下载的 文件的起始和结束位置。
现假设为每个线程分别编号:0,1, 2,3;则
第一个线程负责的下载位置是: 0*blockSize - (0+1)*blockSize -1,
第二个线程负责的下载位置是: 1*blockSize - (1+1)*blockSize -1,
以此类推第i个线程负责的下载位置是:i*blockSize - (i+1)*blockSize -1;
即线程(编号为id)下载开始位置 start = id*block;
即线程(编号为id)下载结束位置 end = (id+1)*block -1;
3、设置http 请求头, conn.setRequestProperty(“Range”, “bytes=” + start + “-” + end);

代码实现

一个简单的Java多线程下载代码如下:

package com.ricky.java.test.download;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.RandomAccessFile;
import java.net.HttpURLConnection;
import java.net.URL;

public class Downloader {
    private URL url;    // 目标地址
    private File file;  // 本地文件
    private static final int THREAD_AMOUNT = 8;                 // 线程数
    private static final String DOWNLOAD_DIR_PATH = "D:/Download";      // 下载目录
    private int threadLen;                                      // 每个线程下载多少

    public Downloader(String address, String filename) throws IOException {     // 通过构造函数传入下载地址
        url = new URL(address);
        File dir = new File(DOWNLOAD_DIR_PATH);
        if(!dir.exists()){
            dir.mkdirs();
        }
        file = new File(dir, filename);
    }

    public void download() throws IOException {
        HttpURLConnection conn = (HttpURLConnection) url.openConnection();
        conn.setConnectTimeout(5000);

        int totalLen = conn.getContentLength();                             // 获取文件长度
        threadLen = (totalLen + THREAD_AMOUNT - 1) / THREAD_AMOUNT;         // 计算每个线程要下载的长度

        System.out.println("totalLen="+totalLen+",threadLen:"+threadLen);

        RandomAccessFile raf = new RandomAccessFile(file, "rws");           // 在本地创建一个和服务端大小相同的文件
        raf.setLength(totalLen);                                            // 设置文件的大小
        raf.close();

        for (int i = 0; i < THREAD_AMOUNT; i++)                             // 开启3条线程, 每个线程下载一部分数据到本地文件中
            new DownloadThread(i).start();
    }

    private class DownloadThread extends Thread {
        private int id;
        public DownloadThread(int id) {
            this.id = id;
        }
        public void run() {
            int start = id * threadLen;                     // 起始位置
            int end = id * threadLen + threadLen - 1;       // 结束位置
            System.out.println("线程" + id + ": " + start + "-" + end);

            try {
                HttpURLConnection conn = (HttpURLConnection) url.openConnection();
                conn.setConnectTimeout(5000);
                conn.setRequestProperty("Range", "bytes=" + start + "-" + end);     // 设置当前线程下载的范围

                InputStream in = conn.getInputStream();
                RandomAccessFile raf = new RandomAccessFile(file, "rws");
                raf.seek(start);            // 设置保存数据的位置

                byte[] buffer = new byte[1024];
                int len;
                while ((len = in.read(buffer)) != -1)
                    raf.write(buffer, 0, len);
                raf.close();

                System.out.println("线程" + id + "下载完毕");
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public static void main(String[] args) throws IOException {

        String address = "http://dldir1.qq.com/qqfile/qq/QQ7.9/16621/QQ7.9.exe";
        new Downloader(address, "QQ7.9.exe").download();

//      String address = "http://api.t.dianping.com/n/api.xml?cityId=2";
//      new Downloader(address, "2.xml").download();
    }
}

封装多线程下载

文件下载是一个常用的模块,我们可以对其封装一下,方便以后调用。涉及到的开发技术如下:

  • JDK 1.7
  • Eclipse Juno
  • Maven 3
  • HttpClient 4.3.6

工程目录结构如下所示:
这里写图片描述

com.ricky.common.java.download.FileDownloader

package com.ricky.common.java.download;

import org.apache.log4j.Logger;

import com.ricky.common.java.download.config.FileDownloaderConfiguration;

/**
 * Java 文件多线程下载
 * @author Ricky Fung
 *
 */
public class FileDownloader {

    protected Logger mLogger = Logger.getLogger("devLog");

    private volatile static FileDownloader fileDownloader;

    private FileDownloaderEngine downloaderEngine;

    private FileDownloaderConfiguration configuration;

    public static FileDownloader getInstance(){

        if(fileDownloader==null){
            synchronized (FileDownloader.class) {
                if(fileDownloader==null){
                    fileDownloader = new FileDownloader();
                }
            }
        }

        return fileDownloader;
    }

    protected FileDownloader(){
    }

    public synchronized void init(FileDownloaderConfiguration configuration){
        if (configuration == null) {
            throw new IllegalArgumentException("FileDownloader configuration can not be initialized with null");
        }
        if (this.configuration == null) {
            mLogger.info("init FileDownloader");
            downloaderEngine = new FileDownloaderEngine(configuration);
            this.configuration = configuration;
        }else{
            mLogger.warn("Try to initialize FileDownloader which had already been initialized before.");
        }

    }

    public boolean download(String url, String filename){

        return downloaderEngine.download(url, filename);
    }

    public boolean isInited() {
        return configuration != null;
    }

    public void destroy() {
        if(downloaderEngine!=null){
            downloaderEngine.close();
            downloaderEngine = null;
        }
    }
}

com.ricky.common.java.download.config.FileDownloaderConfiguration

package com.ricky.common.java.download.config;

import java.io.File;

public class FileDownloaderConfiguration {
    private final int connectTimeout;
    private final int socketTimeout;
    private final int maxRetryCount;
    private final int coreThreadNum;  
    private final long requestBytesSize;
    private final File downloadDestinationDir;  

    private FileDownloaderConfiguration(Builder builder) {  
        this.connectTimeout = builder.connectTimeout;  
        this.socketTimeout = builder.socketTimeout;  
        this.maxRetryCount = builder.maxRetryCount;  
        this.coreThreadNum = builder.coreThreadNum;  
        this.requestBytesSize = builder.requestBytesSize;
        this.downloadDestinationDir = builder.downloadDestinationDir;  
    }

    public int getConnectTimeout() {
        return connectTimeout;
    }
    public int getSocketTimeout() {
        return socketTimeout;
    }
    public int getMaxRetryCount() {
        return maxRetryCount;
    }
    public int getCoreThreadNum() {
        return coreThreadNum;
    }
    public long getRequestBytesSize() {
        return requestBytesSize;
    }
    public File getDownloadDestinationDir() {
        return downloadDestinationDir;
    }

    public static FileDownloaderConfiguration.Builder custom() {  
        return new Builder();  
    }

    public static class Builder {  
        private int connectTimeout;  
        private int socketTimeout;  
        private int maxRetryCount;  
        private int coreThreadNum;  
        private long requestBytesSize;  
        private File downloadDestinationDir;  

        public Builder connectTimeout(int connectTimeout) {  
            this.connectTimeout = connectTimeout;  
            return this;  
        }  
        public Builder socketTimeout(int socketTimeout) {  
            this.socketTimeout = socketTimeout;  
            return this;
        }  
        public Builder coreThreadNum(int coreThreadNum) {  
            this.coreThreadNum = coreThreadNum;  
            return this;  
        }  
        public Builder maxRetryCount(int maxRetryCount) {  
            this.maxRetryCount = maxRetryCount;  
            return this;  
        }  
        public Builder requestBytesSize(long requestBytesSize) {  
            this.requestBytesSize = requestBytesSize;  
            return this;  
        }  
        public Builder downloadDestinationDir(File downloadDestinationDir) {  
            this.downloadDestinationDir = downloadDestinationDir;  
            return this;  
        }

        public FileDownloaderConfiguration build() {  

            initDefaultValue(this);  

            return new FileDownloaderConfiguration(this);  
        }  

        private void initDefaultValue(Builder builder) {  

            if(builder.connectTimeout<1){  
                builder.connectTimeout = 6*1000;  
            }  

            if(builder.socketTimeout<1){  
                builder.socketTimeout = 6*1000;  
            }
            if(builder.maxRetryCount<1){
                builder.maxRetryCount = 1;  
            }  
            if(builder.coreThreadNum<1){  
                builder.coreThreadNum = 3;  
            }
            if(builder.requestBytesSize<1){  
                builder.requestBytesSize = 1024*128;  
            }
            if(builder.downloadDestinationDir==null){  
                builder.downloadDestinationDir = new File("./");  
            }
        }  
    }  
}

com.ricky.common.java.download.FileDownloaderEngine

package com.ricky.common.java.download;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.util.BitSet;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import org.apache.log4j.Logger;

import com.ricky.common.java.download.config.FileDownloaderConfiguration;
import com.ricky.common.java.download.job.DownloadWorker;
import com.ricky.common.java.download.job.Worker.DownloadListener;

public class FileDownloaderEngine {

    protected Logger mLogger = Logger.getLogger("devLog");

    private FileDownloaderConfiguration configuration;

    private ExecutorService pool;

    private HttpRequestImpl httpRequestImpl;

    private File downloadDestinationDir;

    private int coreThreadNum;

    public FileDownloaderEngine(FileDownloaderConfiguration configuration){

        this.configuration = configuration;

        this.coreThreadNum = configuration.getCoreThreadNum();
        this.httpRequestImpl = new HttpRequestImpl(this.configuration);
        this.pool = Executors.newFixedThreadPool(this.configuration.getCoreThreadNum());

        this.downloadDestinationDir = this.configuration.getDownloadDestinationDir();
        if(!this.downloadDestinationDir.exists()){
            this.downloadDestinationDir.mkdirs();
        }
    }

    public boolean download(String url, String filename){

        long start_time = System.currentTimeMillis();
        mLogger.info("开始下载,url:"+url+",filename:"+filename);

        long total_file_len = httpRequestImpl.getFileSize(url);                     // 获取文件长度

        if(total_file_len<1){
            mLogger.warn("获取文件大小失败,url:"+url+",filename:"+filename);
            return false;
        }

        final BitSet downloadIndicatorBitSet = new BitSet(coreThreadNum);   //标记每个线程下载是否成功

        File file = null;
        try {

            file = new File(downloadDestinationDir, filename);

            RandomAccessFile raf = new RandomAccessFile(file, "rws");           // 在本地创建一个和服务端大小相同的文件
            raf.setLength(total_file_len);                                          // 设置文件的大小
            raf.close();

            mLogger.info("create new file:"+file);

        } catch (FileNotFoundException e) {
            mLogger.error("create new file error", e);
        } catch (IOException e) {
            mLogger.error("create new file error", e);
        }

        if(file==null || !file.exists()){
            mLogger.warn("创建文件失败,url:"+url+",filename:"+filename);
            return false;
        }

        long thread_download_len = (total_file_len + coreThreadNum - 1) / coreThreadNum;            // 计算每个线程要下载的长度

        mLogger.info("filename:"+filename+",total_file_len="+total_file_len+",coreThreadNum:"+coreThreadNum+",thread_download_len:"+thread_download_len);

        CountDownLatch latch = new CountDownLatch(coreThreadNum);//两个工人的协作  

        for (int i = 0; i < coreThreadNum; i++){

            DownloadWorker worker = new DownloadWorker(i, url, thread_download_len, file, httpRequestImpl, latch);
            worker.addListener(new DownloadListener() {

                @Override
                public void notify(int thread_id, String url, long start, long end,
                        boolean result, String msg) {

                    mLogger.info("thread_id:"+thread_id+" download result:"+result+",url->"+url);

                    modifyState(downloadIndicatorBitSet, thread_id);
                }
            });

            pool.execute(worker);
        }

        try {
            latch.await();
        } catch (InterruptedException e) {
            mLogger.error("CountDownLatch Interrupt", e);
        }

        mLogger.info("下载结束,url:"+url+",耗时:"+((System.currentTimeMillis()-start_time)/1000)+"(s)");

        return downloadIndicatorBitSet.cardinality()==coreThreadNum;
    }

    private synchronized void modifyState(BitSet bitSet, int index){
        bitSet.set(index);
    }

    /**释放资源*/
    public void close(){

        if(httpRequestImpl!=null){
            httpRequestImpl.close();
            httpRequestImpl = null;
        }
        if(pool!=null){
            pool.shutdown();
            pool = null;
        }

    }

}

com.ricky.common.java.download.job.DownloadWorker

package com.ricky.common.java.download.job;

import java.io.File;
import java.util.concurrent.CountDownLatch;

import org.apache.log4j.Logger;

import com.ricky.common.java.download.HttpRequestImpl;
import com.ricky.common.java.download.RetryFailedException;

public class DownloadWorker extends Worker {

    protected Logger mLogger = Logger.getLogger("devLog");

    private int id;
    private String url;
    private File file;
    private long thread_download_len;

    private CountDownLatch latch;

    private HttpRequestImpl httpRequestImpl;

    public DownloadWorker(int id, String url, long thread_download_len, File file, HttpRequestImpl httpRequestImpl, CountDownLatch latch) {
        this.id = id;
        this.url = url;
        this.thread_download_len = thread_download_len;
        this.file = file;
        this.httpRequestImpl = httpRequestImpl;
        this.latch = latch;
    }

    @Override
    public void run() {

        long start = id * thread_download_len;                      // 起始位置
        long end = id * thread_download_len + thread_download_len - 1;      // 结束位置

        mLogger.info("线程:" + id +" 开始下载 url:"+url+ ",range:" + start + "-" + end);

        boolean result = false;
        try {
            httpRequestImpl.downloadPartFile(id, url, file, start, end);
            result = true;
            mLogger.info("线程:" + id + " 下载 "+url+ " range[" + start + "-" + end+"] 成功");

        } catch (RetryFailedException e) {
            mLogger.error("线程:" + id +" 重试出错", e);
        }catch (Exception e) {
            mLogger.error("线程:" + id +" 下载出错", e);
        }

        if(listener!=null){
            mLogger.info("notify FileDownloaderEngine download result");
            listener.notify(id, url, start, end, result, "");
        }

        latch.countDown();
    }

}

com.ricky.common.java.download.job.Worker

package com.ricky.common.java.download.job;

public abstract class Worker implements Runnable {

    protected DownloadListener listener;

    public void addListener(DownloadListener listener){
        this.listener = listener;
    }

    public interface DownloadListener{

        public void notify(int thread_id, String url, long start, long end, boolean result, String msg);
    }
}

com.ricky.common.java.download.HttpRequestImpl

package com.ricky.common.java.download;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.RandomAccessFile;

import org.apache.commons.io.IOUtils;
import org.apache.http.HttpEntity;
import org.apache.http.HttpStatus;
import org.apache.http.client.ClientProtocolException;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.log4j.Logger;

import com.ricky.common.java.download.config.FileDownloaderConfiguration;
import com.ricky.common.java.http.HttpClientManager;

public class HttpRequestImpl {

    protected Logger mLogger = Logger.getLogger("devLog");

    private int connectTimeout;  
    private int socketTimeout;  
    private int maxRetryCount;  
    private long requestBytesSize;  

    private CloseableHttpClient httpclient = HttpClientManager.getHttpClient();

    public HttpRequestImpl(FileDownloaderConfiguration configuration){
        connectTimeout = configuration.getConnectTimeout();
        socketTimeout = configuration.getSocketTimeout();
        maxRetryCount = configuration.getMaxRetryCount();
        requestBytesSize = configuration.getRequestBytesSize();
    }

    public void downloadPartFile(int id, String url, File file, long start, long end){

        RandomAccessFile raf = null;
        try {
            raf = new RandomAccessFile(file, "rws");
        } catch (FileNotFoundException e) {
            mLogger.error("file not found:"+file, e);
            throw new IllegalArgumentException(e);
        }

        int retry = 0;
        long pos = start;
        while(pos<end){

            long end_index = pos + requestBytesSize;
            if(end_index>end){
                end_index = end;
            }

            boolean success = false;
            try {
                success = requestByRange(url, raf, pos, end_index);
            } catch (ClientProtocolException e) {
                mLogger.error("download error,start:"+pos+",end:"+end_index, e);
            }catch (IOException e) {
                mLogger.error("download error,start:"+pos+",end:"+end_index, e);
            }catch (Exception e) {
                mLogger.error("download error,start:"+pos+",end:"+end_index, e);
            }

//          mLogger.info("线程:" + id +",download url:"+url+",range:"+ pos + "-" + end_index+",success="+success );

            if(success){
                pos += requestBytesSize;
                retry = 0;
            }else{
                if(retry < maxRetryCount){
                    retry++;
                    mLogger.warn("线程:" + id +",url:"+url+",range:"+pos+","+end_index+" 下载失败,重试"+retry+"次");
                }else{
                    mLogger.warn("线程:" + id +",url:"+url+",range:"+pos+","+end_index+" 下载失败,放弃重试!");
                    throw new RetryFailedException("超过最大重试次数");
                }
            }
        }

    }

    private boolean requestByRange(String url, RandomAccessFile raf, long start, long end) throws ClientProtocolException, IOException {

        HttpGet httpget = new HttpGet(url);
        httpget.setHeader("User-Agent", "Mozilla/5.0 (Windows NT 6.3; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/42.0.2311.152 Safari/537.36");
        httpget.setHeader("Range", "bytes=" + start + "-" + end);

        RequestConfig requestConfig = RequestConfig.custom()
                .setConnectTimeout(connectTimeout)
                .setSocketTimeout(socketTimeout)
                .build();

        httpget.setConfig(requestConfig);

        CloseableHttpResponse response = null;
        try {
            response = httpclient.execute(httpget);

            int code = response.getStatusLine().getStatusCode();

            if(code==HttpStatus.SC_OK || code== HttpStatus.SC_PARTIAL_CONTENT){

                HttpEntity entity = response.getEntity();

                if (entity != null) {

                    InputStream in = entity.getContent();
                    raf.seek(start);// 设置保存数据的位置

                    byte[] buffer = new byte[1024];
                    int len;
                    while ((len = in.read(buffer)) != -1){
                        raf.write(buffer, 0, len);
                    }

                    return true;
                }else{
                    mLogger.warn("response entity is null,url:"+url);
                }
            }else{
                mLogger.warn("response error, code="+code+",url:"+url);
            }
        }finally {
            IOUtils.closeQuietly(response);
        }

        return false;
    }

    public long getFileSize(String url){

        int retry = 0;
        long filesize = 0;
        while(retry<maxRetryCount){
            try {
                filesize = getContentLength(url);
            } catch (Exception e) {
                mLogger.error("get File Size error", e);
            }

            if(filesize>0){
                break;
            }else{
                retry++;
                mLogger.warn("get File Size failed,retry:"+retry);
            }
        }

        return filesize;
    }

    private long getContentLength(String url) throws ClientProtocolException, IOException{

        HttpGet httpget = new HttpGet(url);
        httpget.setHeader("User-Agent", "Mozilla/5.0 (Windows NT 6.3; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/42.0.2311.152 Safari/537.36");

        RequestConfig requestConfig = RequestConfig.custom()
                .setConnectTimeout(connectTimeout)
                .setSocketTimeout(socketTimeout)
                .build();

        httpget.setConfig(requestConfig);

        CloseableHttpResponse response = null;
        try {
            response = httpclient.execute(httpget);

            int code = response.getStatusLine().getStatusCode();

            if(code==HttpStatus.SC_OK){

                HttpEntity entity = response.getEntity();

                if (entity != null) {
                    return entity.getContentLength();
                }
            }else{
                mLogger.warn("response code="+code);
            }

        }finally {
            IOUtils.closeQuietly(response);
        }

        return -1;
    }

    public void close(){

        if(httpclient!=null){
            try {
                httpclient.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
            httpclient = null;
        }
    }
}

最后是客户端调用代码

package com.ricky.common.java;

import java.io.File;

import com.ricky.common.java.download.FileDownloader;
import com.ricky.common.java.download.config.FileDownloaderConfiguration;

public class FileDownloaderTest {

    public static void main(String[] args) {

        FileDownloader fileDownloader = FileDownloader.getInstance();
        FileDownloaderConfiguration configuration = FileDownloaderConfiguration
                .custom()
                .coreThreadNum(5)
                .downloadDestinationDir(new File("D:/Download"))
                .build();
        fileDownloader.init(configuration);

        String url = "http://dldir1.qq.com/qqfile/qq/QQ7.9/16621/QQ7.9.exe";;
        String filename = "QQ7.9.exe";

        boolean result = fileDownloader.download(url, filename);

        System.out.println("download result:"+result);

        fileDownloader.destroy();   //close it when you not need
    }
}


源代码

https://github.com/TiFG/FileDownloader

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
1.得到服务器下载文件的大小,然后在本地设置一个临时文件和服务器端文件大小一致 a)获得访问网络地址 b)通过URL对象的openConnection()方法打开连接,返回一个连接对象 c)设置请求头 i.setRequestMethod ii.setConnectTimeout iii.setReadTimeout d)判断是否响应成功 e)获取文件长度(getContentLength()) f)随机访问文件的读取与写入RandomAccessFile(file, mode) g)设置临时文件与服务器文件大小一致(setLength()) h)关闭临时文件 2.计算出每个线程下载的大小(开始位置,结束位置) a)计算出每个线程下载的大小 b)for循环,计算出每个线程的开始、结束位置 c)最后一个线程处理 3.每创建好一次就要开启线程下载 a)构造方法 b)通过URL对象的openConnection()方法打开连接,返回一个连接对象 c)设置请求头 i.setRequestMethod ii.setConnectTimeout d)判断是否响应成功(206) e)获取每个线程返回的流对象 f)随机访问文件的读取与写入RandomAccessFile(file, mode) g)指定开始位置 h)循环读取 i.保存每个线程下载位置 ii.记录每次下载位置 iii.关闭临时记录位置文件 iv.随机本地文件写入 v.记录已下载大小 i)关闭临时文件 j)关闭输入流 4.为了杀死线程还能继续下载的情况下,从本地文件上读取已经下载文件的开始位置 a)创建保存记录结束位置的文件 b)读取文件 c)将流转换为字符 d)获取记录位置 e)把记录位置赋给开始位置 5.当你的n个线程都下载完毕的时候我进行删除记录下载位置的缓存文件 a)线程下载完就减去 b)当没有正在运行的线程时切文件存在时删除文件

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值