python--杂识--5--大文件分片上传和断点续传

0 思路

(1) 客户端实现边分片文件边发送。客户端使用生产者消费者模式,生产者用于读取指定大小的二进制文件分片,放入缓存队列中,消费者用于从缓存队列中读取文件分片并post发送到服务器;生产者任务与消费者任务异步进行,使用了协程实现。
(2) 服务端实现边接收文件分片边拼接文件分片。接收并保存文件分片到磁盘用子线程实现,拼接文件分片用子进程实现。需要使用info.ini文件配合实现以上功能,info.ini保存了拼接文件分片的子进程是否存在以及当前需要拼接到合并文件[文件分片合并后的文件]的文件分片id。例如:

[abc.txt]
merge_slice_process_exist=0
curr_merge_slice_id=5

merge_slice_process_exist表示合并分片的子进程是否存在,0表示不存在,1表示存在;curr_merge_slice_id=5表示abc.txt的id为5的文件分片当前需要合并到合并文件中。

1 目录结构

在这里插入图片描述

2 代码

# app.py
import logging
from flask import Flask, request, jsonify
from main import ServerBigFileUpload

app = Flask(__name__)


@app.route('/')
def hello_world():
    return 'Hello World!'


@app.route('/big_file_upload', methods=['GET', 'POST'])
def big_file_upload():
    logging.warning('=' * 20 + 'S:0' + '=' * 20)
    dir_ = "upload"
    return_data = return_data_ = dict()
    if request.method == 'GET':
        # 0 根据文件名判断文件是否在服务器上已经存在一部分

        # 1
        logging.warning('=' * 20 + 'S:1' + '=' * 20)
        file_name = request.args.get("file_name", -1)
        slice_num = int(request.args.get("slice_num", -1))
        if slice_num == -1 or file_name == -1:
            return_data["message"] = "参数错误"
            return_data["status"] = 500
            return jsonify(return_data)

        # 2 获取已有文件切片的id,返回给客户端,以让客户端决定判断还有哪些切片有待发送
        logging.warning('=' * 20 + 'S:2' + '=' * 20)
        return_data_ = ServerBigFileUpload.get_main(file_name, slice_num, dir_)

    elif request.method == 'POST':
        # 3
        logging.warning('=' * 20 + 'S:3' + '=' * 20)
        file_name = request.args.get("file_name")
        slice_id = request.args.get("slice_id")
        slice_size = request.args.get("slice_size")
        slice_content = request.stream.read()

        # 4 接收文件分片并保存到磁盘
        logging.warning('=' * 20 + 'S:4' + '=' * 20)
        return_data_ = ServerBigFileUpload.post_main(file_name, slice_id, slice_size, slice_content, dir_)

    return_data.update(return_data_)
    logging.warning('=' * 20 + 'S:5' + '=' * 20)
    return jsonify(return_data)


if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=8888)
# main.py
import os
import threading
import multiprocessing
import configparser
import subprocess
import time
import logging


class ServerBigFileUpload:
    @staticmethod
    def merge_slice(file_name, slice_num, dir_, curr_merge_slice_id=0, timeout=10, waiting_times=3):
        # 2.5.1
        logging.warning('=' * 20 + 'S:2.5.1' + '=' * 20)
        save_dir = dir_ + "/" + "tmp" + "/" + file_name
        info_ini_file = dir_ + '/info.ini'
        merge_file = dir_ + "/" + file_name
        logging.warning(merge_file)
        if not os.path.isfile(merge_file):
            open(merge_file, 'a').close()

        # 2.5.2 合并后台收到的分片
        logging.warning('=' * 20 + 'S:2.5.2' + '=' * 20)
        merge_file_open = open(merge_file, "wb+")
        waiting_times_ = waiting_times
        # # 对于每一个所要合并的分片最多循环等待三次,防止所要合并分片还在传送的路上或还没写入磁盘
        while waiting_times_ and curr_merge_slice_id < slice_num:
            curr_merge_slice_file = save_dir + "/" + str(curr_merge_slice_id)
            if str(curr_merge_slice_id) in os.listdir(save_dir):
                with open(curr_merge_slice_file, r"rb") as f:
                    curr_merge_slice_content = f.read()
                merge_file_open.write(curr_merge_slice_content)
                os.remove(curr_merge_slice_file)  # 调试时可以注释掉
                curr_merge_slice_id += 1
                waiting_times_ = waiting_times
            else:
                # 防止所要合并分片还在传送的路上或还没写入磁盘

                waiting_times_ -= 1
                time.sleep(timeout)
        merge_file_open.close()

        # 2.5.3 修改info.ini
        logging.warning('=' * 20 + 'S:2.5.3' + '=' * 20)
        config = configparser.ConfigParser()
        config.read(info_ini_file, encoding="utf-8")
        if curr_merge_slice_id == slice_num:
            # 表示已将所有分片合并时

            config.remove_section(file_name)
            rm_cmd = "sudo rm -rf %s" % save_dir  # 调试时可以注释掉
            subprocess.check_call(rm_cmd, shell=True)  # 调试时可以注释掉
        if waiting_times_ == 0:
            # 表示由于前台未发送或网络原因导致部分分片没有到达后台时

            config.set(file_name, "merge_slice_process_exist", "0")
            config.set(file_name, "curr_merge_slice_id", str(curr_merge_slice_id))
        with open(info_ini_file, "w", encoding="utf-8") as f:
            config.write(f)

        # 2.5.4
        logging.warning('=' * 20 + 'S:2.5.4' + '=' * 20)
        return None

    @staticmethod
    def get_main(file_name, slice_num, dir_):
        # 2.1
        logging.warning('=' * 20 + 'S:2.1' + '=' * 20)
        return_data = dict()

        # 2.4 判断是否有关于进程id和当前需要合并的分片id的info.ini
        logging.warning('=' * 20 + 'S:2.4' + '=' * 20)
        info_ini_file = dir_ + '/info.ini'
        if not os.path.isfile(info_ini_file):
            open(info_ini_file, "a").close()

        # 2.2 判断所上传的文件是否之前上传过部分
        logging.warning('=' * 20 + 'S:2.2' + '=' * 20)
        upload_file = dir_ + "/" + file_name
        config = configparser.ConfigParser()
        config.read(info_ini_file)
        if os.path.isfile(upload_file) and not config.has_section(file_name):
            # 表示之前文件整体全部上传完成
            return_data["is_uploaded"] = 1
            return return_data
        else:
            # 表示之前没有上传过文件或之前上传了文件部分分片
            return_data["is_uploaded"] = 0

        # 2.3 获取所上传文件当前已存取的分片id
        logging.warning('=' * 20 + 'S:2.3' + '=' * 20)
        save_dir = dir_ + "/" + "tmp" + "/" + file_name
        return_data["file_name"] = file_name
        return_data["status"] = 200
        if not os.path.isdir(save_dir):
            return_data["slice_ids_uploaded"] = []
            os.makedirs(save_dir, exist_ok=True)
        else:
            uploaded_not_merge_slice_ids = os.listdir(save_dir)  # 调试时可以注释掉
            try:
                curr_merge_slice_id = int(config.get(file_name, "curr_merge_slice_id"))
            except Exception as e:
                curr_merge_slice_id = 0
            already_merge_slice_ids = [x for x in range(curr_merge_slice_id)]  # 调试时可以注释掉
            return_data["slice_ids_uploaded"] = uploaded_not_merge_slice_ids + already_merge_slice_ids  # 调试时可以注释掉
            # return_data["slice_ids_uploaded"] = os.listdir(save_dir)  # 调试时取消注释

        # 2.5 判断当前是否合并分片merge_slice进程已启动
        logging.warning('=' * 20 + 'S:2.5' + '=' * 20)
        config = configparser.ConfigParser()
        config.read(info_ini_file, encoding="utf-8")
        if not config.has_section(file_name):
            # info.ini没有 file_name 部分时
            merge_slice = multiprocessing.Process(target=ServerBigFileUpload.merge_slice,
                                                  args=(file_name, slice_num, dir_, 0))
            merge_slice.start()
            config.add_section(file_name)
            config.set(file_name, "merge_slice_process_exist", "1")
            with open(info_ini_file, "w", encoding="utf-8") as f:
                config.write(f)
        elif int(config.get(file_name, "merge_slice_process_exist")) == 0:
            # 当info.ini中有 file_name 部分 但merge_slice进程没有启动时
            curr_merge_slice_id = int(config.get(file_name, "curr_merge_slice_id"))
            merge_slice = multiprocessing.Process(target=ServerBigFileUpload.merge_slice,
                                                  args=(file_name, slice_num, dir_, curr_merge_slice_id))
            merge_slice.start()
            config.set(file_name, "merge_slice_process_exist", "1")
            with open(info_ini_file, "w", encoding="utf-8") as f:
                config.write(f)

        # 2.6
        logging.warning('=' * 20 + 'S:2.6' + '=' * 20)
        return return_data

    @staticmethod
    def save_slice(file_name, slice_id, slice_content, dir_):
        # 4.2.1 保存文件分片
        logging.warning('=' * 20 + 'S:4.2.1' + '=' * 20)
        slice_save_file = dir_ + "/" + "tmp" + "/" + file_name + "/" + slice_id
        with open(slice_save_file, "wb") as f:
            f.write(slice_content)

        # 4.2.2
        logging.warning('=' * 20 + 'S:4.2.2' + '=' * 20)
        return None

    @staticmethod
    def post_main(file_name, slice_id, slice_size, slice_content, dir_):
        # 4.1 使用正则匹配获取真正传送的数据 slice_content
        logging.warning('=' * 20 + 'S:4.1' + '=' * 20)
        return_data = dict()
        slice_content = slice_content[205+len(slice_size):-40]

        # 4.2 创建保存当前分片的子线程
        logging.warning('=' * 20 + 'S:4.2' + '=' * 20)
        save_slice_threading = threading.Thread(target=ServerBigFileUpload.save_slice, args=(file_name, slice_id, slice_content, dir_))
        save_slice_threading.start()

        return_data["file_name"] = file_name
        return_data["slice_id"] = slice_id
        return_data["status"] = 200

        # 4.3
        logging.warning('=' * 20 + 'S:4.3' + '=' * 20)
        return return_data
# client.py
import logging
import os
import asyncio
# import uvloop
import aiohttp
import aiofiles
import requests
import time


# asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())


class ClientBigFileUpload:

    @staticmethod
    def file_name(file_path):
        # 1.1
        logging.warning('=' * 20 + 'C:1.1' + '=' * 20)
        file_name_ = os.path.split(file_path)[-1]

        # 1.2
        logging.warning('=' * 20 + 'C:1.2' + '=' * 20)
        return file_name_

    @staticmethod
    def compute_slice_num(file_path, slice_size):
        # 2-3.1
        logging.warning('=' * 20 + 'C:2-3.1' + '=' * 20)
        file_size = os.path.getsize(file_path)
        slice_num = file_size // slice_size
        if file_size % slice_size:
            slice_num += 1
        # 2-3.2
        logging.warning('=' * 20 + 'C:2-3.2' + '=' * 20)
        return int(slice_num)

    @staticmethod
    async def file_slice(file_path, file_name, slice_ids, slice_size, queue):
        # 7.1
        logging.warning('=' * 20 + 'C:7.1' + '=' * 20)
        try:
            async with aiofiles.open(file_path, 'rb') as f:
                for slice_id in slice_ids:
                    await f.seek(int(slice_id * slice_size), 0)
                    slice_content = await f.read(slice_size)
                    logging.warning(len(slice_content))
                    await queue.put({"file_name": file_name, "slice_id": slice_id, "slice_size": len(slice_content), "slice_content": slice_content})
        except Exception as e:
            # 7.2
            logging.warning('=' * 20 + 'C:7.2' + '=' * 20)
            logging.error(e)
            return False
        # 7.3
        logging.warning('=' * 20 + 'C:7.3' + '=' * 20)
        return True

    @staticmethod
    async def post_slice_upload(session, queue, url, timeout=5):
        # 2.1.1
        logging.warning('=' * 20 + 'C:2.1.1' + '=' * 20)
        # headers = {"Content-Type": "application/json"}
        headers = {}

        # 2.1.2 防止queue为空且从磁盘读取的数据还没放到queue中导致协程直接结束
        logging.warning('=' * 20 + 'C:2.1.2' + '=' * 20)
        if queue.empty():
            await asyncio.sleep(timeout)

        while not queue.empty():
            slice_ = await queue.get()
            params = {"file_name": slice_["file_name"], "slice_id": slice_["slice_id"], "slice_size": slice_["slice_size"]}
            slice_content = {"slic_content": slice_["slice_content"]}
            print("slice_id: %d" % slice_["slice_id"])
            try:
                async with session.post(url, data=slice_content, headers=headers, params=params) as response:
                    await response.json()
                    # time.sleep(2)
            except Exception as e:
                logging.error(e)
            queue.task_done()
            if queue.empty():
                await asyncio.sleep(timeout)
        # 2.1.3
        logging.warning('=' * 20 + 'C:2.1.3' + '=' * 20)
        return None

    @staticmethod
    async def upload_session(queue, url, consumer_num, timeout):
        # 2.1
        logging.warning('=' * 20 + 'C:2.1' + '=' * 20)
        try:
            async with aiohttp.ClientSession() as session:
                tasks = [asyncio.create_task(ClientBigFileUpload.post_slice_upload(session, queue, url, timeout)) for x in range(consumer_num)]
                await asyncio.wait(tasks)
        except Exception as e:
            # 2.2
            logging.warning('=' * 20 + 'C:2.2' + '=' * 20)
            logging.error(e)
            return False

        # 2.3
        logging.warning('=' * 20 + 'C:2.3' + '=' * 20)
        return True

    @staticmethod
    def get_slice_ids_uploaded(url, file_name, slice_num):
        # 4.1
        logging.warning('=' * 20 + 'C:4.1' + '=' * 20)
        headers = {}
        params = {"file_name": file_name, "slice_num": slice_num}
        try:
            response = requests.get(url, params=params, headers=headers).json()
            logging.warning(response)
        except Exception as e:
            logging.error(e)
            return None, None, None, None

        is_uploaded = response.get("is_uploaded", None)
        slice_ids_uploaded = response.get("slice_ids_uploaded", None)
        status = response.get("status", None)

        # 4.2
        logging.warning('=' * 20 + 'C:4.2' + '=' * 20)
        logging.warning(type(file_name))
        logging.warning(type(is_uploaded))
        logging.warning(type(slice_ids_uploaded))
        logging.warning(type(status))
        print(file_name, is_uploaded, slice_ids_uploaded, status)
        return file_name, is_uploaded, slice_ids_uploaded, status

    @staticmethod
    async def main(file_path, queue_size, consumer_num, slice_size, url, timeout):
        """
        客户端分片并发送文件,分片与发送异步进行
        :param file_path: 所发送文件路径
        :param queue_size: 缓存队列queue大小
        :param consumer_num: 消费者个数
        :param slice_size: 分片大小
        :param url: 接收文件分片的url
        :param timeout: 防止缓存队列queue为空且从磁盘读取的数据还没放到queue中导致协程直接结束
        :return: None
        """
        # 1 获取发送文件名
        logging.warning('=' * 20 + 'C:1' + '=' * 20)
        file_name = ClientBigFileUpload.file_name(file_path)

        # 2 异步发送post请求,给服务器发送切片
        logging.warning('=' * 20 + 'C:2' + '=' * 20)
        queue = asyncio.Queue(queue_size)
        task = asyncio.create_task(ClientBigFileUpload.upload_session(queue, url, consumer_num, timeout))

        # 2-3
        logging.warning('=' * 20 + 'C:2-3' + '=' * 20)
        slice_num = ClientBigFileUpload.compute_slice_num(file_path, slice_size)
        slice_ids = {x for x in range(slice_num)}

        # 3 获取服务器所需文件分片ids并根据ids对文件分片
        logging.warning('=' * 20 + 'C:3' + '=' * 20)
        while True:
            # 4 发生时get请求获取需要发送的切片
            logging.warning('=' * 20 + 'C:4' + '=' * 20)
            _, is_uploaded, slice_ids_uploaded, _ = ClientBigFileUpload.get_slice_ids_uploaded(url, file_name, slice_num)

            slice_ids_uploaded = set(map(int, slice_ids_uploaded))
            print("slice_ids_uploaded: ", slice_ids_uploaded)
            slice_ids -= slice_ids_uploaded
            print("slice_ids: ", slice_ids)

            # 5 判断是否文件之前上传过
            logging.warning('=' * 20 + 'C:5' + '=' * 20)
            if is_uploaded:
                logging.warning("该文件之前上传过")
                break

            # 6 判断是否所有的分片已完成发送
            logging.warning('=' * 20 + 'C:6' + '=' * 20)
            if not slice_ids:
                break

            # 7 对客户端所需发送的文件分片进行分片
            logging.warning('=' * 20 + 'C:7' + '=' * 20)
            await ClientBigFileUpload.file_slice(file_path, file_name, list(slice_ids), slice_size, queue)

            # 8 等待队列为空
            logging.warning('=' * 20 + 'C:8' + '=' * 20)
            await queue.join()

        # 9 任务完成销毁所有的消费者
        logging.warning('=' * 20 + 'C:9' + '=' * 20)
        task.cancel()

        # 10
        logging.warning('=' * 20 + 'C:10' + '=' * 20)
        return None


if __name__ == '__main__':
    file_path = r"xxxxxxxxxxxxxxxxxxxxxx"  # 所上传文件路径

    queue_size = 10  # 队列长度
    consumer_num = 1  # 消费者个数
    slice_size = int(0.5 * 1024 * 1024)  # 分片大小
    url = r"http://xx.xx.xx.xx:8888/big_file_upload" # 文件上传的url
    timeout = 0.5
    asyncio.run(ClientBigFileUpload.main(file_path, queue_size, consumer_num, slice_size, url, timeout))

3 启动

(1) 在服务端进入file_upload目录执行

uwsgi --socket 0.0.0.0:8888 --protocol=http -p 1 -w app:app

(2) 在客户端执行client.py脚本即可

4 bug

目前里边存在一个问题未解决:
   客户端报错: ERROR:root:[WinError 10053] 你的主机中的软件中止了一个已建立的连接。

问题复原demo

# server.py
import socket


socketserver = socket.socket()
host = "0.0.0.0"
port = 9999
socketserver.bind((host, port))
socketserver.listen(5)
client, addr = socketserver.accept()
data = client.recv(1024).decode("utf-8")
print(data, type(data), len(data))
# client.py
import socket, time

socket_client = socket.socket()
host = "127.0.0.1"
port = 9999
data = "222"
socket_client.connect((host, port))
print("开始发送")
socket_client.send(data.encode("utf-8"))
time.sleep(3)
socket_client.send(data.encode("utf-8"))
socket_client.recv(1024)

运行结果client报错
在这里插入图片描述

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值