Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -22,6 +22,10 @@ def inference(raw_image, question, decoding_strategy):
|
|
22 |
inputs["do_sample"] = True
|
23 |
inputs["top_k"] = 50
|
24 |
inputs["top_p"] = 0.95
|
|
|
|
|
|
|
|
|
25 |
|
26 |
out = model_image_captioning.generate(**inputs)
|
27 |
return processor.batch_decode(out, skip_special_tokens=True)[0]
|
@@ -29,7 +33,7 @@ def inference(raw_image, question, decoding_strategy):
|
|
29 |
inputs = [
|
30 |
gr.inputs.Image(type='pil'),
|
31 |
gr.inputs.Textbox(lines=2, label="Context (optional)"),
|
32 |
-
gr.inputs.Radio(choices=[
|
33 |
]
|
34 |
outputs = gr.outputs.Textbox(label="Output")
|
35 |
|
|
|
22 |
inputs["do_sample"] = True
|
23 |
inputs["top_k"] = 50
|
24 |
inputs["top_p"] = 0.95
|
25 |
+
elif decoding_strategy == "Contrastive search":
|
26 |
+
inputs["penalty_alpha"] = 0.6
|
27 |
+
inputs["top_k"] = 4
|
28 |
+
inputs["max_length"] = 512
|
29 |
|
30 |
out = model_image_captioning.generate(**inputs)
|
31 |
return processor.batch_decode(out, skip_special_tokens=True)[0]
|
|
|
33 |
inputs = [
|
34 |
gr.inputs.Image(type='pil'),
|
35 |
gr.inputs.Textbox(lines=2, label="Context (optional)"),
|
36 |
+
gr.inputs.Radio(choices=["Beam search","Nucleus sampling", "Contrastive search"], type="value", default="Nucleus sampling", label="Caption Decoding Strategy")
|
37 |
]
|
38 |
outputs = gr.outputs.Textbox(label="Output")
|
39 |
|