在annegu写的多线程断点续传实践的基础上做了一些修改与重构,annegu写的多线程断点续传实践文章的地址:http://annegu.iteye.com/blog/427397
实现的主要功能有:1.任意个线程分段下载2.断点续传3.临时文件合并成所需文件,并删除临时文件
一、首先看一个常量类,用于定义下载状态,下载目录:
import java.io.File;
/**
* 下载相关信息常量类
*
* @author jaychang
*
*/
public class DownloadConstant {
/** 下载默认路径 ,为当前使用系统的用户的Downloads的目录*/
public final static String DOWNLOAD_DIRECTORY = System.getProperty("user.home")+File.separator+"Downloads";
/** 下载已经完成 */
public final static String DOWNLOAD_HAS_FINISHED = "DOWNLOAD_HAS_FINISHED";
/** 下载出现错误 */
public final static String DOWNLOAD_ERROR = "DOWNLOAD_ERROR";
}
二、请求头设置
import java.net.URLConnection;
/**
* 请求头设置工具类
*
* @author jaychang
*
*/
public class RequestHeaderUtil {
/**
* 模拟发送HTTP请求
*
* @param con
* URLConnection
*/
public static void setHeader(URLConnection conn) {
conn
.setRequestProperty(
"User-Agent",
"Mozilla/5.0 (X11; U; Linux i686;en-US; rv:1.9.0.3) Gecko/2008092510 Ubuntu/8.04 (hardy) Firefox/3.0.3");
conn
.setRequestProperty("Accept-Language",
"en-us,en;q=0.7,zh-cn;q=0.3");
conn.setRequestProperty("Accept-Encoding", "aa");
conn.setRequestProperty("Accept-Charset",
"ISO-8859-1,utf-8;q=0.7,*;q=0.7");
conn.setRequestProperty("Keep-Alive", "300");
conn.setRequestProperty("Connection", "keep-alive");
conn.setRequestProperty("If-Modified-Since",
"Fri, 02 Jan 2009 17:00:05 GMT");
conn.setRequestProperty("If-None-Match", "\"1261d8-4290-df64d224\"");
conn.setRequestProperty("Cache-Control", "max-age=0");
}
}
该方法有个URLConnection类型的参数,有些网站为了安全起见,会对请求的http连接进行过滤,因此为 了伪装这个http的连接请求,我们给httpHeader穿一件伪装服。下面的setHeader方法展示了一些非常常用的典型的httpHeader 的伪装方法。比较重要的 有:Uer-Agent模拟从Ubuntu的firefox浏览器发出的请求;Referer模拟浏览器请求的前一个触发页面,例如从skycn站点来下 载软件的话,Referer设置成skycn的首页域名就可以了;Range就是这个连接获取的流文件的起始区间。
三、线程下载管理类
import java.io.IOException;
import java.net.URL;
import java.net.URLConnection;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import cn.com.servyou.mutithreaddown.contants.DownloadConstant;
import cn.com.servyou.mutithreaddown.util.MergeFileUtil;
import cn.com.servyou.mutithreaddown.util.RequestHeaderUtil;
/**
* 下载管理类
*
* @author jaychang
*
*/
public class DownloadManager {
/**默认的线程数,当然可以改成使用配置文件进行设置*/
public final static int THREAD_TOTAL_NUM = 10;
/** 文件总大小 */
private long contentLength;
/** 起始索引数组,记录相应线程下载段的起始位置 */
private long[] startPoints = new long[THREAD_TOTAL_NUM];
/** 终止索引数组 ,记录相应线程下载段的终止位置*/
private long[] endPoints = new long[THREAD_TOTAL_NUM];
/** URL */
private String urlStr;
/** 编码 */
@SuppressWarnings("unused")
private final static String DEFAULT_ENCODING = "GBK";
public String getUrlStr() {
return urlStr;
}
public void setUrlStr(String urlStr) {
this.urlStr = urlStr;
}
public DownloadManager(String urlStr){
this.urlStr = urlStr;
}
public void download() throws IOException {
URL url = new URL(urlStr);
URLConnection conn = url.openConnection();
RequestHeaderUtil.setHeader(conn);
// 下载文件的文件名
String fileAllName = urlStr.substring(urlStr.lastIndexOf("/") + 1);
// 下载文件的大小
contentLength = conn.getContentLength();
// 计算每个线程需下载的大小
long contentLengthPerThread = contentLength / THREAD_TOTAL_NUM;
System.out.println("Toltal bytes is " + contentLength);
System.out.println("Every thread need read bytes is "
+ contentLengthPerThread);
ExecutorService exec = Executors.newCachedThreadPool();
CountDownLatch latch = new CountDownLatch(THREAD_TOTAL_NUM);
for (int i = 0; i < THREAD_TOTAL_NUM; i++) {
startPoints[i] = contentLengthPerThread * i;
endPoints[i] = (i == THREAD_TOTAL_NUM - 1) ? contentLength - 1
: contentLengthPerThread * (i + 1) - 1;
DownloadThread downloadThread = new DownloadThread();
downloadThread.setContentLength(contentLength);
// 设置下载片段起始位置
downloadThread.setStartPoint(startPoints[i]);
// 设置下载片段结束位置
downloadThread.setEndPoint(endPoints[i]);
// 设置下载文件的全名
downloadThread.setFileAllName(fileAllName);
downloadThread.setUrlStr(urlStr);
downloadThread.setThreadIndex(i + 1);
downloadThread.setLatch(latch);
// 线程开始执行
exec.execute(downloadThread);
}
try {
// 等待CountdownLatch信号为0,表示所有子线程都执行结束
latch.await(100000, TimeUnit.MILLISECONDS);
exec.shutdown();
// 把分段下载下来的临时文件中的内容写入目标文件中。
MergeFileUtil.merge(DownloadConstant.DOWNLOAD_DIRECTORY + "/"
+ fileAllName, fileAllName);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
下面对以上代码做一个解释:
首先说明一下该类的几个常量,及属性,THREAD_TOTAL_NUM也不用多说了,contentLength为下载文件的总的字节数,startPoint[THREAD_TOTAL_NUM]存着每个线程所负责要下载的片段的起始位置,endPoint[THREAD_TOTAL_NUM]就不用说了,urlStr必须是http协议的url,且结尾是文件名+文件后缀名(fileName.type),编码暂时还未用到,先放着。
其次,说下CountdownLatch,CountdownLatch就是一个计数器,就像一个拦截的栅栏,用await()方法来把栅栏关上,线程就跑不下去了,只有等计数器减为0的时候,栅栏才会自动打开,被暂停的线程才会继续运行。CountdownLatch的应用场景可以有很多,分段下载就是一个很好的例子。
四、下载线程类
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.concurrent.CountDownLatch;
import cn.com.servyou.mutithreaddown.contants.DownloadConstant;
import cn.com.servyou.mutithreaddown.util.RequestHeaderUtil;
/**
* 下载线程类
*
* @author jaychang
*
*/
public class DownloadThread extends Thread {
/** 该线程所需下载文件片段的开始位置 */
private int threadIndex;
private long startPoint;
/** 该线程所需下载的文件片段的结束位置 */
private long endPoint;
/** 下载文件的总大小 */
private long contentLength;
/** 文件全名 */
private String fileAllName;
/** 文件的url地址 */
private String urlStr;
/** 缓冲大小 */
private final static int BUFFER_READ_SIZE = 8096;
/** 临时文件后缀名 */
private final static String FILE_TYPE = "tmp";
/** 初始化下载状态 */
private String status = DownloadConstant.DOWNLOAD_ERROR;
/** CountDownLatch 用作与主线程同步,等到所有线程都执行完毕,则主线程开始对临时文件进行拼装 */
private CountDownLatch latch;
public DownloadThread() {
super();
}
/**
* 构造器
*
* @param threadIndex
* 线程索引
* @param startPoint
* 下载开始位置
* @param endPoint
* 下载结束位置
* @param contentLength
* 文件总的大小
* @param fileAllName
* 文件全名
* @param urlStr
* url字符串
* @param latch
* CountDownLatch
*/
public DownloadThread(int threadIndex, long startPoint, long endPoint,
long contentLength, String fileAllName, String urlStr,
CountDownLatch latch) {
super();
this.threadIndex = threadIndex;
this.startPoint = startPoint;
this.endPoint = endPoint;
this.contentLength = contentLength;
this.fileAllName = fileAllName;
this.urlStr = urlStr;
this.latch = latch;
start();
}
/**
* 下载文件部分内容,并存于临时文件中
*/
public void run() {
int indexOfPoint = fileAllName.lastIndexOf(".");
String fileName = fileAllName.substring(0, indexOfPoint);
String tempFileName = fileName + "_" + threadIndex + "." + FILE_TYPE;
File downloadDir = new File(DownloadConstant.DOWNLOAD_DIRECTORY + "/"
+ fileAllName);
if (!downloadDir.exists()) {
downloadDir.mkdirs();
}
File tempFile = new File(downloadDir, tempFileName);
boolean isExist = tempFile.exists();
if (isExist) {
long localContentLength = tempFile.length();
processDownload(tempFileName, localContentLength);
} else {
processDownload(tempFileName, 0);
}
}
/**
* 处理该线程的下载任务
*
* @param tempFileName
* 临时文件名
* @param localContentLength
* 临时文件大小
*/
public void processDownload(String tempFileName, long localContentLength) {
HttpURLConnection conn = null;
BufferedInputStream in = null;
BufferedOutputStream out = null;
// 该线程需下载的字节片段未下载完
if (localContentLength < endPoint - startPoint + 1) {
try {
conn = (HttpURLConnection) (new URL(urlStr)).openConnection();
conn.setAllowUserInteraction(true);
// 设置连接超时时间为10000ms
conn.setConnectTimeout(10000);
// 设置读取数据超时时间为10000ms
conn.setReadTimeout(100000);
RequestHeaderUtil.setHeader(conn);
// 设置请求头读取字节的范围,即断点起始位置
long startPos = startPoint + localContentLength;
conn.setRequestProperty("Range", "bytes=" + startPos + "-"
+ endPoint);
System.out.println("Thread" + threadIndex + ": " + " startPos="
+ startPos + " endPos=" + endPoint + " NeedReadBytes="
+ (endPoint - startPos + 1));
int responseCode = conn.getResponseCode();
if (HttpURLConnection.HTTP_OK == responseCode) {
System.out.println("HTTP_OK");
} else if (HttpURLConnection.HTTP_PARTIAL == responseCode) {
System.out.println("HTTP_PARTIAL");
} else if (HttpURLConnection.HTTP_CLIENT_TIMEOUT == responseCode) {
System.out.println("HTTP_CLIENT_TIMEOUT");
}
File directory = new File(DownloadConstant.DOWNLOAD_DIRECTORY
+ "/" + fileAllName);
in = new BufferedInputStream(conn.getInputStream());
out = new BufferedOutputStream(new FileOutputStream(new File(
directory, tempFileName), true));
long count = 0;
byte[] b = new byte[BUFFER_READ_SIZE];
int len = -1;
// 需要读取的字节数
long needReadBytes = endPoint - startPos + 1;
while ((len = in.read(b)) != -1) {
count += len;
if (count > needReadBytes) {
System.out.println("Current read " + len + " Thread "
+ threadIndex + " has readed " + count
+ " bytes!");
System.out.println("Thread " + threadIndex
+ " finished!");
break;
}
out.write(b, 0, len);
}
// 设置最终该线程的下载状态
this.status = count >= needReadBytes ? DownloadConstant.DOWNLOAD_HAS_FINISHED
: DownloadConstant.DOWNLOAD_ERROR;
// 线程池计数减1,表示线程池中该线程的任务已结束
latch.countDown();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (in != null)
try {
in.close();
} catch (IOException e1) {
e1.printStackTrace();
}
if (out != null)
try {
out.close();
} catch (IOException e) {
e.printStackTrace();
}
}
} else if (localContentLength >= endPoint - startPoint + 1) {
System.out.println("Thread " + (threadIndex + 1) + ","
+ "needReadBytes = " + (endPoint - startPoint));
this.status = DownloadConstant.DOWNLOAD_HAS_FINISHED;
latch.countDown();
}
}
public long getStartPoint() {
return startPoint;
}
public void setStartPoint(long startPoint) {
this.startPoint = startPoint;
}
public long getEndPoint() {
return endPoint;
}
public void setEndPoint(long endPoint) {
this.endPoint = endPoint;
}
public long getContentLength() {
return contentLength;
}
public void setContentLength(long contentLength) {
this.contentLength = contentLength;
}
public String getFileAllName() {
return fileAllName;
}
public void setFileAllName(String fileAllName) {
this.fileAllName = fileAllName;
}
public int getThreadIndex() {
return threadIndex;
}
public void setThreadIndex(int threadIndex) {
this.threadIndex = threadIndex;
}
public String getUrlStr() {
return urlStr;
}
public void setUrlStr(String urlStr) {
this.urlStr = urlStr;
}
public String getStatus() {
return status;
}
public void setStatus(String status) {
this.status = status;
}
public CountDownLatch getLatch() {
return latch;
}
public void setLatch(CountDownLatch latch) {
this.latch = latch;
}
}
该类的属性中有threadIndex,用于定义临时文件,临时文件定义为"文件名_threadIndex.后缀名"
五、临时文件合并工具类。
现在每个分段的下载线程都顺利结束了,也都创建了相应的临时文件,接下来在主线程中会对临时文件进行合并,并写入目标文件,最后删除临时文件。
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
/**
* 临时文件合并工具类
*
* @author jaychang
*
*/
public class MergeFileUtil {
/** 下划线分隔符 */
public final static String UNDER_LINE = "_";
/** 点分隔符 */
public final static String POINT = ".";
/**
* 合并临时文件
*
* @param filePath
* 临时文件路径
* @param fileAllName
* @throws IOException
*/
public static void merge(String filePath, String fileAllName)
throws IOException {
File dir = new File(filePath);
BufferedOutputStream out = null;
BufferedInputStream in = null;
File[] fileList = dir.listFiles();
// 对临时文件进行排序,按照顺序写到输出流中,排序规则为按照文件名的threadIndex(fileAllName_'threadIndex'.type)排序
Arrays.sort(fileList, new Comparator<File>() {
public int compare(File fileOne, File fileAnother) {
String fileNameOne = fileOne.getName();
String fileNameAnother = fileAnother.getName();
int lastIndexOfUnderLineOne = fileNameOne
.lastIndexOf(UNDER_LINE);
int lastIndexOfPointOne = fileNameOne.lastIndexOf(POINT);
int lastIndexOfUnderLineAnother = fileNameAnother
.lastIndexOf(UNDER_LINE);
int lastIndexOfPointAnother = fileNameAnother
.lastIndexOf(POINT);
int one = Integer.parseInt(fileNameOne.substring(
lastIndexOfUnderLineOne + 1, lastIndexOfPointOne));
int another = Integer.parseInt(fileNameAnother.substring(
lastIndexOfUnderLineAnother+1, lastIndexOfPointAnother));
return one - another;
}
});
// 由临时文件拼装成的最终文件
File destFile = new File(dir, fileAllName);
if (!destFile.exists()) {
destFile.createNewFile();
}
out = new BufferedOutputStream(new FileOutputStream(destFile, true));
for (File file : fileList) {
// 过滤非临时文件
if (file.getName().indexOf("tmp") < 0)
continue;
// 读取临时文件,按文件编号(处理该文件的线程索引)顺序写入最终生成的文件
in = new BufferedInputStream(new FileInputStream(file));
byte[] b = new byte[8196];
int len = -1;
while ((len = in.read(b)) != -1) {
out.write(b, 0, len);
}
if (in != null) {
in.close();
}
// 删除临时文件
file.delete();
}
if (out != null) {
out.close();
}
}
}
遍历文件下载的临时文件,并对临时文件数组进行排序,排序按照前面讲到的threadIndex,因为之前定义临时文件名称为"文件名_threadIndex.后缀名"。