Spaces:
Runtime error
Runtime error
File size: 2,139 Bytes
cae4936 8985863 cae4936 70b2a7d 8985863 cae4936 70b2a7d cae4936 32613f0 cae4936 32613f0 cae4936 70b2a7d cae4936 1e66cb4 cae4936 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import gradio as gr
from prefix_clip import download_pretrained_model, generate_caption
from gpt2_story_gen import generate_story
coco_weights = 'coco_weights.pt'
conceptual_weights = 'conceptual_weights.pt'
download_pretrained_model('coco', file_to_save=coco_weights)
download_pretrained_model('conceptual', file_to_save=conceptual_weights)
def main(pil_image, genre, model, n_stories, use_beam_search=False):
if model.lower()=='coco':
model_file = coco_weights
elif model.lower()=='conceptual':
model_file = conceptual_weights
image_caption = generate_caption(
model_path=model_file,
pil_image=pil_image,
use_beam_search=use_beam_search,
)
story = generate_story(image_caption, pil_image, genre.lower(), n_stories)
return story
if __name__ == "__main__":
title = "Image to Story"
article = "Combines the power of [clip prefix captioning](https://github.com/rmokady/CLIP_prefix_caption) with [gpt2 story generator](https://huggingface.co/pranavpsv/genre-story-generator-v2) to create stories of different genres from image"
description = "Drop an image and generate stories of different genre based on that image"
interface = gr.Interface(
main,
title=title,
description=description,
article=article,
inputs=[
gr.inputs.Image(type="pil", source="upload", label="Input"),
gr.inputs.Dropdown(
type="value",
label="Story genre",
choices=[
"superhero",
"action",
"drama",
"horror",
"thriller",
"sci_fi",
],
),
gr.inputs.Radio(choices=["coco", "conceptual"], label="Model"),
gr.inputs.Dropdown(choices=[1, 2, 3], label="No. of stories", type="value"),
],
outputs=gr.outputs.Textbox(label="Generated story"),
examples=[["car.jpg", "drama", "conceptual"], ["gangster.jpg", "action", "coco"]],
enable_queue=True,
)
interface.launch()
|