Java實現多線程下載,支持斷點續傳


完整代碼:https://github.com/yuanyb/Downloader

多線程下載及斷點續傳的實現是使用 HTTP/1.1 引入的 Range 請求參數,可以訪問Web資源的指定區間的內容。雖然實現了多線程及斷點續傳,但還有很多不完善的地方。

包含四個類:

  • Downloader: 主類,負責分配任務給各個子線程,及檢測進度
  • DownloadFile: 表示要下載的哪個文件,為了能寫輸入到文件的指定位置,使用 RandomAccessFile 類操作文件,多個線程寫同一個文件需要保證線程安全,這里直接調用 getChannel 方法,獲取一個文件通道,FileChannel是線程安全的。
  • DownloadTask: 實際執行下載的線程,獲取 [lowerBound, upperBound] 區間的數據,當下載過程中出現異常時要通知其他線程(使用 AtomicBoolean),結束下載
  • Logger: 實時記錄下載進度,以便續傳時知道從哪開始。感覺這里做的比較差,為了能實時寫出日志及方便地使用Properties類的load/store方法格式化輸入輸出,每次都是打開后再關閉。

 

演示:

隨便找一個文件下載:

強行結束程序並重新運行:

 

 

 

 日志文件:

斷點續傳的關鍵是記錄各個線程的下載進度,這里細節比較多,花了很久。只需要記錄每個線程請求的Range區間極客,每次成功寫數據到文件時,就更新一次下載區間。下面是下載完成后的日志內容。

 

 

代碼:

Downloader.java

 

package downloader;

import java.io.*;
import java.net.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.concurrent.atomic.AtomicBoolean;

public class Downloader {
    private static final int DEFAULT_THREAD_COUNT = 4;  // 默認線程數量
    private AtomicBoolean canceled; // 取消狀態,如果有一個子線程出現異常,則取消整個下載任務
    private DownloadFile file; // 下載的文件對象
    private String storageLocation;
    private final int threadCount; // 線程數量
    private long fileSize; // 文件大小
    private final String url;
    private long beginTime; // 開始時間
    private Logger logger;

    public Downloader(String url) {
        this(url, DEFAULT_THREAD_COUNT);
    }

    public Downloader(String url, int threadCount) {
        this.url = url;
        this.threadCount = threadCount;
        this.canceled = new AtomicBoolean(false);
        this.storageLocation = url.substring(url.lastIndexOf('/')+1);
        this.logger = new Logger(storageLocation + ".log", url, threadCount);
    }

    public void start() {
        boolean reStart = Files.exists(Path.of(storageLocation + ".log"));
        if (reStart) {
            logger = new Logger(storageLocation + ".log");
            System.out.printf("* 繼續上次下載進度[已下載:%.2fMB]:%s\n", logger.getWroteSize() / 1014.0 / 1024, url);
        } else {
            System.out.println("* 開始下載:" + url);
        }
        if (-1 == (this.fileSize = getFileSize()))
            return;
        System.out.printf("* 文件大小:%.2fMB\n", fileSize / 1024.0 / 1024);

        this.beginTime = System.currentTimeMillis();
        try {
            this.file = new DownloadFile(storageLocation, fileSize, logger);
            if (reStart) {
                file.setWroteSize(logger.getWroteSize());
            }
            // 分配線程下載
            dispatcher(reStart);
            // 循環打印進度
            printDownloadProgress();
        } catch (IOException e) {
            System.err.println("x 創建文件失敗[" + e.getMessage() + "]");
        }
    }

    /**
     * 分配器,決定每個線程下載哪個區間的數據
     */
    private void dispatcher(boolean reStart) {
        long blockSize = fileSize / threadCount; // 每個線程要下載的數據量
        long lowerBound = 0, upperBound = 0;
        long[][] bounds = null;
        int threadID = 0;
        if (reStart) {
           bounds = logger.getBounds();
        }
        for (int i = 0; i < threadCount; i++) {
            if (reStart) {
                threadID = (int)(bounds[i][0]);
                lowerBound = bounds[i][1];
                upperBound = bounds[i][2];
            } else {
                threadID = i;
                lowerBound = i * blockSize;
                // fileSize-1 !!!!! fu.ck,找了一下午的錯
                upperBound = (i == threadCount - 1) ? fileSize-1 : lowerBound + blockSize;
            }
            new DownloadTask(url, lowerBound, upperBound, file, canceled, threadID).start();
        }
    }

