基本思想
首次下载准备
- 建立连接,获取文件大小,建立等大的本地RandomAccessFile。
- 根据配置的线程数,划分各个下载区间。
- 建立临时文件,记录各个下载区间的下载信息(完成与否)。
继续下载准备
- 读取临时文件,获取各个下载区间的信息
开始下载
- 遍历各个下载区间信息,对未完成的开启线程下载。每个下载线程,会持有一个RandomAccessFile,在下载时写入本地文件。并且在线程终止时回调下载结果
- 每有下载线程回调时,更新下载区间信息
- 下载成功后,删除临时文件。下载失败,则更新临时文件。若下载任务非正常结束,最好主动更新临时文件。
Code(修改自github上一个class)
package com.persist.download;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.io.*;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
/**
*多线程文件下载器
*下载线程数 默认为4
* 建议不要配置过多,可以考虑CPU*2
* */
public class Download {
private final static int DEFAULT_THREAD_NUM = 4;
//the uri of the file to be downloaded
private String uri;
//the path to save the local file
private String savePath;
//the true num of thread to execute download task
private int threadNum;
//the file size
private long fileSize;
//the size of each part to be downloaded
private long partSize;
//the download thread array
private DownloadThread[] downloadThread;
//the num of the thread which is really started
private int realThreadNum;
//the num of the finished thread
private int overThreadNum;
//the tail of the tmp file which records the download info of each part
//the download info will be store with json.
// when download is finished,the tmp file will be deleted.
private final static String tmpTail = "-tmp";
//the download info of each part
private List<DownloadInfo> downloadInfos;
public Download(String uri, String savePath, int threadNum) {
this.uri = uri;
this.savePath = savePath;
this.threadNum = threadNum;
downloadThread = new DownloadThread[threadNum];
}
public Download(String uri, String savePath) {
//default thread num is 4
this(uri, savePath, DEFAULT_THREAD_NUM);
}
/**
* execute download task
* */
public void download() throws IOException{
///如果第一次下载,
// 则建立临时文件,并写入分段下载信息
//或读取临时文件,读入分段下载信息
if(downloadInfos == null)
{
File file = new File(savePath + tmpTail);
Gson gson = new Gson();
if (!file.exists() || !file.isFile())
{
// System.out.println("首次下载");
URL url = new URL(uri);
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
conn.setConnectTimeout(5 * 1000);
conn.setRequestMethod("GET");
conn.setRequestProperty("Accept", "image/gif,image/jpeg,image/pjpeg," +
"image/gif, image/jpeg, image/pjpeg, image/pjpeg, "
+ "application/x-shockwave-flash, application/xaml+xml, "
+ "application/vnd.ms-xpsdocument, application/x-ms-xbap, "
+ "application/x-ms-application, application/vnd.ms-excel, "
+ "application/vnd.ms-powerpoint, application/msword, */*");
conn.setRequestProperty("Accept-Language", "zh-CN");
conn.setRequestProperty("Charset", "UTF-8");
conn.setRequestProperty("Connection", "Keep-Alive");
fileSize = conn.getContentLengthLong();
conn.disconnect();
RandomAccessFile file1 = new RandomAccessFile(savePath, "rw");
file1.setLength(fileSize);
file1.close();
//这样划分,至多可能多读取threadNum-1字节,但不会丢失
partSize = fileSize / threadNum + 1;
downloadInfos = new ArrayList<>();
realThreadNum = threadNum;
overThreadNum = 0;
for(int i = 0; i < threadNum; ++i)
{
long startPos = i * partSize;
RandomAccessFile currentPart = new RandomAccessFile(savePath,"rw");
currentPart.seek(startPos);
downloadThread[i] = new DownloadThread(uri, currentPart, startPos, partSize);
downloadThread[i].setCallback(callback);
downloadThread[i].start();
downloadInfos.add(new DownloadInfo(startPos, partSize, false));
}
return;
} else
{
// System.out.println("继续下载");
FileInputStream is = new FileInputStream(file);
InputStreamReader reader = new InputStreamReader(is);
downloadInfos = gson.fromJson(reader,
new TypeToken<List<DownloadInfo>>() {}.getType());
partSize = downloadInfos.get(0).partSize;
fileSize = partSize * downloadInfos.size();
reader.close();
}
}
DownloadInfo info;
realThreadNum = 0;
overThreadNum = 0;
for(int i = 0; i < threadNum; i++)
{
info = downloadInfos.get(i);
if(!info.finish)
{
realThreadNum++;
RandomAccessFile currentPart = new RandomAccessFile(savePath,"rw");
currentPart.seek(info.startPos);
downloadThread[i] = new DownloadThread(uri, currentPart, info.startPos, info.partSize);
downloadThread[i].setCallback(callback);
downloadThread[i].start();
}
}
}
/**
* the callback entity.
* When a thread is over, the {@link #downloadInfos} will be update.
* If all thread are executed successfully, the tmp file will be deleted.
* */
private Callback callback = new Callback() {
@Override
public void onTerminate(boolean finish, int index) {
overThreadNum++;
if(finish)
{
System.out.println("下载第"+index+"部分完成");
File file = new File(savePath + tmpTail);
//下载完成,删除临时文件
if(getCompleteRate() >= 1f && file.exists() && file.isFile())
file.delete();
else
{
downloadInfos.get(index).finish = true;
}
}
else
{
System.out.println("下载第"+index+"部分失败");
if(overThreadNum == realThreadNum)
updateTmpFile();
}
}
};
public void updateTmpFile()
{
PrintWriter writer = null;
try {
File file = new File(savePath + tmpTail);
writer = new PrintWriter(new FileOutputStream(file));
Gson gson = new Gson();
writer.print(gson.toJson(downloadInfos));
} catch (FileNotFoundException e) {
e.printStackTrace();
} finally {
if (writer != null)
writer.close();
}
}
/**
* get the complete rate of the download task
* */
public float getCompleteRate(){
long doneLength = 0;
for(int i = 0; i < threadNum; ++i){
if(downloadThread[i] != null)
doneLength += downloadThread[i].doneLength;
else
doneLength += partSize;
}
if(doneLength > fileSize)
return 1.0f;
return (float) ((doneLength * 1.0) / fileSize);
}
/**
* record the download part info
* to avoid starting needless download thread
* */
private class DownloadInfo {
public long startPos;
public long partSize;
public boolean finish;
public DownloadInfo(long startPos, long partSize, boolean finish)
{
this.startPos = startPos;
this.partSize = partSize;
this.finish = finish;
}
}
/**
* when a thread is over, a callback will be invoked with the result
* */
private interface Callback
{
void onTerminate(boolean finish, int index);
}
/**
* download a part of the file and write data to the local file.
* When it is over, the callback will be invoked.
* */
private class DownloadThread extends Thread{
private String uri;
private RandomAccessFile file;
//文件开始位置
private long startPos;
//当前线程下载的块的大小
private long partSize;
//已完成长度
private long doneLength = 0;
private Callback callback;
public DownloadThread(String uri, RandomAccessFile file, long startPos,long partSize) {
super();
this.uri = uri;
this.file = file;
this.startPos = startPos;
this.partSize = partSize;
}
public void setCallback(Callback callback)
{
this.callback = callback;
}
@Override
public void run(){
try {
URL url = new URL(uri);
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
conn.setConnectTimeout(5 * 1000);
conn.setRequestMethod("GET");
conn.addRequestProperty("Accept", "image/gif,image/jpeg,image/pjpeg," +
"image/gif, image/jpeg, image/pjpeg, image/pjpeg, "
+ "application/x-shockwave-flash, application/xaml+xml, "
+ "application/vnd.ms-xpsdocument, application/x-ms-xbap, "
+ "application/x-ms-application, application/vnd.ms-excel, "
+ "application/vnd.ms-powerpoint, application/msword, */*");
conn.setRequestProperty("Accept-Language", "zh-CN");
conn.setRequestProperty("Charset", "UTF-8");
InputStream is = conn.getInputStream();
is.skip(startPos);
byte buffer[] = new byte[1024];
int len;
while( doneLength < partSize
&& (len = is.read(buffer)) != -1){
file.write(buffer, 0, len);
doneLength += len;
}
file.close();
is.close();
if(callback != null)
{
callback.onTerminate(true, (int)(startPos/partSize));
}
} catch (MalformedURLException e) {
// TODO Auto-generated catch block
e.printStackTrace();
if(callback != null)
{
callback.onTerminate(false, (int)(startPos/partSize));
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
if(callback != null)
{
callback.onTerminate(false, (int)(startPos/partSize));
}
}
}
}
}
注意
保存临时文件时,采用的是json格式,使用了Gson框架。也可以采用其他方式,只要能方便的记录各个区间的下载信息即可。