test / cli /install-sf.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
#!/usr/bin/env python
import os
import re
import sys
torch_supported = ['211', '212','220','221']
cuda_supported = ['cu118', 'cu121']
python_supported = ['39', '310', '311']
repo_url = 'https://github.com/chengzeyi/stable-fast'
api_url = 'https://api.github.com/repos/chengzeyi/stable-fast/releases/tags/nightly'
path_url = '/releases/download/nightly'
def install_pip(arg: str):
import subprocess
cmd = f'"{sys.executable}" -m pip install -U {arg}'
print(f'Running: {cmd}')
result = subprocess.run(cmd, shell=True, check=False, env=os.environ)
return result.returncode == 0
def get_nightly():
import requests
r = requests.get(api_url, timeout=10)
if r.status_code != 200:
print('Failed to get nightly version')
return None
json = r.json()
assets = json.get('assets', [])
if len(assets) == 0:
print('Failed to get nightly version')
return None
asset = assets[0].get('name', '')
pattern = r"-(.+?)\+"
match = re.search(pattern, asset)
if match:
ver = match.group(1)
print(f'Nightly version: {ver}')
return ver
else:
print('Failed to get nightly version')
return None
def install_stable_fast():
import torch
python_ver = f'{sys.version_info.major}{sys.version_info.minor}'
if python_ver not in python_supported:
raise ValueError(f'StableFast unsupported python: {python_ver} required {python_supported}')
if sys.platform == 'linux':
bin_url = 'manylinux2014_x86_64.whl'
elif sys.platform == 'win32':
bin_url = 'win_amd64.whl'
else:
raise ValueError(f'StableFast unsupported platform: {sys.platform}')
torch_ver, cuda_ver = torch.__version__.split('+')
torch_ver = torch_ver.replace('.', '')
sf_ver = get_nightly()
if torch_ver not in torch_supported:
print(f'StableFast unsupported torch: {torch_ver} required {torch_supported}')
print('Installing from source...')
url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast'
elif cuda_ver not in cuda_supported:
print(f'StableFast unsupported CUDA: {cuda_ver} required {cuda_supported}')
print('Installing from source...')
url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast'
elif sf_ver is None:
print('StableFast cannot determine version')
print('Installing from source...')
url = 'git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast'
else:
print('Installing wheel...')
file_url = f'stable_fast-{sf_ver}+torch{torch_ver}{cuda_ver}-cp{python_ver}-cp{python_ver}-{bin_url}'
url = f'{repo_url}/{path_url}/{file_url}'
ok = install_pip(url)
if ok:
import sfast
print(f'StableFast installed: {sfast.__version__}')
else:
print('StableFast install failed')
if __name__ == '__main__':
install_stable_fast()