    /**
     * 循環打印進度,直到下載完畢,或任務被取消
     */
    private void printDownloadProgress() {
        long downloadedSize = file.getWroteSize();
        int i = 0;
        long lastSize = 0; // 三秒前的下載量
        while (!canceled.get() && downloadedSize < fileSize) {
            if (i++ % 4 == 3) { // 每3秒打印一次
                System.out.printf("下載進度:%.2f%%, 已下載:%.2fMB,當前速度:%.2fMB/s\n",
                        downloadedSize / (double)fileSize * 100 ,
                        downloadedSize / 1024.0 / 1024,
                        (downloadedSize - lastSize) / 1024.0 / 1024 / 3);
                lastSize = downloadedSize;
                i = 0;
            }
            try {
                Thread.sleep(1000);
            } catch (InterruptedException ignore) {}
            downloadedSize = file.getWroteSize();
        }
        file.close();
        if (canceled.get()) {
            try {
                Files.delete(Path.of(storageLocation));
            } catch (IOException ignore) {
            }
            System.err.println("x 下載失敗,任務已取消");
        } else {
            System.out.println("* 下載成功,本次用時"+ (System.currentTimeMillis() - beginTime) / 1000 +"秒");
        }
    }

    /**
     * @return 要下載的文件的尺寸
     */
    private long getFileSize() {
        if (fileSize != 0) {
            return fileSize;
        }
        HttpURLConnection conn = null;
        try {
            conn = (HttpURLConnection)new URL(url).openConnection();
            conn.setConnectTimeout(3000);
            conn.setRequestMethod("HEAD");
            conn.connect();
            System.out.println("* 連接服務器成功");
        } catch (MalformedURLException e) {
            throw new RuntimeException("URL錯誤");
        } catch (IOException e) {
            System.err.println("x 連接服務器失敗["+ e.getMessage() +"]");
            return -1;
        }
        return conn.getContentLengthLong();
    }

    public static void main(String[] args) throws IOException {
        new Downloader("http://js.xiazaicc.com//down2/ucliulanqi_downcc.zip").start();
    }
}

 

DownloadTask.java

 

package downloader;

import java.io.*;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.concurrent.atomic.AtomicBoolean;

class DownloadTask extends Thread {
    private final String url;
    private long lowerBound; // 下載的文件區間
    private long upperBound;
    private AtomicBoolean canceled;
    private DownloadFile downloadFile;
    private int threadId;

    DownloadTask(String url, long lowerBound, long upperBound, DownloadFile downloadFile,
                        AtomicBoolean canceled, int threadID) {
        this.url = url;
        this.lowerBound = lowerBound;
        this.upperBound = upperBound;
        this.canceled = canceled;
        this.downloadFile = downloadFile;
        this.threadId = threadID;
    }

