import hashlib from huggingface_hub import hf_hub_download, get_paths_info import requests import sys import os uri = sys.argv[1] file_name = uri.split('/')[-1] # Function to parse the URI and determine download method def parse_uri(uri): if uri.startswith('huggingface://'): repo_id = uri.split('://')[1] return 'huggingface', repo_id.rsplit('/', 1)[0] elif 'huggingface.co' in uri: parts = uri.split('/resolve/') if len(parts) > 1: repo_path = parts[0].split('https://huggingface.co/')[-1] return 'huggingface', repo_path return 'direct', uri def calculate_sha256(file_path): sha256_hash = hashlib.sha256() with open(file_path, 'rb') as f: for byte_block in iter(lambda: f.read(4096), b''): sha256_hash.update(byte_block) return sha256_hash.hexdigest() def manual_safety_check_hf(repo_id): scanResponse = requests.get('https://huggingface.co/api/models/' + repo_id + "/scan") scan = scanResponse.json() # Check if 'hasUnsafeFile' exists in the response if 'hasUnsafeFile' in scan: if scan['hasUnsafeFile']: return scan else: return None else: return None download_type, repo_id_or_url = parse_uri(uri) new_checksum = None file_path = None # Decide download method based on URI type if download_type == 'huggingface': # Check if the repo is flagged as dangerous by HF hazard = manual_safety_check_hf(repo_id_or_url) if hazard != None: print(f'Error: HuggingFace has detected security problems for {repo_id_or_url}: {str(hazard)}', filename=file_name) sys.exit(5) # Use HF API to pull sha for file in get_paths_info(repo_id_or_url, [file_name], repo_type='model'): try: new_checksum = file.lfs.sha256 break except Exception as e: print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr) sys.exit(2) if new_checksum is None: try: file_path = hf_hub_download(repo_id=repo_id_or_url, filename=file_name) except Exception as e: print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr) sys.exit(2) else: response = requests.get(repo_id_or_url) if response.status_code == 200: with open(file_name, 'wb') as f: f.write(response.content) file_path = file_name elif response.status_code == 404: print(f'File not found: {response.status_code}', file=sys.stderr) sys.exit(2) else: print(f'Error downloading file: {response.status_code}', file=sys.stderr) sys.exit(1) if new_checksum is None: new_checksum = calculate_sha256(file_path) print(new_checksum) os.remove(file_path) else: print(new_checksum)