Spaces:
Runtime error
Runtime error
bipin
commited on
Commit
•
ef62727
1
Parent(s):
cf0876f
fixed import error
Browse files- gpt2_story_gen.py +2 -1
gpt2_story_gen.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from transformers import pipeline, CLIPProcessor, CLIPModel
|
|
|
2 |
|
3 |
|
4 |
def generate_story(image_caption, image, genre):
|
@@ -21,7 +22,7 @@ def generate_story(image_caption, image, genre):
|
|
21 |
padding=True
|
22 |
)
|
23 |
clip_ranker_outputs = clip_ranker_model(**clip_ranker_inputs)
|
24 |
-
logits_per_image =
|
25 |
probs = logits_per_image.softmax(dim=1)
|
26 |
story = stories[torch.argmax(probs).item()]
|
27 |
|
|
|
1 |
from transformers import pipeline, CLIPProcessor, CLIPModel
|
2 |
+
import torch
|
3 |
|
4 |
|
5 |
def generate_story(image_caption, image, genre):
|
|
|
22 |
padding=True
|
23 |
)
|
24 |
clip_ranker_outputs = clip_ranker_model(**clip_ranker_inputs)
|
25 |
+
logits_per_image = clip_ranker_outputs.logits_per_image
|
26 |
probs = logits_per_image.softmax(dim=1)
|
27 |
story = stories[torch.argmax(probs).item()]
|
28 |
|