Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
import os | |
import re | |
import sys | |
torch_supported = ['211', '212','220','221'] | |
cuda_supported = ['cu118', 'cu121'] | |
python_supported = ['39', '310', '311'] | |
repo_url = 'https://github.com/chengzeyi/stable-fast' | |
api_url = 'https://api.github.com/repos/chengzeyi/stable-fast/releases/tags/nightly' | |
path_url = '/releases/download/nightly' | |
def install_pip(arg: str): | |
import subprocess | |
cmd = f'"{sys.executable}" -m pip install -U {arg}' | |
print(f'Running: {cmd}') | |
result = subprocess.run(cmd, shell=True, check=False, env=os.environ) | |
return result.returncode == 0 | |
def get_nightly(): | |
import requests | |
r = requests.get(api_url, timeout=10) | |
if r.status_code != 200: | |
print('Failed to get nightly version') | |
return None | |
json = r.json() | |
assets = json.get('assets', []) | |
if len(assets) == 0: | |
print('Failed to get nightly version') | |
return None | |
asset = assets[0].get('name', '') | |
pattern = r"-(.+?)\+" | |
match = re.search(pattern, asset) | |
if match: | |
ver = match.group(1) | |
print(f'Nightly version: {ver}') | |
return ver | |
else: | |
print('Failed to get nightly version') | |
return None | |
def install_stable_fast(): | |
import torch | |
python_ver = f'{sys.version_info.major}{sys.version_info.minor}' | |
if python_ver not in python_supported: | |
raise ValueError(f'StableFast unsupported python: {python_ver} required {python_supported}') | |
if sys.platform == 'linux': | |
bin_url = 'manylinux2014_x86_64.whl' | |
elif sys.platform == 'win32': | |
bin_url = 'win_amd64.whl' | |
else: | |
raise ValueError(f'StableFast unsupported platform: {sys.platform}') | |
torch_ver, cuda_ver = torch.__version__.split('+') | |
torch_ver = torch_ver.replace('.', '') | |
sf_ver = get_nightly() | |
if torch_ver not in torch_supported: | |
print(f'StableFast unsupported torch: {torch_ver} required {torch_supported}') | |
print('Installing from source...') | |
url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast' | |
elif cuda_ver not in cuda_supported: | |
print(f'StableFast unsupported CUDA: {cuda_ver} required {cuda_supported}') | |
print('Installing from source...') | |
url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast' | |
elif sf_ver is None: | |
print('StableFast cannot determine version') | |
print('Installing from source...') | |
url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast' | |
else: | |
print('Installing wheel...') | |
file_url = f'stable_fast-{sf_ver}+torch{torch_ver}{cuda_ver}-cp{python_ver}-cp{python_ver}-{bin_url}' | |
url = f'{repo_url}/{path_url}/{file_url}' | |
ok = install_pip(url) | |
if ok: | |
import sfast | |
print(f'StableFast installed: {sfast.__version__}') | |
else: | |
print('StableFast install failed') | |
if __name__ == '__main__': | |
install_stable_fast() | |