K00B404's picture
Update CLIP.py
1fc69ed verified
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
import torch
from safetensors.torch import load_file as load_safetensor
from diffusers import AutoencoderKL
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load(tokenizer_path = "tokenizer", text_encoder_path = "text_encoder"):
""" loads the clip model and tokenizer. returns: tuple of clip_model, tokenizer"""
safetensor_fp16 = f"./{text_encoder_path}/model.fp16.safetensors" # or use model.safetensors
config_path = f"./{text_encoder_path}/config.json"
# Load tokenizer
clip_tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
# Load CLIPTextModelWithProjection from the config file and safetensor
clip_model = CLIPTextModelWithProjection.from_pretrained(
text_encoder_path,
config=config_path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
# Load safetensor weights
state_dict = load_safetensor(safetensor_fp16)
clip_model.load_state_dict(state_dict)
clip_model = clip_model.to(device)
return clip_model, clip_tokenizer
def load_vae(vae_path='vae'):
return AutoencoderKL.from_pretrained(vae_path)
# Example function for processing prompts
def encode_prompt(prompt,tokenizer,clip_model):
inputs = tokenizer(prompt, return_tensors="pt")
return clip_model(**inputs).last_hidden_state