bipin commited on
Commit
ef62727
1 Parent(s): cf0876f

fixed import error

Browse files
Files changed (1) hide show
  1. gpt2_story_gen.py +2 -1
gpt2_story_gen.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import pipeline, CLIPProcessor, CLIPModel
 
2
 
3
 
4
  def generate_story(image_caption, image, genre):
@@ -21,7 +22,7 @@ def generate_story(image_caption, image, genre):
21
  padding=True
22
  )
23
  clip_ranker_outputs = clip_ranker_model(**clip_ranker_inputs)
24
- logits_per_image = outputs.logits_per_image
25
  probs = logits_per_image.softmax(dim=1)
26
  story = stories[torch.argmax(probs).item()]
27
 
 
1
  from transformers import pipeline, CLIPProcessor, CLIPModel
2
+ import torch
3
 
4
 
5
  def generate_story(image_caption, image, genre):
 
22
  padding=True
23
  )
24
  clip_ranker_outputs = clip_ranker_model(**clip_ranker_inputs)
25
+ logits_per_image = clip_ranker_outputs.logits_per_image
26
  probs = logits_per_image.softmax(dim=1)
27
  story = stories[torch.argmax(probs).item()]
28