golang实现下载进度条,在网上的都是终端线上进度条,没有可以通过接口查询进度条的功能;
网上都是使用mpb.Process实现的,于是自己写了一个;
原理就是通过重写read方法,获取下载进度,话不多说直接上代码:
package utils
import (
"archive/zip"
"fmt"
"github.com/spf13/cast"
"github.com/vbauerster/mpb/v5"
"io"
"net/http"
"os"
"path"
"path/filepath"
"runtime"
"strconv"
"sync"
"sync/atomic"
)
type Resource struct {
Filename string
Url string
}
type Downloader struct {
wg *sync.WaitGroup
pool chan *Resource
Concurrent int
HttpClient http.Client
TargetDir string
Resources []Resource
ReaderCounts []*ReaderCount
}
func NewDownloader(targetDir string) *Downloader {
concurrent := runtime.NumCPU()
return &Downloader{
wg: &sync.WaitGroup{},
TargetDir: targetDir,
Concurrent: concurrent,
}
}
func (d *Downloader) AppendResource(filename, url string, count *ReaderCount) {
d.Resources = append(d.Resources, Resource{
Filename: filename,
Url: url,
})
d.ReaderCounts = append(d.ReaderCounts, count)
}
func (d *Downloader) Download(resource Resource, progress *mpb.Progress, count *ReaderCount) error {
defer d.wg.Done()
d.pool <- &resource
finalPath := d.TargetDir + "/" + resource.Filename
// 创建临时文件
target, err := os.Create(finalPath + ".tmp")
if err != nil {
return err
}
// 开始下载
req, err := http.NewRequest(http.MethodGet, resource.Url, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
target.Close()
return err
}
defer resp.Body.Close()
fileSize, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
//创建一个进度条
//bar := progress.AddBar(
// int64(fileSize),
// // 进度条前的修饰
// mpb.PrependDecorators(
// decor.CountersKibiByte("% .2f / % .2f"), // 已下载数量
// decor.Percentage(decor.WCSyncSpace), // 进度百分比
// ),
// // 进度条后的修饰
// //mpb.AppendDecorators(
// // decor.EwmaETA(decor.ET_STYLE_GO, 90),
// // decor.Name(" ] "),
// // decor.EwmaSpeed(decor.UnitKiB, "% .2f", 60),
// //),
//)
//reader := bar.ProxyReader(resp.Body)
var total uint64
count.SetFileSize(fileSize)
count.setTotal(&total)
count.setReader(resp.Body)
// defer reader.Close()
// 将下载的文件流拷贝到临时文件
if _, err := io.Copy(target, count); err != nil {
target.Close()
return err
}
// 关闭临时并修改临时文件为最终文件
target.Close()
if err := os.Rename(finalPath+".tmp", finalPath); err != nil {
return err
}
<-d.pool
return nil
}
func (d *Downloader) Start() error {
d.pool = make(chan *Resource, d.Concurrent)
// fmt.Println("开始下载,当前并发:", d.Concurrent)
p := mpb.New(mpb.WithWaitGroup(d.wg))
for i, resource := range d.Resources {
d.wg.Add(1)
go d.Download(resource, p, d.ReaderCounts[i])
}
p.Wait()
d.wg.Wait()
return nil
}
func Unzip2(zipPath, dstDir string) error {
// open zip file
reader, err := zip.OpenReader(zipPath)
if err != nil {
return err
}
defer reader.Close()
for _, file := range reader.File {
if err := unzipFile(file, dstDir); err != nil {
return err
}
}
return nil
}
func unzipFile(file *zip.File, dstDir string) error {
// create the directory of file
filePath := path.Join(dstDir, file.Name)
if file.FileInfo().IsDir() {
if err := os.MkdirAll(filePath, os.ModePerm); err != nil {
return err
}
return nil
}
if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
return err
}
// open the file
rc, err := file.Open()
if err != nil {
return err
}
defer rc.Close()
// create the file
w, err := os.Create(filePath)
if err != nil {
return err
}
defer w.Close()
// save the decompressed file content
_, err = io.Copy(w, rc)
return err
}
type ReaderCount struct {
FileSize int
Total *uint64
reader io.ReadCloser
over bool
}
func (r *ReaderCount) Read(p []byte) (n int, err error) {
read, err := r.reader.Read(p)
if err == io.EOF {
r.over = true
}
atomic.AddUint64(r.Total, cast.ToUint64(read))
// fmt.Println(r.Total, read, r.FileSize, cast.ToFloat64(r.Total)/cast.ToFloat64(r.FileSize))
return read, err
}
func (r *ReaderCount) Close() error {
return r.reader.Close()
}
func (r *ReaderCount) SetFileSize(filesize int) {
r.FileSize = filesize
}
func (r *ReaderCount) setReader(reader io.ReadCloser) {
r.reader = reader
}
func (r *ReaderCount) setTotal(total *uint64) {
r.Total = total
}
func (r *ReaderCount) GetRate() string {
return fmt.Sprintf("%0.2f%%", cast.ToFloat64(r.Total)/cast.ToFloat64(r.FileSize)*100)
}
func (r *ReaderCount) IsOver() bool {
return r.over
}
func NewReader() *ReaderCount {
return &ReaderCount{}
}
func UpgradeVersion(url, name, downloadDir, toDir string) *ReaderCount {
downloader := NewDownloader(downloadDir)
reader := NewReader()
downloader.AppendResource(name, url, reader)
go func() {
// 可自主调整协程数量,默认为CPU核数
downloader.Concurrent = 1
downloader.Start()
// 解压
Unzip2(name, toDir)
}()
return reader
}
测试用例:
func Test_upgradeVersion(t *testing.T) {
version := utils.UpgradeVersion("xxx", "xx.zip", "./", "./")
for {
if version.IsOver() {
fmt.Println("完成")
return
}
fmt.Println(version.GetRate())
}
}
结果: