Spaces:
Running
Running
from transformers import CLIPTextModelWithProjection, CLIPTokenizer | |
import torch | |
from safetensors.torch import load_file as load_safetensor | |
from diffusers import AutoencoderKL | |
# Device configuration | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def load(tokenizer_path = "tokenizer", text_encoder_path = "text_encoder"): | |
""" loads the clip model and tokenizer. returns: tuple of clip_model, tokenizer""" | |
safetensor_fp16 = f"./{text_encoder_path}/model.fp16.safetensors" # or use model.safetensors | |
config_path = f"./{text_encoder_path}/config.json" | |
# Load tokenizer | |
clip_tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) | |
# Load CLIPTextModelWithProjection from the config file and safetensor | |
clip_model = CLIPTextModelWithProjection.from_pretrained( | |
text_encoder_path, | |
config=config_path, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
# Load safetensor weights | |
state_dict = load_safetensor(safetensor_fp16) | |
clip_model.load_state_dict(state_dict) | |
clip_model = clip_model.to(device) | |
return clip_model, clip_tokenizer | |
def load_vae(vae_path='vae'): | |
return AutoencoderKL.from_pretrained(vae_path) | |
# Example function for processing prompts | |
def encode_prompt(prompt,tokenizer,clip_model): | |
inputs = tokenizer(prompt, return_tensors="pt") | |
return clip_model(**inputs).last_hidden_state | |