howard-hou commited on
Commit
c25fbe0
1 Parent(s): a9b31ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import os, gc
3
  import torch
 
4
  from transformers import CLIPImageProcessor
5
  from huggingface_hub import hf_hub_download
6
 
@@ -33,7 +34,7 @@ image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
33
  ##########################################################################
34
  def generate_prompt(instruction):
35
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
36
- return f"{instruction}\n\nAssistant:"
37
 
38
  def generate(
39
  ctx,
@@ -57,10 +58,8 @@ def generate(
57
  for i in range(int(token_count)):
58
  if i == 0:
59
  input_ids = pipeline.encode(ctx)
60
- print(input_ids)
61
  text_embs = model.w['emb.weight'][input_ids]
62
- input_embs = torch.cat((image_features, text_embs), dim=0)
63
- print(input_embs.shape)
64
  out, state = model.forward(embs=input_embs, state=None)
65
  else:
66
  input_ids = [token]
@@ -103,12 +102,18 @@ examples = [
103
  "What are the things I should be cautious about when I visit here?",
104
  ]
105
  ]
 
106
  def chatbot(image, question):
107
  if image is None:
108
  yield "Please upload an image."
109
  return
110
  image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
111
  image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
 
 
 
 
 
112
  input_text = generate_prompt(question)
113
  for output in generate(input_text, image_features):
114
  yield output
@@ -119,7 +124,7 @@ with gr.Blocks(title=title) as demo:
119
  image = gr.Image(type='pil', label="Image")
120
  with gr.Column():
121
  prompt = gr.Textbox(lines=5, label="Prompt",
122
- value="Please upload an image and ask a question.")
123
  with gr.Row():
124
  submit = gr.Button("Submit", variant="primary")
125
  clear = gr.Button("Clear", variant="secondary")
 
1
  import gradio as gr
2
  import os, gc
3
  import torch
4
+ import torch.nn.functional as F
5
  from transformers import CLIPImageProcessor
6
  from huggingface_hub import hf_hub_download
7
 
 
34
  ##########################################################################
35
  def generate_prompt(instruction):
36
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
37
+ return f"\n{instruction}\n\nAssistant:"
38
 
39
  def generate(
40
  ctx,
 
58
  for i in range(int(token_count)):
59
  if i == 0:
60
  input_ids = pipeline.encode(ctx)
 
61
  text_embs = model.w['emb.weight'][input_ids]
62
+ input_embs = torch.cat((image_features, text_embs), dim=0)[-ctx_limit:]
 
63
  out, state = model.forward(embs=input_embs, state=None)
64
  else:
65
  input_ids = [token]
 
102
  "What are the things I should be cautious about when I visit here?",
103
  ]
104
  ]
105
+
106
  def chatbot(image, question):
107
  if image is None:
108
  yield "Please upload an image."
109
  return
110
  image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
111
  image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
112
+ # apply layer norm to image feature, very important
113
+ image_features = F.layer_norm(image_features,
114
+ (image_features.shape[-1],),
115
+ weight=model.w['blocks.0.ln0.weight'],
116
+ bias=model.w['blocks.0.ln0.bias'])
117
  input_text = generate_prompt(question)
118
  for output in generate(input_text, image_features):
119
  yield output
 
124
  image = gr.Image(type='pil', label="Image")
125
  with gr.Column():
126
  prompt = gr.Textbox(lines=5, label="Prompt",
127
+ value="Render a clear and concise summary of the photo.")
128
  with gr.Row():
129
  submit = gr.Button("Submit", variant="primary")
130
  clear = gr.Button("Clear", variant="secondary")