File size: 4,131 Bytes
06ffadd
 
 
 
 
b991b4f
 
 
 
 
 
45b88b9
 
 
 
 
 
 
 
 
 
 
 
 
 
b991b4f
45b88b9
b991b4f
45b88b9
 
b991b4f
45b88b9
 
b991b4f
45b88b9
 
b991b4f
 
45b88b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b991b4f
45b88b9
 
b991b4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45b88b9
 
b991b4f
 
 
 
 
 
 
 
 
 
45b88b9
b991b4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
from huggingface_hub import login

login(os.environ['hf_token'])


from transformers import CLIPConfig, CLIPModel
from torch import nn
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

def load_distillclip(model_id, revision=None):
  ckpt_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", revision=revision)
  config = CLIPConfig.from_pretrained(model_id)
  model = CLIPModel(config)
  model.vision_model.embeddings.patch_embedding = nn.Conv2d(
        in_channels=model.config.vision_config.num_channels,
        out_channels=model.vision_model.embeddings.embed_dim,
        kernel_size=model.vision_model.embeddings.patch_size,
        stride=model.vision_model.embeddings.patch_size,
        bias=True,
    )
  model.vision_model.pre_layrnorm = nn.Identity()
  print(model.load_state_dict({k.removeprefix('student.'): v for k, v in load_file(ckpt_path).items()}))
  return model
    

import torch
from torch import nn
from einops import reduce
from tqdm.auto import tqdm

class ZeroShotCLIP(nn.Module):
  def __init__(self, model=None, processor=None, classes=[], templates=[], load_in_8bit=False):
    super().__init__()

    self.model = model.eval()
    self.processor = processor
    self.classes = classes
    self.templates = templates
    self._init_weights()

  @torch.no_grad()
  def _init_weights(self):
    self.model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    weights = []
    for classname in tqdm(self.classes):
      prompts = [template.format(classname) for template in self.templates]
      prompts = self.processor(text=prompts, truncation=True, padding=True, return_tensors='pt')
      embeddings = self.model.get_text_features(**{k: v.to(device) for k, v in prompts.items()}).cpu()
      embeddings /= embeddings.norm(dim=-1, keepdim=True)
      embeddings = reduce(embeddings, 'b d -> d', 'mean')
      embeddings /= embeddings.norm()
      weights.append(embeddings)
    weights = torch.stack(weights)
    self.register_buffer('weights', weights)

  @torch.no_grad()
  def forward(self, pixel_values):
    x = self.model.get_image_features(pixel_values=pixel_values)
    x /= x.norm(dim=-1, keepdim=True)
    return x.mm(self.weights.t()) * 100.00000762939453

  def preprocess_and_forward(self, x):
    x = self.processor(images=x, return_tensors='pt')
    return self(x['pixel_values'])


from transformers import CLIPProcessor

model = load_distillclip('Ramos-Ramos/distillclip')
processor = CLIPProcessor.from_pretrained('Ramos-Ramos/distillclip')


def infer(image, classes, templates):
  classes = [label.strip() for label in classes.split(',')]
  print(classes)
  templates = [template.strip() for template in templates.split(';')]
  print(templates)
  clip = ZeroShotCLIP(model=model, processor=processor, classes=classes, templates=templates)
  preds = clip.preprocess_and_forward(image).softmax(dim=1).flatten()
  return {label: score.item() for label, score in zip(classes, preds)}


import gradio as gr

title = 'DistillCLIP'
description = 'Zero-shot image classification demo with DistillCLIP'
article = '''DistillCLIP is a distilled version of [CLIP-ViT/B-32](https://huggingface.co/openai/clip-vit-base-patch32).

Please refer to the [DistillCLIP model card](https://huggingface.co/Ramos-Ramos/distillclip) for more details on DistillCLIP.

Note: As multiplying logits by a temperature prior to the softmax can better distinguish final scores, we multiply DistillCLIP's text-image similarity scores by the teacher CLIP's temperature.'''

demo = gr.Interface(
    fn=infer,
    inputs=[
        gr.Image(label='Image', type='pil'),
        gr.Textbox(label='Classes', placeholder='cat, truck', info='Classes for classification. Separate classes with commas.'),
        gr.Textbox(label='Prompt/s', placeholder='a photo of a {}.; a blurry photo of a {}.', info='Prompt templates. Use "{}" as placeholder for class. Separate prompts with semi-colons.')
    ],
    outputs=gr.Label(label='Class scores'),
    title=title,
    description=description,
    article=article
)
demo.launch()