    @Override
    public void run() {
        ReadableByteChannel input = null;
        try {
            ByteBuffer buffer = ByteBuffer.allocate(1024 * 1024 * 2); // 2MB
            input = connect();
            System.out.println("* [線程" + threadId + "]連接成功,開始下載...");

            int len;
            while (!canceled.get() && lowerBound <= upperBound) {
                buffer.clear();
                len = input.read(buffer);
                downloadFile.write(lowerBound, buffer, threadId, upperBound);
                lowerBound += len;
            }
            if (!canceled.get()) {
                System.out.println("* [線程" + threadId + "]下載完成" + ": " + lowerBound + "-" + upperBound);
            }
        } catch (IOException e) {
            canceled.set(true);
            System.err.println("x [線程" + threadId + "]遇到錯誤[" + e.getMessage() + "],結束下載");
        } finally {
            if (input != null) {
                try {
                    input.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    /**
     * 連接WEB服務器,並返回一個數據通道
     * @return 返回通道
     * @throws IOException 網絡連接錯誤
     */
    private ReadableByteChannel connect() throws IOException {
        HttpURLConnection conn = (HttpURLConnection)new URL(url).openConnection();
        conn.setConnectTimeout(3000);
        conn.setRequestMethod("GET");
        conn.setRequestProperty("Range", "bytes=" + lowerBound + "-" + upperBound);
//        System.out.println("thread_"+ threadId +": " + lowerBound + "-" + upperBound);
        conn.connect();

        int statusCode = conn.getResponseCode();
        if (HttpURLConnection.HTTP_PARTIAL != statusCode) {
            conn.disconnect();
            throw new IOException("狀態碼錯誤:" + statusCode);
        }

        return Channels.newChannel(conn.getInputStream());
    }
}

 

DownloadFile.java

 

package downloader;

import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.concurrent.atomic.AtomicLong;

class DownloadFile {
    private final RandomAccessFile file;
    private final FileChannel channel; // 線程安全類
    private AtomicLong wroteSize; // 已寫入的長度
    private Logger logger;

    DownloadFile(String fileName, long fileSize, Logger logger) throws IOException {
        this.wroteSize = new AtomicLong(0);
        this.logger = logger;
        this.file = new RandomAccessFile(fileName, "rw");
        file.setLength(fileSize);
        channel = file.getChannel();
    }

    /**
     * 寫數據
     * @param offset 寫偏移量
     * @param buffer 數據
     * @throws IOException 寫數據出現異常
     */
    void write(long offset, ByteBuffer buffer, int threadID, long upperBound) throws IOException {
        buffer.flip();
        int length = buffer.limit();
        while (buffer.hasRemaining()) {
            channel.write(buffer, offset);
        }
        wroteSize.addAndGet(length);
        logger.updateLog(threadID, length, offset + length, upperBound); // 更新日志
    }

    /**
     * @return 已經下載的數據量,為了知道何時結束整個任務,以及統計信息
     */
    long getWroteSize() {
        return wroteSize.get();
    }

    // 繼續下載時調用
    void setWroteSize(long wroteSize) {
        this.wroteSize.set(wroteSize);
    }

    void close() {
        try {
            file.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

 

Logger.java

package downloader;

import java.io.*;
import java.util.Properties;

class Logger {
    private String logFileName; // 下載的文件的名字
    private Properties log;

     /**
      * 重新開始下載時,使用該構造函數
      * @param logFileName
      */
    Logger(String logFileName) {
        this.logFileName = logFileName;
        log = new Properties();
        FileInputStream fin = null;
        try {
            log.load(new FileInputStream(logFileName));
        } catch (IOException ignore) {
        } finally {
            try {
                fin.close();
            } catch (Exception ignore) {}
        }
    }

    Logger(String logFileName, String url, int threadCount) {
        this.logFileName = logFileName;
        this.log = new Properties();
        log.put("url", url);
        log.put("wroteSize", "0");
        log.put("threadCount", String.valueOf(threadCount));
        for (int i = 0; i < threadCount; i++) {
            log.put("thread_" + i, "0-0");
        }
    }


    synchronized void updateLog(int threadID, long length, long lowerBound, long upperBound) {
        log.put("thread_"+threadID, lowerBound + "-" + upperBound);
        log.put("wroteSize", String.valueOf(length + Long.parseLong(log.getProperty("wroteSize"))));

        FileOutputStream file = null;
        try {
            file = new FileOutputStream(logFileName); // 每次寫時都清空文件
            log.store(file, null);
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (file != null) {
                try {
                    file.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    /**
     * 獲取區間信息
     *      ret[i][0] = threadID, ret[i][1] = lowerBoundID, ret[i][2] = upperBoundID
     * @return
     */
    long[][] getBounds() {
        long[][] bounds = new long[Integer.parseInt(log.get("threadCount").toString())][3];
        int[] index = {0};
        log.forEach((k, v) -> {
            String key = k.toString();
            if (key.startsWith("thread_")) {
                String[] interval = v.toString().split("-");
                bounds[index[0]][0] = Long.parseLong(key.substring(key.indexOf("_") + 1));
                bounds[index[0]][1] = Long.parseLong(interval[0]);
                bounds[index[0]++][2] = Long.parseLong(interval[1]);
            }
        });
       return bounds;
    }
    long getWroteSize() {
        return Long.parseLong(log.getProperty("wroteSize"));
    }
}

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM