大文件分片下载并发

一起来学python / 2023-05-06 / 原文

import requests
from concurrent.futures import ThreadPoolExecutor
import os
from pathlib import Path
from loguru import logger
import traceback
from tqdm import tqdm
from threading import Lock
from functools import partial

# url = "https://scontent.xx.fbcdn.net/m1/v/t6/An_YmP5OIPXun-vu3hkckAZZ2s4lPYoVkiyvCcWiVY21mu1Ng5_1HeCa2CWiSTsskj8HQ8bN013HxNpYDdSC_7jWQq_svcg.tar?ccb=10-5&oh=00_AfBn8XrMhiHu6w1KuS1X8rkuLzzZJnRs8B9jFMvVRfQnfg&oe=64659C28&_nc_sid=fb0754"
#
# tar_path = Path("/Users/chennan/Desktop/sa_000020.tar")
# fetching_path = Path(f"{tar_path.as_posix()}.fetch")
lock = Lock()
downloaded = 0
url = "https://images.pexels.com/photos/15983035/pexels-photo-15983035.jpeg"
tar_path = Path("/Users/chennan/Desktop/pexels-photo-15983035.jpeg")
fetching_path = Path(f"{tar_path.as_posix()}.fetch")
pbar_threads = []

def insert_data(info):
    headers, start, end = info
    pbar_thread = tqdm(total=end - start)
    pbar_threads.append(pbar_thread)
    with requests.get(url, stream=True, headers=headers) as response:
        with tar_path.open('rb+') as f:  # path 文件保存路径
            # 从文件的start位置开始写入
            f.seek(start)
            for chunk in response.iter_content(chunk_size=1024):
                if chunk:
                    f.write(chunk)
                # pbar.update(len(chunk))
                pbar_thread.update(len(chunk))
            pbar_thread.close()


def get_file_length():
    req = requests.get(url, stream=True)
    return int(req.headers['content-Length'])


def fetch_one(content_length):
    all_thread = 64  # 线程数量
    part = content_length // all_thread  # 每个线程请求的大小
    for i in range(all_thread):
        # 每个线程开始, 结束爬取的位置
        start = part * i
        if i == all_thread - 1:
            end = content_length
        else:
            end = start + part
        if i > 0:
            start += 1
        # 设置从网站请求开始和结束的位置
        headers = {
            'Range': f'bytes={start}-{end}',
        }
        yield headers, start, end


if __name__ == '__main__':
    content_length = get_file_length()
    if not tar_path.exists():
        with tar_path.open('wb') as f:
            f.seek(content_length - 1)
            f.write(b'\0')
    if not fetching_path.exists():
        with fetching_path.open("wb") as fs:
            fs.write(b"\0")

    try:
        with ThreadPoolExecutor(max_workers=16) as pool:
            pool.map(insert_data, fetch_one(content_length))
        with tqdm(total=content_length) as pbar:
            progress = sum(pbar_thread.n for pbar_thread in pbar_threads)
            pbar.update(progress)
        fetching_path.unlink()
        logger.info("下载完成")
    except Exception:
        logger.error(f"{traceback.format_exc()}")