Spaces:
Runtime error
Runtime error
bipin
commited on
Commit
•
0843a80
1
Parent(s):
1e66cb4
added clip ranking
Browse files- app.py +1 -1
- gpt2_story_gen.py +21 -5
app.py
CHANGED
@@ -14,7 +14,7 @@ def main(pil_image, genre, model, use_beam_search=False):
|
|
14 |
pil_image=pil_image,
|
15 |
use_beam_search=use_beam_search,
|
16 |
)
|
17 |
-
story = generate_story(image_caption, genre.lower())
|
18 |
return story
|
19 |
|
20 |
|
|
|
14 |
pil_image=pil_image,
|
15 |
use_beam_search=use_beam_search,
|
16 |
)
|
17 |
+
story = generate_story(image_caption, image, genre.lower())
|
18 |
return story
|
19 |
|
20 |
|
gpt2_story_gen.py
CHANGED
@@ -1,11 +1,27 @@
|
|
1 |
-
from transformers import pipeline
|
2 |
|
3 |
|
4 |
-
def generate_story(image_caption, genre):
|
5 |
-
|
|
|
|
|
6 |
|
|
|
|
|
|
|
|
|
|
|
7 |
input = f"<BOS> <{genre}> {image_caption}"
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
return story
|
|
|
1 |
+
from transformers import pipeline, CLIPProcessor, CLIPModel
|
2 |
|
3 |
|
4 |
+
def generate_story(image_caption, image, genre):
|
5 |
+
clip_ranker_checkpoint = "openai/clip-vit-base-patch32"
|
6 |
+
clip_ranker_processor = CLIPProcessor.from_pretrained(clip_ranker_checkpoint)
|
7 |
+
clip_ranker_model = CLIPModel.from_pretrained(clip_ranker_checkpoint)
|
8 |
|
9 |
+
story_gen = pipeline(
|
10 |
+
"text-generation",
|
11 |
+
"pranavpsv/genre-story-generator-v2"
|
12 |
+
)
|
13 |
+
|
14 |
input = f"<BOS> <{genre}> {image_caption}"
|
15 |
+
stories = [story_gen(input)[0]['generated_text'].strip(input) for i in range(3)]
|
16 |
+
clip_ranker_inputs = clip_ranker_processor(
|
17 |
+
text=stories,
|
18 |
+
images=image,
|
19 |
+
return_tensors='pt',
|
20 |
+
padding=True
|
21 |
+
)
|
22 |
+
clip_ranker_outputs = clip_ranker_model(**clip_ranker_inputs)
|
23 |
+
logits_per_image = outputs.logits_per_image
|
24 |
+
probs = logits_per_image.softmax(dim=1)
|
25 |
+
story = stories[torch.argmax(probs).item()]
|
26 |
|
27 |
return story
|