Example generation scripts
Browse files- README.md +14 -3
- generate.py +140 -0
- test.py +38 -0
README.md
CHANGED
@@ -22,6 +22,8 @@ This model is designed to be used with the LAVIS library. Please install [salesf
|
|
22 |
|
23 |
After loading the model, you can disable the qformer text input to follow the same configuration we used for fine-tuning. However, the model still works well with it enabled, so we recommend users to experiment with both and choose the optimal configuration on a case-by-case basis.
|
24 |
|
|
|
|
|
25 |
```
|
26 |
import torch
|
27 |
from PIL import Image
|
@@ -32,13 +34,15 @@ from lavis.common.registry import registry
|
|
32 |
|
33 |
import requests
|
34 |
|
|
|
|
|
35 |
url = "https://iliad.stanford.edu/pg-vlm/example_images/ceramic_bowl.jpg"
|
36 |
example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
37 |
|
38 |
vlm = load_model(
|
39 |
name='blip2_t5_instruct',
|
40 |
model_type='flant5xxl',
|
41 |
-
checkpoint='
|
42 |
is_eval=True,
|
43 |
device="cuda" if torch.cuda.is_available() else "cpu"
|
44 |
)
|
@@ -56,6 +60,13 @@ question_samples = {
|
|
56 |
'image': torch.stack([processor(example_image)], dim=0).to(vlm.device)
|
57 |
}
|
58 |
|
59 |
-
|
|
|
60 |
# (['opaque', 'translucent', 'transparent'], tensor([-0.0448, -4.1387, -4.2793], device='cuda:0'))
|
61 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
After loading the model, you can disable the qformer text input to follow the same configuration we used for fine-tuning. However, the model still works well with it enabled, so we recommend users to experiment with both and choose the optimal configuration on a case-by-case basis.
|
24 |
|
25 |
+
Review the generate.py and test.py scripts provided in the Files tab for an example of using PG-InstructBLIP to determine the transparency of an opaque bowl.
|
26 |
+
|
27 |
```
|
28 |
import torch
|
29 |
from PIL import Image
|
|
|
34 |
|
35 |
import requests
|
36 |
|
37 |
+
from generate import generate
|
38 |
+
|
39 |
url = "https://iliad.stanford.edu/pg-vlm/example_images/ceramic_bowl.jpg"
|
40 |
example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
41 |
|
42 |
vlm = load_model(
|
43 |
name='blip2_t5_instruct',
|
44 |
model_type='flant5xxl',
|
45 |
+
checkpoint='pgvlm_weights.bin', # replace with location of downloaded weights
|
46 |
is_eval=True,
|
47 |
device="cuda" if torch.cuda.is_available() else "cpu"
|
48 |
)
|
|
|
60 |
'image': torch.stack([processor(example_image)], dim=0).to(vlm.device)
|
61 |
}
|
62 |
|
63 |
+
answers, scores = generate(vlm, question_samples, length_penalty=0, repetition_penalty=1, num_captions=3)
|
64 |
+
print(answers, scores)
|
65 |
# (['opaque', 'translucent', 'transparent'], tensor([-0.0448, -4.1387, -4.2793], device='cuda:0'))
|
66 |
+
```
|
67 |
+
|
68 |
+
Note that the output of the generate function includes the log probabilities of each generation. For categorical properties (like material, transparency, and contents), these probabilities can be interpreted as confidences, as typical with VLMs. In the example above, PG-InstructBLIP is very confident that the ceramic bowl is opaque, which is true.
|
69 |
+
|
70 |
+
For continuous properties (like mass, fragility, and deformability), we recommend asking yes or no questions like "Is this object heavy?" and comparing the probabilities of the "yes" response between objects to determine which has a larger value.
|
71 |
+
|
72 |
+
For best results, we also recommend cropping input images to focus on the object in question, because PG-InstructBLIP is fine-tuned on object-centric data.
|
generate.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
@torch.no_grad()
|
5 |
+
def generate(
|
6 |
+
vlm,
|
7 |
+
samples,
|
8 |
+
use_nucleus_sampling=False,
|
9 |
+
num_beams=5,
|
10 |
+
max_length=256,
|
11 |
+
min_length=1,
|
12 |
+
top_p=0.9,
|
13 |
+
repetition_penalty=1.5,
|
14 |
+
length_penalty=1.0,
|
15 |
+
num_captions=1,
|
16 |
+
temperature=1,
|
17 |
+
):
|
18 |
+
if "prompt" in samples.keys():
|
19 |
+
prompt = samples["prompt"]
|
20 |
+
else:
|
21 |
+
prompt = vlm.prompt
|
22 |
+
|
23 |
+
image = samples["image"]
|
24 |
+
|
25 |
+
bs = image.size(0)
|
26 |
+
|
27 |
+
if isinstance(prompt, str):
|
28 |
+
prompt = [prompt] * bs
|
29 |
+
else:
|
30 |
+
assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
|
31 |
+
|
32 |
+
# For TextCaps
|
33 |
+
if "ocr_tokens" in samples.keys() and "{}" in prompt[0]:
|
34 |
+
prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)]
|
35 |
+
|
36 |
+
query_tokens = vlm.query_tokens.expand(bs, -1, -1)
|
37 |
+
if vlm.qformer_text_input:
|
38 |
+
# remove ocr tokens in q_former (for eval textvqa)
|
39 |
+
# qformer_prompt = prompt
|
40 |
+
# qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt]
|
41 |
+
|
42 |
+
text_Qformer = vlm.tokenizer(
|
43 |
+
prompt,
|
44 |
+
padding='longest',
|
45 |
+
truncation=True,
|
46 |
+
max_length=vlm.max_txt_len,
|
47 |
+
return_tensors="pt",
|
48 |
+
).to(image.device)
|
49 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
|
50 |
+
Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1)
|
51 |
+
|
52 |
+
# For video data
|
53 |
+
if image.dim() == 5:
|
54 |
+
inputs_t5, atts_t5 = [], []
|
55 |
+
for j in range(image.size(2)):
|
56 |
+
this_frame = image[:,:,j,:,:]
|
57 |
+
with vlm.maybe_autocast():
|
58 |
+
frame_embeds = vlm.ln_vision(vlm.visual_encoder(this_frame))
|
59 |
+
frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
60 |
+
|
61 |
+
if vlm.qformer_text_input:
|
62 |
+
frame_query_output = vlm.Qformer.bert(
|
63 |
+
text_Qformer.input_ids,
|
64 |
+
attention_mask = Qformer_atts,
|
65 |
+
query_embeds=query_tokens,
|
66 |
+
encoder_hidden_states=frame_embeds,
|
67 |
+
encoder_attention_mask=frame_atts,
|
68 |
+
return_dict=True,
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
frame_query_output = vlm.Qformer.bert(
|
72 |
+
query_embeds=query_tokens,
|
73 |
+
encoder_hidden_states=frame_embeds,
|
74 |
+
encoder_attention_mask=frame_atts,
|
75 |
+
return_dict=True,
|
76 |
+
)
|
77 |
+
|
78 |
+
frame_inputs_t5 = vlm.t5_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
79 |
+
frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
|
80 |
+
inputs_t5.append(frame_inputs_t5)
|
81 |
+
atts_t5.append(frame_atts_t5)
|
82 |
+
inputs_t5 = torch.cat(inputs_t5, dim=1)
|
83 |
+
atts_t5 = torch.cat(atts_t5, dim=1)
|
84 |
+
else:
|
85 |
+
with vlm.maybe_autocast():
|
86 |
+
image_embeds = vlm.ln_vision(vlm.visual_encoder(image))
|
87 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
|
88 |
+
|
89 |
+
if vlm.qformer_text_input:
|
90 |
+
query_output = vlm.Qformer.bert(
|
91 |
+
text_Qformer.input_ids,
|
92 |
+
attention_mask=Qformer_atts,
|
93 |
+
query_embeds=query_tokens,
|
94 |
+
encoder_hidden_states=image_embeds,
|
95 |
+
encoder_attention_mask=image_atts,
|
96 |
+
return_dict=True,
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
query_output = vlm.Qformer.bert(
|
100 |
+
query_embeds=query_tokens,
|
101 |
+
encoder_hidden_states=image_embeds,
|
102 |
+
encoder_attention_mask=image_atts,
|
103 |
+
return_dict=True,
|
104 |
+
)
|
105 |
+
|
106 |
+
inputs_t5 = vlm.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
|
107 |
+
atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
|
108 |
+
|
109 |
+
input_tokens = vlm.t5_tokenizer(
|
110 |
+
prompt,
|
111 |
+
padding="longest",
|
112 |
+
return_tensors="pt"
|
113 |
+
).to(image.device)
|
114 |
+
|
115 |
+
encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
|
116 |
+
|
117 |
+
with vlm.maybe_autocast(dtype=torch.bfloat16):
|
118 |
+
inputs_embeds = vlm.t5_model.encoder.embed_tokens(input_tokens.input_ids)
|
119 |
+
inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
|
120 |
+
|
121 |
+
outputs = vlm.t5_model.generate(
|
122 |
+
return_dict_in_generate=True,
|
123 |
+
output_scores=True,
|
124 |
+
inputs_embeds=inputs_embeds,
|
125 |
+
attention_mask=encoder_atts,
|
126 |
+
do_sample=use_nucleus_sampling,
|
127 |
+
top_p=top_p,
|
128 |
+
temperature=temperature,
|
129 |
+
num_beams=num_beams,
|
130 |
+
max_new_tokens=max_length,
|
131 |
+
min_length=min_length,
|
132 |
+
repetition_penalty=repetition_penalty,
|
133 |
+
length_penalty=length_penalty,
|
134 |
+
num_return_sequences=num_captions,
|
135 |
+
)
|
136 |
+
output_text = vlm.t5_tokenizer.batch_decode(
|
137 |
+
outputs.sequences, skip_special_tokens=True
|
138 |
+
)
|
139 |
+
|
140 |
+
return output_text, outputs.sequences_scores
|
test.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
|
5 |
+
from lavis.models import load_model, load_preprocess
|
6 |
+
from lavis.common.registry import registry
|
7 |
+
|
8 |
+
import requests
|
9 |
+
|
10 |
+
from generate import generate
|
11 |
+
|
12 |
+
url = "https://iliad.stanford.edu/pg-vlm/example_images/ceramic_bowl.jpg"
|
13 |
+
example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
14 |
+
|
15 |
+
vlm = load_model(
|
16 |
+
name='blip2_t5_instruct',
|
17 |
+
model_type='flant5xxl',
|
18 |
+
checkpoint='pgvlm_weights.bin', # replace with location of downloaded weights
|
19 |
+
is_eval=True,
|
20 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
)
|
22 |
+
|
23 |
+
vlm.qformer_text_input = False # Optionally disable qformer text
|
24 |
+
|
25 |
+
model_cls = registry.get_model_class('blip2_t5_instruct')
|
26 |
+
model_type = 'flant5xxl'
|
27 |
+
preprocess_cfg = OmegaConf.load(model_cls.default_config_path(model_type)).preprocess
|
28 |
+
vis_processors, _ = load_preprocess(preprocess_cfg)
|
29 |
+
processor = vis_processors["eval"]
|
30 |
+
|
31 |
+
question_samples = {
|
32 |
+
'prompt': 'Question: Classify this object as transparent, translucent, or opaque? Respond unknown if you are not sure. Short answer:',
|
33 |
+
'image': torch.stack([processor(example_image)], dim=0).to(vlm.device)
|
34 |
+
}
|
35 |
+
|
36 |
+
answers, scores = generate(vlm, question_samples, length_penalty=0, repetition_penalty=1, num_captions=3)
|
37 |
+
print(answers, scores)
|
38 |
+
# (['opaque', 'translucent', 'transparent'], tensor([-0.0448, -4.1387, -4.2793], device='cuda:0'))
|