Spaces:
Running
on
A10G
Running
on
A10G
import numpy | |
import torch.nn as nn | |
from transformers import CLIPTokenizer, CLIPTextModel | |
import transformers | |
transformers.logging.set_verbosity_error() | |
""" | |
Will encounter following warning: | |
- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task | |
or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). | |
- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model | |
that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). | |
https://github.com/CompVis/stable-diffusion/issues/97 | |
according to this issue, this warning is safe. | |
This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion. | |
You can safely ignore the warning, it is not an error. | |
This clip usage is from U-ViT and same with Stable Diffusion. | |
""" | |
class AbstractEncoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def encode(self, *args, **kwargs): | |
raise NotImplementedError | |
class FrozenCLIPEmbedder(AbstractEncoder): | |
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" | |
def __init__(self, sd_path, device="cuda", max_length=77): | |
super().__init__() | |
self.tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer", use_fast=False) | |
self.transformer = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder") | |
self.device = device | |
self.max_length = max_length | |
self.freeze() | |
def freeze(self): | |
self.transformer = self.transformer.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, | |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
tokens = batch_encoding["input_ids"].to(self.device) | |
outputs = self.transformer(input_ids=tokens) | |
z = outputs.last_hidden_state | |
return z | |
def encode(self, text): | |
return self(text) | |
class TextEmbedder(nn.Module): | |
""" | |
Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. | |
""" | |
def __init__(self, args, dropout_prob=0.1): | |
super().__init__() | |
self.text_encodder = FrozenCLIPEmbedder(args) | |
self.dropout_prob = dropout_prob | |
def token_drop(self, text_prompts, force_drop_ids=None): | |
""" | |
Drops text to enable classifier-free guidance. | |
""" | |
if force_drop_ids is None: | |
drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob | |
else: | |
# TODO | |
drop_ids = force_drop_ids == 1 | |
labels = list(numpy.where(drop_ids, "None", text_prompts)) | |
# print(labels) | |
return labels | |
def forward(self, text_prompts, train, force_drop_ids=None): | |
use_dropout = self.dropout_prob > 0 | |
if (train and use_dropout) or (force_drop_ids is not None): | |
text_prompts = self.token_drop(text_prompts, force_drop_ids) | |
embeddings = self.text_encodder(text_prompts) | |
return embeddings | |
if __name__ == '__main__': | |
r""" | |
Returns: | |
Examples from CLIPTextModel: | |
```python | |
>>> from transformers import AutoTokenizer, CLIPTextModel | |
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") | |
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") | |
>>> outputs = model(**inputs) | |
>>> last_hidden_state = outputs.last_hidden_state | |
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states | |
```""" | |
import torch | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
text_encoder = TextEmbedder(dropout_prob=0.00001).to(device) | |
text_encoder1 = FrozenCLIPEmbedder().to(device) | |
text_prompt = ["a photo of a cat", "a photo of a dog", 'a photo of a dog human'] | |
# text_prompt = ('None', 'None', 'None') | |
output = text_encoder(text_prompts=text_prompt, train=True) | |
output1 = text_encoder1(text_prompt) | |
# print(output) | |
print(output.shape) | |
print(output1.shape) | |
print((output==output1).all()) | |