|
import os |
|
import requests |
|
from tqdm import tqdm |
|
from modelscope import snapshot_download |
|
from urllib.parse import urlparse |
|
|
|
MODEL_DIR = snapshot_download("MuGeminorum/MiVOLO", cache_dir="./mivolo/__pycache__") |
|
|
|
|
|
def is_url(s: str): |
|
try: |
|
|
|
result = urlparse(s) |
|
|
|
return all([result.scheme, result.netloc]) |
|
|
|
except: |
|
|
|
return False |
|
|
|
|
|
def download_file(url: str, save_path: str): |
|
if os.path.exists(save_path): |
|
print("目标已存在,无需下载") |
|
return |
|
|
|
create_dir(os.path.dirname(save_path)) |
|
response = requests.get(url, stream=True) |
|
total_size = int(response.headers.get("content-length", 0)) |
|
|
|
progress_bar = tqdm(total=total_size, unit="B", unit_scale=True) |
|
with open(save_path, "wb") as file: |
|
for data in response.iter_content(chunk_size=1024): |
|
file.write(data) |
|
progress_bar.update(len(data)) |
|
|
|
progress_bar.close() |
|
if total_size != 0 and progress_bar.n != total_size: |
|
os.remove(save_path) |
|
print("下载失败,重试中...") |
|
download_file(url, save_path) |
|
|
|
else: |
|
print("下载完成") |
|
|
|
return save_path |
|
|
|
|
|
def create_dir(dir_path: str): |
|
if not os.path.exists(dir_path): |
|
os.makedirs(dir_path) |
|
|
|
|
|
def get_jpg_files(folder_path: str): |
|
all_files = os.listdir(folder_path) |
|
return [ |
|
os.path.join(folder_path, file) |
|
for file in all_files |
|
if file.lower().endswith(".jpg") |
|
] |
|
|