HuggingSpaces / models_server.py
khulnasoft's picture
Update models_server.py
26ec8ac verified
import argparse
import logging
import random
import uuid
import numpy as np
from transformers import pipeline
from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image, export_to_video
from transformers import (
SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5ForSpeechToSpeech,
BlipProcessor, BlipForConditionalGeneration, TrOCRProcessor, VisionEncoderDecoderModel,
ViTImageProcessor, AutoTokenizer, AutoImageProcessor, TimesformerForVideoClassification,
MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, DPTForDepthEstimation, DPTFeatureExtractor
)
from datasets import load_dataset
from PIL import Image
from torchvision import transforms
import torch
import torchaudio
from speechbrain.pretrained import WaveformEnhancement
import joblib
from huggingface_hub import hf_hub_url, cached_download
from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector, CannyDetector, MidasDetector
import warnings
import time
from espnet2.bin.tts_inference import Text2Speech
import soundfile as sf
from asteroid.models import BaseModel
import traceback
import os
import yaml
warnings.filterwarnings("ignore")
def setup_logger():
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
logger = setup_logger()
def load_config(config_path):
with open(config_path, "r") as file:
return yaml.load(file, Loader=yaml.FullLoader)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
return parser.parse_args()
args = parse_args()
# Ensure the config is always set when not running as the main script
if __name__ != "__main__":
args.config = "config.gradio.yaml"
config = load_config(args.config)
local_deployment = config["local_deployment"]
if config["inference_mode"] == "huggingface":
local_deployment = "none"
PROXY = {"https": config["proxy"]} if config["proxy"] else None
start = time.time()
local_models = "" # Changed to empty string
def load_pipes(local_deployment):
standard_pipes = {}
other_pipes = {}
controlnet_sd_pipes = {}
if local_deployment in ["full"]:
other_pipes = {
"damo-vilab/text-to-video-ms-1.7b": {
"model": DiffusionPipeline.from_pretrained(f"{local_models}damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"),
"device": "cuda:0"
},
"JorisCos/DCCRNet_Libri1Mix_enhsingle_16k": {
"model": BaseModel.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k"),
"device": "cuda:0"
},
"microsoft/speecht5_vc": {
"processor": SpeechT5Processor.from_pretrained(f"{local_models}microsoft/speecht5_vc"),
"model": SpeechT5ForSpeechToSpeech.from_pretrained(f"{local_models}microsoft/speecht5_vc"),
"vocoder": SpeechT5HifiGan.from_pretrained(f"{local_models}microsoft/speecht5_hifigan"),
"embeddings_dataset": load_dataset(f"{local_models}Matthijs/cmu-arctic-xvectors", split="validation"),
"device": "cuda:0"
},
"facebook/maskformer-swin-base-coco": {
"feature_extractor": MaskFormerFeatureExtractor.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"),
"model": MaskFormerForInstanceSegmentation.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"),
"device": "cuda:0"
},
"Intel/dpt-hybrid-midas": {
"model": DPTForDepthEstimation.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas", low_cpu_mem_usage=True),
"feature_extractor": DPTFeatureExtractor.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas"),
"device": "cuda:0"
}