Jesuscarr commited on
Commit
73d241c
1 Parent(s): 2f8bbf0

Update app.py

Browse files

Code Comments: Add more comments to your code. This will make it easier for others (and you in the future) to understand what each part of the code does.
Error Handling: Add error handling to your code. This will make your application more robust and easier to debug. For example, you could add try/except blocks around areas of your code that might raise exceptions.
Function Documentation: Add docstrings to your functions. This will make it clear what each function does, what parameters it takes, and what it returns.
Code Organization: Consider organizing your code into classes or modules. This can make your code easier to read and maintain. For example, you could have a separate module for all your Gradio interface functions.
Variable Naming: Use more descriptive variable names. This can make your code easier to understand. For example, instead of cfg, you could use config.
Code Formatting: Follow the PEP 8 style guide for Python code. This will make your code easier to read and more consistent. For example, you should have spaces around operators and after commas, and your lines should not be too long.

Files changed (1) hide show
  1. app.py +132 -76
app.py CHANGED
@@ -20,6 +20,12 @@ from minigpt4.runners import *
20
  from minigpt4.tasks import *
21
 
22
  def parse_args():
 
 
 
 
 
 
23
  parser = argparse.ArgumentParser(description="Demo")
24
  parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml', help="path to configuration file.")
25
  parser.add_argument(
@@ -32,8 +38,13 @@ def parse_args():
32
  args = parser.parse_args()
33
  return args
34
 
35
-
36
  def setup_seeds(config):
 
 
 
 
 
 
37
  seed = config.run_cfg.seed + get_rank()
38
 
39
  random.seed(seed)
@@ -42,37 +53,39 @@ def setup_seeds(config):
42
 
43
  cudnn.benchmark = False
44
  cudnn.deterministic = True
45
-
46
- # ========================================
47
- # Model Initialization
48
- # ========================================
49
 
50
- SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
 
 
51
 
52
- You can duplicate and use it with a paid private GPU.
 
 
 
 
53
 
54
- <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
 
 
55
 
56
- Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io).
57
- '''
 
 
58
 
59
- print('Initializing Chat')
60
- cfg = Config(parse_args())
61
 
62
- model_config = cfg.model_cfg
63
- model_cls = registry.get_model_class(model_config.arch)
64
- model = model_cls.from_config(model_config).to('cuda:0')
65
-
66
- vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
67
- vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
68
- chat = Chat(model, vis_processor)
69
- print('Initialization Finished')
70
 
71
- # ========================================
72
- # Gradio Setting
73
- # ========================================
74
 
75
- def gradio_reset(chat_state, img_list):
 
 
76
  if chat_state is not None:
77
  chat_state.messages = []
78
  if img_list is not None:
@@ -80,6 +93,17 @@ def gradio_reset(chat_state, img_list):
80
  return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
81
 
82
  def upload_img(gr_img, text_input, chat_state):
 
 
 
 
 
 
 
 
 
 
 
83
  if gr_img is None:
84
  return None, None, gr.update(interactive=True), chat_state, None
85
  chat_state = CONV_VISION.copy()
@@ -88,67 +112,99 @@ def upload_img(gr_img, text_input, chat_state):
88
  return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
89
 
90
  def gradio_ask(user_message, chatbot, chat_state):
 
 
 
 
 
 
 
 
 
 
 
91
  if len(user_message) == 0:
92
  return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
93
  chat.ask(user_message, chat_state)
94
  chatbot = chatbot + [[user_message, None]]
95
  return '', chatbot, chat_state
96
 
97
-
98
  def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, max_length=2000)[0]
100
  chatbot[-1][1] = llm_message
101
  return chatbot, chat_state, img_list
102
 
103
- title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
104
- description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
105
- article = """<div style='display:flex; gap: 0.25rem; '><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://github.com/TsuTikgiau/blip2-llm/blob/release_prepare/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
106
- """
107
-
108
- #TODO show examples below
109
-
110
- with gr.Blocks() as demo:
111
- gr.Markdown(title)
112
- gr.Markdown(SHARED_UI_WARNING)
113
- gr.Markdown(description)
114
- gr.Markdown(article)
115
-
116
- with gr.Row():
117
- with gr.Column(scale=0.5):
118
- image = gr.Image(type="pil")
119
- upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
120
- clear = gr.Button("Restart")
121
-
122
- num_beams = gr.Slider(
123
- minimum=1,
124
- maximum=5,
125
- value=1,
126
- step=1,
127
- interactive=True,
128
- label="beam search numbers)",
129
- )
130
-
131
- temperature = gr.Slider(
132
- minimum=0.1,
133
- maximum=2.0,
134
- value=1.0,
135
- step=0.1,
136
- interactive=True,
137
- label="Temperature",
138
- )
139
-
140
-
141
- with gr.Column():
142
- chat_state = gr.State()
143
- img_list = gr.State()
144
- chatbot = gr.Chatbot(label='MiniGPT-4')
145
- text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
146
-
147
- upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
148
-
149
- text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
150
- gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
151
- )
152
- clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
 
 
 
 
 
 
 
 
 
 
153
 
154
- demo.launch(enable_queue=True)
 
20
  from minigpt4.tasks import *
21
 
22
  def parse_args():
23
+ """
24
+ Parse command line arguments.
25
+
26
+ Returns:
27
+ argparse.Namespace: Parsed command line arguments.
28
+ """
29
  parser = argparse.ArgumentParser(description="Demo")
30
  parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml', help="path to configuration file.")
31
  parser.add_argument(
 
38
  args = parser.parse_args()
39
  return args
40
 
 
41
  def setup_seeds(config):
42
+ """
43
+ Set up random seeds for reproducibility.
44
+
45
+ Parameters:
46
+ config (Config): Configuration object.
47
+ """
48
  seed = config.run_cfg.seed + get_rank()
49
 
50
  random.seed(seed)
 
53
 
54
  cudnn.benchmark = False
55
  cudnn.deterministic = True
 
 
 
 
56
 
57
+ def initialize_chat():
58
+ """
59
+ Initialize the chat model.
60
 
61
+ Returns:
62
+ Chat: Initialized chat model.
63
+ """
64
+ print('Initializing Chat')
65
+ config = Config(parse_args())
66
 
67
+ model_config = config.model_cfg
68
+ model_cls = registry.get_model_class(model_config.arch)
69
+ model = model_cls.from_config(model_config).to('cuda:0')
70
 
71
+ vis_processor_cfg = config.datasets_cfg.cc_align.vis_processor.train
72
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
73
+ chat = Chat(model, vis_processor)
74
+ print('Initialization Finished')
75
 
76
+ return chat
 
77
 
78
+ def gradio_reset(chat_state, img_list):
79
+ """
80
+ Reset the Gradio interface.
 
 
 
 
 
81
 
82
+ Parameters:
83
+ chat_state (gr.State): The current state of the chat.
84
+ img_list (gr.State): The current list of images.
85
 
86
+ Returns:
87
+ tuple: Updated Gradio interface elements.
88
+ """
89
  if chat_state is not None:
90
  chat_state.messages = []
91
  if img_list is not None:
 
93
  return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
94
 
95
  def upload_img(gr_img, text_input, chat_state):
96
+ """
97
+ Upload an image and update the Gradio interface.
98
+
99
+ Parameters:
100
+ gr_img (gr.Image): The uploaded image.
101
+ text_input (gr.Textbox): The text input box.
102
+ chat_state (gr.State): The current state of the chat.
103
+
104
+ Returns:
105
+ tuple: Updated Gradio interface elements.
106
+ """
107
  if gr_img is None:
108
  return None, None, gr.update(interactive=True), chat_state, None
109
  chat_state = CONV_VISION.copy()
 
112
  return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
113
 
114
  def gradio_ask(user_message, chatbot, chat_state):
115
+ """
116
+ Process user message and update the Gradio interface.
117
+
118
+ Parameters:
119
+ user_message (str): The message input by the user.
120
+ chatbot (list): The current state of the chatbot.
121
+ chat_state (gr.State): The current state of the chat.
122
+
123
+ Returns:
124
+ tuple: Updated Gradio interface elements.
125
+ """
126
  if len(user_message) == 0:
127
  return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
128
  chat.ask(user_message, chat_state)
129
  chatbot = chatbot + [[user_message, None]]
130
  return '', chatbot, chat_state
131
 
 
132
  def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
133
+ """
134
+ Generate a chatbot answer and update the Gradio interface.
135
+
136
+ Parameters:
137
+ chatbot (list): The current state of the chatbot.
138
+ chat_state (gr.State): The current state of the chat.
139
+ img_list (gr.State): The current list of images.
140
+ num_beams (int): The number of beams for the beam search.
141
+ temperature (float): The temperature for the generation.
142
+
143
+ Returns:
144
+ tuple: Updated Gradio interface elements.
145
+ """
146
  llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, max_length=2000)[0]
147
  chatbot[-1][1] = llm_message
148
  return chatbot, chat_state, img_list
149
 
150
+ def main():
151
+ """
152
+ Main function to run the Gradio interface.
153
+ """
154
+ # Initialize the chat model
155
+ chat = initialize_chat()
156
+
157
+ # Set up the Gradio interface
158
+ title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
159
+ description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
160
+ article = """<div style='display:flex; gap: 0.25rem; '><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://github.com/TsuTikgiau/blip2-llm/blob/release_prepare/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
161
+ """
162
+
163
+ with gr.Blocks() as demo:
164
+ gr.Markdown(title)
165
+ gr.Markdown(SHARED_UI_WARNING)
166
+ gr.Markdown(description)
167
+ gr.Markdown(article)
168
+
169
+ with gr.Row():
170
+ with gr.Column(scale=0.5):
171
+ image = gr.Image(type="pil")
172
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
173
+ clear = gr.Button("Restart")
174
+
175
+ num_beams = gr.Slider(
176
+ minimum=1,
177
+ maximum=5,
178
+ value=1,
179
+ step=1,
180
+ interactive=True,
181
+ label="beam search numbers)",
182
+ )
183
+
184
+ temperature = gr.Slider(
185
+ minimum=0.1,
186
+ maximum=2.0,
187
+ value=1.0,
188
+ step=0.1,
189
+ interactive=True,
190
+ label="Temperature",
191
+ )
192
+
193
+ with gr.Column():
194
+ chat_state = gr.State()
195
+ img_list = gr.State()
196
+ chatbot = gr.Chatbot(label='MiniGPT-4')
197
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
198
+
199
+ upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
200
+
201
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
202
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
203
+ )
204
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
205
+
206
+ demo.launch(enable_queue=True)
207
+
208
+ if __name__ == "__main__":
209
+ main()
210