Li commited on
Commit
8fedd76
1 Parent(s): f533bf3

update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -0
app.py CHANGED
@@ -38,7 +38,36 @@ if "vision_encoder.logit_scale"in model_state_dict:
38
  del model_state_dict["vision_encoder.visual.ln_post.weight"]
39
  del model_state_dict["vision_encoder.visual.ln_post.bias"]
40
  flamingo.load_state_dict(model_state_dict, strict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
42
 
43
  def generate(
44
  idx,
 
38
  del model_state_dict["vision_encoder.visual.ln_post.weight"]
39
  del model_state_dict["vision_encoder.visual.ln_post.bias"]
40
  flamingo.load_state_dict(model_state_dict, strict=True)
41
+ def get_outputs(
42
+ model,
43
+ batch_images,
44
+ attention_mask,
45
+ max_generation_length,
46
+ min_generation_length,
47
+ num_beams,
48
+ length_penalty,
49
+ input_ids,
50
+ image_start_index_list=None,
51
+ image_nums=None,
52
+ bad_words_ids=None,
53
+ ):
54
+ # and torch.cuda.amp.autocast(dtype=torch.float16)
55
+ with torch.inference_mode():
56
+ outputs = model.generate(
57
+ batch_images,
58
+ input_ids,
59
+ attention_mask=attention_mask,
60
+ max_new_tokens=max_generation_length,
61
+ min_length=min_generation_length,
62
+ num_beams=num_beams,
63
+ length_penalty=length_penalty,
64
+ image_start_index_list=image_start_index_list,
65
+ image_nums=image_nums,
66
+ bad_words_ids=bad_words_ids,
67
+ )
68
 
69
+ outputs = outputs[:, len(input_ids[0]) :]
70
+ return outputs
71
 
72
  def generate(
73
  idx,