|
import gradio as gr
|
|
import torch
|
|
import spaces
|
|
from diffusers import DiffusionPipeline
|
|
import gc
|
|
import subprocess
|
|
|
|
|
|
subprocess.run('pip cache purge', shell=True)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
models = ["camenduru/FLUX.1-dev-diffusers",
|
|
"black-forest-labs/FLUX.1-schnell",
|
|
"sayakpaul/FLUX.1-merged",
|
|
"John6666/blue-pencil-flux1-v001-fp8-flux",
|
|
"John6666/fluxunchained-artfulnsfw-fut516xfp8e4m3fnv11-fp8-flux",
|
|
"John6666/nepotism-fuxdevschnell-v3aio-flux"
|
|
]
|
|
|
|
|
|
def clear_cache():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
|
|
def get_repo_safetensors(repo_id: str):
|
|
from huggingface_hub import HfApi
|
|
api = HfApi()
|
|
try:
|
|
if " " in repo_id or not api.repo_exists(repo_id): return gr.update(value="", choices=[])
|
|
files = api.list_repo_files(repo_id=repo_id)
|
|
except Exception as e:
|
|
print(f"Error: Failed to get {repo_id}'s info. ")
|
|
print(e)
|
|
return gr.update(choices=[])
|
|
files = [f for f in files if f.endswith(".safetensors")]
|
|
if len(files) == 0: return gr.update(value="", choices=[])
|
|
else: return gr.update(value=files[0], choices=files)
|
|
|
|
|
|
def change_base_model(repo_id: str):
|
|
from huggingface_hub import HfApi
|
|
global pipe
|
|
api = HfApi()
|
|
try:
|
|
if " " in repo_id or not api.repo_exists(repo_id): return
|
|
clear_cache()
|
|
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
|
|
except Exception as e:
|
|
print(e)
|
|
|