bidiptas commited on
Commit
ee6345a
1 Parent(s): 1c7ae7d

Example generation scripts

Browse files
Files changed (3) hide show
  1. README.md +14 -3
  2. generate.py +140 -0
  3. test.py +38 -0
README.md CHANGED
@@ -22,6 +22,8 @@ This model is designed to be used with the LAVIS library. Please install [salesf
22
 
23
  After loading the model, you can disable the qformer text input to follow the same configuration we used for fine-tuning. However, the model still works well with it enabled, so we recommend users to experiment with both and choose the optimal configuration on a case-by-case basis.
24
 
 
 
25
  ```
26
  import torch
27
  from PIL import Image
@@ -32,13 +34,15 @@ from lavis.common.registry import registry
32
 
33
  import requests
34
 
 
 
35
  url = "https://iliad.stanford.edu/pg-vlm/example_images/ceramic_bowl.jpg"
36
  example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
37
 
38
  vlm = load_model(
39
  name='blip2_t5_instruct',
40
  model_type='flant5xxl',
41
- checkpoint='pg-vlm/pgvlm_weights.bin', # replace with location of downloaded weights
42
  is_eval=True,
43
  device="cuda" if torch.cuda.is_available() else "cpu"
44
  )
@@ -56,6 +60,13 @@ question_samples = {
56
  'image': torch.stack([processor(example_image)], dim=0).to(vlm.device)
57
  }
58
 
59
- print(vlm.generate(question_samples, length_penalty=0, repetition_penalty=1, num_captions=3))
 
60
  # (['opaque', 'translucent', 'transparent'], tensor([-0.0448, -4.1387, -4.2793], device='cuda:0'))
61
- ```
 
 
 
 
 
 
 
22
 
23
  After loading the model, you can disable the qformer text input to follow the same configuration we used for fine-tuning. However, the model still works well with it enabled, so we recommend users to experiment with both and choose the optimal configuration on a case-by-case basis.
24
 
25
+ Review the generate.py and test.py scripts provided in the Files tab for an example of using PG-InstructBLIP to determine the transparency of an opaque bowl.
26
+
27
  ```
28
  import torch
29
  from PIL import Image
 
34
 
35
  import requests
36
 
37
+ from generate import generate
38
+
39
  url = "https://iliad.stanford.edu/pg-vlm/example_images/ceramic_bowl.jpg"
40
  example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
41
 
42
  vlm = load_model(
43
  name='blip2_t5_instruct',
44
  model_type='flant5xxl',
45
+ checkpoint='pgvlm_weights.bin', # replace with location of downloaded weights
46
  is_eval=True,
47
  device="cuda" if torch.cuda.is_available() else "cpu"
48
  )
 
60
  'image': torch.stack([processor(example_image)], dim=0).to(vlm.device)
61
  }
62
 
63
+ answers, scores = generate(vlm, question_samples, length_penalty=0, repetition_penalty=1, num_captions=3)
64
+ print(answers, scores)
65
  # (['opaque', 'translucent', 'transparent'], tensor([-0.0448, -4.1387, -4.2793], device='cuda:0'))
66
+ ```
67
+
68
+ Note that the output of the generate function includes the log probabilities of each generation. For categorical properties (like material, transparency, and contents), these probabilities can be interpreted as confidences, as typical with VLMs. In the example above, PG-InstructBLIP is very confident that the ceramic bowl is opaque, which is true.
69
+
70
+ For continuous properties (like mass, fragility, and deformability), we recommend asking yes or no questions like "Is this object heavy?" and comparing the probabilities of the "yes" response between objects to determine which has a larger value.
71
+
72
+ For best results, we also recommend cropping input images to focus on the object in question, because PG-InstructBLIP is fine-tuned on object-centric data.
generate.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ @torch.no_grad()
5
+ def generate(
6
+ vlm,
7
+ samples,
8
+ use_nucleus_sampling=False,
9
+ num_beams=5,
10
+ max_length=256,
11
+ min_length=1,
12
+ top_p=0.9,
13
+ repetition_penalty=1.5,
14
+ length_penalty=1.0,
15
+ num_captions=1,
16
+ temperature=1,
17
+ ):
18
+ if "prompt" in samples.keys():
19
+ prompt = samples["prompt"]
20
+ else:
21
+ prompt = vlm.prompt
22
+
23
+ image = samples["image"]
24
+
25
+ bs = image.size(0)
26
+
27
+ if isinstance(prompt, str):
28
+ prompt = [prompt] * bs
29
+ else:
30
+ assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
31
+
32
+ # For TextCaps
33
+ if "ocr_tokens" in samples.keys() and "{}" in prompt[0]:
34
+ prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)]
35
+
36
+ query_tokens = vlm.query_tokens.expand(bs, -1, -1)
37
+ if vlm.qformer_text_input:
38
+ # remove ocr tokens in q_former (for eval textvqa)
39
+ # qformer_prompt = prompt
40
+ # qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt]
41
+
42
+ text_Qformer = vlm.tokenizer(
43
+ prompt,
44
+ padding='longest',
45
+ truncation=True,
46
+ max_length=vlm.max_txt_len,
47
+ return_tensors="pt",
48
+ ).to(image.device)
49
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
50
+ Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1)
51
+
52
+ # For video data
53
+ if image.dim() == 5:
54
+ inputs_t5, atts_t5 = [], []
55
+ for j in range(image.size(2)):
56
+ this_frame = image[:,:,j,:,:]
57
+ with vlm.maybe_autocast():
58
+ frame_embeds = vlm.ln_vision(vlm.visual_encoder(this_frame))
59
+ frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
60
+
61
+ if vlm.qformer_text_input:
62
+ frame_query_output = vlm.Qformer.bert(
63
+ text_Qformer.input_ids,
64
+ attention_mask = Qformer_atts,
65
+ query_embeds=query_tokens,
66
+ encoder_hidden_states=frame_embeds,
67
+ encoder_attention_mask=frame_atts,
68
+ return_dict=True,
69
+ )
70
+ else:
71
+ frame_query_output = vlm.Qformer.bert(
72
+ query_embeds=query_tokens,
73
+ encoder_hidden_states=frame_embeds,
74
+ encoder_attention_mask=frame_atts,
75
+ return_dict=True,
76
+ )
77
+
78
+ frame_inputs_t5 = vlm.t5_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
79
+ frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
80
+ inputs_t5.append(frame_inputs_t5)
81
+ atts_t5.append(frame_atts_t5)
82
+ inputs_t5 = torch.cat(inputs_t5, dim=1)
83
+ atts_t5 = torch.cat(atts_t5, dim=1)
84
+ else:
85
+ with vlm.maybe_autocast():
86
+ image_embeds = vlm.ln_vision(vlm.visual_encoder(image))
87
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
88
+
89
+ if vlm.qformer_text_input:
90
+ query_output = vlm.Qformer.bert(
91
+ text_Qformer.input_ids,
92
+ attention_mask=Qformer_atts,
93
+ query_embeds=query_tokens,
94
+ encoder_hidden_states=image_embeds,
95
+ encoder_attention_mask=image_atts,
96
+ return_dict=True,
97
+ )
98
+ else:
99
+ query_output = vlm.Qformer.bert(
100
+ query_embeds=query_tokens,
101
+ encoder_hidden_states=image_embeds,
102
+ encoder_attention_mask=image_atts,
103
+ return_dict=True,
104
+ )
105
+
106
+ inputs_t5 = vlm.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
107
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
108
+
109
+ input_tokens = vlm.t5_tokenizer(
110
+ prompt,
111
+ padding="longest",
112
+ return_tensors="pt"
113
+ ).to(image.device)
114
+
115
+ encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
116
+
117
+ with vlm.maybe_autocast(dtype=torch.bfloat16):
118
+ inputs_embeds = vlm.t5_model.encoder.embed_tokens(input_tokens.input_ids)
119
+ inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
120
+
121
+ outputs = vlm.t5_model.generate(
122
+ return_dict_in_generate=True,
123
+ output_scores=True,
124
+ inputs_embeds=inputs_embeds,
125
+ attention_mask=encoder_atts,
126
+ do_sample=use_nucleus_sampling,
127
+ top_p=top_p,
128
+ temperature=temperature,
129
+ num_beams=num_beams,
130
+ max_new_tokens=max_length,
131
+ min_length=min_length,
132
+ repetition_penalty=repetition_penalty,
133
+ length_penalty=length_penalty,
134
+ num_return_sequences=num_captions,
135
+ )
136
+ output_text = vlm.t5_tokenizer.batch_decode(
137
+ outputs.sequences, skip_special_tokens=True
138
+ )
139
+
140
+ return output_text, outputs.sequences_scores
test.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from omegaconf import OmegaConf
4
+
5
+ from lavis.models import load_model, load_preprocess
6
+ from lavis.common.registry import registry
7
+
8
+ import requests
9
+
10
+ from generate import generate
11
+
12
+ url = "https://iliad.stanford.edu/pg-vlm/example_images/ceramic_bowl.jpg"
13
+ example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
14
+
15
+ vlm = load_model(
16
+ name='blip2_t5_instruct',
17
+ model_type='flant5xxl',
18
+ checkpoint='pgvlm_weights.bin', # replace with location of downloaded weights
19
+ is_eval=True,
20
+ device="cuda" if torch.cuda.is_available() else "cpu"
21
+ )
22
+
23
+ vlm.qformer_text_input = False # Optionally disable qformer text
24
+
25
+ model_cls = registry.get_model_class('blip2_t5_instruct')
26
+ model_type = 'flant5xxl'
27
+ preprocess_cfg = OmegaConf.load(model_cls.default_config_path(model_type)).preprocess
28
+ vis_processors, _ = load_preprocess(preprocess_cfg)
29
+ processor = vis_processors["eval"]
30
+
31
+ question_samples = {
32
+ 'prompt': 'Question: Classify this object as transparent, translucent, or opaque? Respond unknown if you are not sure. Short answer:',
33
+ 'image': torch.stack([processor(example_image)], dim=0).to(vlm.device)
34
+ }
35
+
36
+ answers, scores = generate(vlm, question_samples, length_penalty=0, repetition_penalty=1, num_captions=3)
37
+ print(answers, scores)
38
+ # (['opaque', 'translucent', 'transparent'], tensor([-0.0448, -4.1387, -4.2793], device='cuda:0'))