File size: 3,051 Bytes
0d08077
7d06c4c
0d08077
 
a95ba86
0d08077
df766f8
 
 
 
0d08077
43f7561
76537bc
0d08077
bc65b96
 
a95ba86
df766f8
 
 
 
 
 
0d08077
 
 
a95ba86
0d08077
 
 
 
 
 
a95ba86
0d08077
 
 
 
a95ba86
 
 
 
 
 
 
0d08077
7d06c4c
a95ba86
df766f8
 
 
bba74e9
 
 
2bcaca6
 
 
 
 
df766f8
0d08077
2e7d5a4
df766f8
 
0d08077
a95ba86
 
0d08077
2e7d5a4
0d08077
bc65b96
a95ba86
 
 
 
 
 
 
 
 
bc65b96
 
a95ba86
bc65b96
43f7561
 
8467d12
43f7561
 
 
0d08077
 
a95ba86
 
2e7d5a4
 
 
a95ba86
2e7d5a4
 
a95ba86
2e7d5a4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import torch
from PIL import Image

from model import BlipBaseModel, GitBaseCocoModel

MODELS = {
	"Git-Base-COCO": GitBaseCocoModel,
	"Blip Base": BlipBaseModel,
}

# examples = [["Image1.png"], ["Image2.png"], ["Image3.png"]]

def generate_captions(
	image,
	num_captions,
	model_name,
	max_length,
	temperature,
	top_k,
	top_p,
	repetition_penalty,
	diversity_penalty,
	):
	"""
	Generates captions for the given image.

	-----
	Parameters:
	image: PIL.Image
		The image to generate captions for.
	num_captions: int
		The number of captions to generate.
	** Rest of the parameters are the same as in the model.generate method. **
	-----
	Returns:
	list[str]
	"""
	# Convert the numerical values to their corresponding types.
	# Gradio Slider returns values as floats: except when the value is a whole number, in which case it returns an int.
	# Only float values suffer from this issue.
	temperature = float(temperature)
	top_p = float(top_p)
	repetition_penalty = float(repetition_penalty)
	diversity_penalty = float(diversity_penalty)

	device = "cuda" if torch.cuda.is_available() else "cpu"

	model = MODELS[model_name](device)

	captions = model.generate(
		image=image,
		max_length=max_length,
		num_captions=num_captions,
		temperature=temperature,
		top_k=top_k,
		top_p=top_p,
		repetition_penalty=repetition_penalty,
		diversity_penalty=diversity_penalty,
	)

	# Convert list to a single string separated by newlines.
	captions = "\n".join(captions)
	return captions

title = "AI tool for generating captions for images"
description = "This tool uses pretrained models to generate captions for images."

interface = gr.Interface(
	fn=generate_captions,
	inputs=[
		gr.components.Image(type="pil", label="Image"),
		gr.components.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Captions to Generate"),
		gr.components.Dropdown(MODELS.keys(), label="Model", value=list(MODELS.keys())[1]), # Default to Blip Base
		gr.components.Slider(minimum=20, maximum=100, step=5, value=50, label="Maximum Caption Length"),
		gr.components.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label="Temperature"),
		gr.components.Slider(minimum=1, maximum=100, step=1, value=50, label="Top K"),
		gr.components.Slider(minimum=0.1, maximum=5.0, step=0.1, value=1.0, label="Top P"),
		gr.components.Slider(minimum=1.0, maximum=10.0, step=0.1, value=2.0, label="Repetition Penalty"),
		gr.components.Slider(minimum=0.0, maximum=10.0, step=0.1, value=2.0, label="Diversity Penalty"),
	],
	outputs=[
		gr.components.Textbox(label="Caption"),
	],
	# Set image examples to be displayed in the interface.
	examples = [
		["Image1.png", 1, list(MODELS.keys())[1], 50, 1.0, 50, 1.0, 2.0, 2.0],
		["Image2.png", 1, list(MODELS.keys())[1], 50, 1.0, 50, 1.0, 2.0, 2.0],
		["Image3.png", 1, list(MODELS.keys())[1], 50, 1.0, 50, 1.0, 2.0, 2.0],
	],
	title=title,
	description=description,
	allow_flagging="never",
)


if __name__ == "__main__":
    # Launch the interface.
	interface.launch(
		enable_queue=True,
		debug=True,
	)