bipin commited on
Commit
0843a80
1 Parent(s): 1e66cb4

added clip ranking

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. gpt2_story_gen.py +21 -5
app.py CHANGED
@@ -14,7 +14,7 @@ def main(pil_image, genre, model, use_beam_search=False):
14
  pil_image=pil_image,
15
  use_beam_search=use_beam_search,
16
  )
17
- story = generate_story(image_caption, genre.lower())
18
  return story
19
 
20
 
 
14
  pil_image=pil_image,
15
  use_beam_search=use_beam_search,
16
  )
17
+ story = generate_story(image_caption, image, genre.lower())
18
  return story
19
 
20
 
gpt2_story_gen.py CHANGED
@@ -1,11 +1,27 @@
1
- from transformers import pipeline
2
 
3
 
4
- def generate_story(image_caption, genre):
5
- story_gen = pipeline("text-generation", "pranavpsv/genre-story-generator-v2")
 
 
6
 
 
 
 
 
 
7
  input = f"<BOS> <{genre}> {image_caption}"
8
- story = story_gen(input)[0]["generated_text"]
9
- story = f"{story.strip(input)}"
 
 
 
 
 
 
 
 
 
10
 
11
  return story
 
1
+ from transformers import pipeline, CLIPProcessor, CLIPModel
2
 
3
 
4
+ def generate_story(image_caption, image, genre):
5
+ clip_ranker_checkpoint = "openai/clip-vit-base-patch32"
6
+ clip_ranker_processor = CLIPProcessor.from_pretrained(clip_ranker_checkpoint)
7
+ clip_ranker_model = CLIPModel.from_pretrained(clip_ranker_checkpoint)
8
 
9
+ story_gen = pipeline(
10
+ "text-generation",
11
+ "pranavpsv/genre-story-generator-v2"
12
+ )
13
+
14
  input = f"<BOS> <{genre}> {image_caption}"
15
+ stories = [story_gen(input)[0]['generated_text'].strip(input) for i in range(3)]
16
+ clip_ranker_inputs = clip_ranker_processor(
17
+ text=stories,
18
+ images=image,
19
+ return_tensors='pt',
20
+ padding=True
21
+ )
22
+ clip_ranker_outputs = clip_ranker_model(**clip_ranker_inputs)
23
+ logits_per_image = outputs.logits_per_image
24
+ probs = logits_per_image.softmax(dim=1)
25
+ story = stories[torch.argmax(probs).item()]
26
 
27
  return story