chats-bug commited on
Commit
1d4f82c
1 Parent(s): bba74e9

Inefficient implementation

Browse files
Files changed (1) hide show
  1. model.py +4 -7
model.py CHANGED
@@ -52,17 +52,13 @@ class ImageCaptionModel:
52
 
53
  """
54
  pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
55
-
56
- if diversity_penalty != 0.0:
57
- num_beam_groups = 2
58
- num_captions = num_captions if num_captions % 2 == 0 else num_captions + 1
59
 
60
  generated_ids = self.model.generate(
61
  pixel_values=pixel_values,
62
  max_length=max_length,
63
- num_beams=num_captions,
64
  num_beam_groups=num_beam_groups,
65
- num_return_sequences=num_captions,
66
  temperature=temperature,
67
  top_k=top_k,
68
  top_p=top_p,
@@ -71,8 +67,9 @@ class ImageCaptionModel:
71
  )
72
 
73
  generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
 
74
 
75
- return generated_caption[:num_captions]
76
 
77
 
78
  class GitBaseCocoModel(ImageCaptionModel):
 
52
 
53
  """
54
  pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
 
 
 
 
55
 
56
  generated_ids = self.model.generate(
57
  pixel_values=pixel_values,
58
  max_length=max_length,
59
+ num_beams=num_captions*2,
60
  num_beam_groups=num_beam_groups,
61
+ num_return_sequences=num_captions*2,
62
  temperature=temperature,
63
  top_k=top_k,
64
  top_p=top_p,
 
67
  )
68
 
69
  generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
70
+ generated_caption = [generated_caption[i] for i in range(0, num_captions*2, 2)]
71
 
72
+ return generated_caption
73
 
74
 
75
  class GitBaseCocoModel(ImageCaptionModel):