Spaces:
Runtime error
Runtime error
bipin
commited on
Commit
•
1e66cb4
1
Parent(s):
207c00c
added model selection option
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ from prefix_clip import download_pretrained_model, generate_caption
|
|
4 |
from gpt2_story_gen import generate_story
|
5 |
|
6 |
|
7 |
-
def main(pil_image, genre, model
|
8 |
model_file = "pretrained_weights.pt"
|
9 |
|
10 |
download_pretrained_model(model.lower(), file_to_save=model_file)
|
@@ -42,9 +42,10 @@ if __name__ == "__main__":
|
|
42 |
"sci_fi",
|
43 |
],
|
44 |
),
|
|
|
45 |
],
|
46 |
outputs=gr.outputs.Textbox(label="Generated story"),
|
47 |
-
examples=[["car.jpg", "drama"], [
|
48 |
enable_queue=True,
|
49 |
)
|
50 |
interface.launch()
|
|
|
4 |
from gpt2_story_gen import generate_story
|
5 |
|
6 |
|
7 |
+
def main(pil_image, genre, model, use_beam_search=False):
|
8 |
model_file = "pretrained_weights.pt"
|
9 |
|
10 |
download_pretrained_model(model.lower(), file_to_save=model_file)
|
|
|
42 |
"sci_fi",
|
43 |
],
|
44 |
),
|
45 |
+
gr.inputs.Radio(choices=["coco", "conceptual"], label="Model")
|
46 |
],
|
47 |
outputs=gr.outputs.Textbox(label="Generated story"),
|
48 |
+
examples=[["car.jpg", "drama", "conceptual"], ["gangster.jpg", "action", "coco"]],
|
49 |
enable_queue=True,
|
50 |
)
|
51 |
interface.launch()
|