Shanshan Wang commited on
Commit
ab3d7d0
1 Parent(s): f588375

cache model

Browse files
Files changed (1) hide show
  1. app.py +66 -45
app.py CHANGED
@@ -1,13 +1,14 @@
1
  import gradio as gr
2
- from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
3
  import torch
4
- import torchvision.transforms as T
5
- from PIL import Image
6
- import logging
7
-
8
- logging.basicConfig(level=logging.INFO)
9
- from torchvision.transforms.functional import InterpolationMode
10
  import os
 
 
 
 
 
 
11
  from huggingface_hub import login
12
  hf_token = os.environ.get('hf_token', None)
13
 
@@ -23,25 +24,40 @@ model_paths = {
23
  def load_model_and_set_image_function(model_name):
24
  # Get the model path from the model_paths dictionary
25
  model_path = model_paths[model_name]
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Load the model
28
- model = AutoModel.from_pretrained(
29
- model_path,
30
- torch_dtype=torch.bfloat16,
31
- low_cpu_mem_usage=True,
32
- trust_remote_code=True,
33
- use_auth_token=hf_token,
34
- device_map="auto"
35
- ).eval().cuda()
36
 
37
- tokenizer = AutoTokenizer.from_pretrained(
38
- model_path,
39
- trust_remote_code=True,
40
- use_fast=False,
41
- use_auth_token=hf_token
42
- )
 
 
 
 
 
 
43
 
44
- return model, tokenizer
45
 
46
 
47
  def inference(image_input,
@@ -52,22 +68,24 @@ def inference(image_input,
52
  tile_num,
53
  chatbot,
54
  state,
55
- model_state,
56
- tokenizer_state):
57
 
58
  # Check if model_state is None
59
- if model_state is None or tokenizer_state is None:
60
  chatbot.append(("System", "Please select a model to start the conversation."))
61
  return chatbot, state, ""
 
 
 
 
 
 
 
62
 
63
  # Check for empty or invalid user message
64
  if not user_message or user_message.strip() == '' or user_message.lower() == 'system':
65
  chatbot.append(("System", "Please enter a valid message to continue the conversation."))
66
  return chatbot, state, ""
67
-
68
-
69
- model = model_state
70
- tokenizer = tokenizer_state
71
 
72
 
73
  # if image is provided, store it in image_state:
@@ -122,13 +140,20 @@ def regenerate_response(chatbot,
122
  tile_num,
123
  state,
124
  image_input,
125
- model_state,
126
- tokenizer_state):
127
 
128
  # Check if model_state is None
129
- if model_state is None or tokenizer_state is None:
130
  chatbot.append(("System", "Please select a model to start the conversation."))
131
  return chatbot, state
 
 
 
 
 
 
 
 
132
 
133
  # Check if there is a previous user message
134
  if chatbot is None or len(chatbot) == 0:
@@ -152,8 +177,6 @@ def regenerate_response(chatbot,
152
  else:
153
  state = None
154
 
155
- model = model_state
156
- tokenizer = tokenizer_state
157
  # Set generation config
158
  do_sample = (float(temperature) != 0.0)
159
 
@@ -195,8 +218,8 @@ with gr.Blocks() as demo:
195
 
196
  state= gr.State()
197
  model_state = gr.State()
198
- tokenizer_state = gr.State()
199
- image_load_function_state = gr.State()
200
 
201
  with gr.Row():
202
  model_dropdown = gr.Dropdown(
@@ -209,14 +232,14 @@ with gr.Blocks() as demo:
209
  model_dropdown.change(
210
  fn=load_model_and_set_image_function,
211
  inputs=[model_dropdown],
212
- outputs=[model_state, tokenizer_state]
213
  )
214
 
215
  # Load the default model when the app starts
216
  demo.load(
217
  fn=load_model_and_set_image_function,
218
  inputs=[model_dropdown],
219
- outputs=[model_state, tokenizer_state]
220
  )
221
 
222
  with gr.Row(equal_height=True):
@@ -282,8 +305,7 @@ with gr.Blocks() as demo:
282
  tile_num,
283
  chatbot,
284
  state,
285
- model_state,
286
- tokenizer_state
287
  ],
288
  outputs=[chatbot, state, user_input]
289
  )
@@ -298,8 +320,7 @@ with gr.Blocks() as demo:
298
  tile_num,
299
  state,
300
  image_input,
301
- model_state,
302
- tokenizer_state,
303
  ],
304
  outputs=[chatbot, state]
305
  )
@@ -319,5 +340,5 @@ with gr.Blocks() as demo:
319
  inputs = [image_input, user_input],
320
  label = "examples",
321
  )
322
- demo.queue(concurrency_count=4,max_size=10)
323
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer
3
  import torch
4
+ import threading
 
 
 
 
 
5
  import os
6
+
7
+ # caching the mode
8
+ model_cache = {}
9
+ tokenizer_cache = {}
10
+ model_lock = threading.Lock()
11
+
12
  from huggingface_hub import login
13
  hf_token = os.environ.get('hf_token', None)
14
 
 
24
  def load_model_and_set_image_function(model_name):
25
  # Get the model path from the model_paths dictionary
26
  model_path = model_paths[model_name]
27
+
28
+
29
+ with model_lock:
30
+ if model_name in model_cache:
31
+ # model is already loaded; retrieve it from the cache
32
+ print(f"Model {model_name} is already loaded. Retrieving from cache.")
33
+
34
+ else:
35
+ # load the model and tokenizer
36
+ print(f"Loading model {model_name}...")
37
 
38
+ model = AutoModel.from_pretrained(
39
+ model_path,
40
+ torch_dtype=torch.bfloat16,
41
+ low_cpu_mem_usage=True,
42
+ trust_remote_code=True,
43
+ use_auth_token=hf_token,
44
+ # device_map="auto"
45
+ ).eval().cuda()
 
46
 
47
+ tokenizer = AutoTokenizer.from_pretrained(
48
+ model_path,
49
+ trust_remote_code=True,
50
+ use_fast=False,
51
+ use_auth_token=hf_token
52
+ )
53
+
54
+ # add the model and tokenizer to the cache
55
+ model_cache[model_name] = model
56
+ tokenizer_cache[model_name] = tokenizer
57
+ print(f"Model {model_name} loaded successfully.")
58
+
59
 
60
+ return model_name
61
 
62
 
63
  def inference(image_input,
 
68
  tile_num,
69
  chatbot,
70
  state,
71
+ model_name):
 
72
 
73
  # Check if model_state is None
74
+ if model_name is None:
75
  chatbot.append(("System", "Please select a model to start the conversation."))
76
  return chatbot, state, ""
77
+
78
+ with model_lock:
79
+ if model_name not in model_cache:
80
+ chatbot.append(("System", "Model not loaded. Please wait for the model to load."))
81
+ return chatbot, state, ""
82
+ model = model_cache[model_name]
83
+ tokenizer = tokenizer_cache[model_name]
84
 
85
  # Check for empty or invalid user message
86
  if not user_message or user_message.strip() == '' or user_message.lower() == 'system':
87
  chatbot.append(("System", "Please enter a valid message to continue the conversation."))
88
  return chatbot, state, ""
 
 
 
 
89
 
90
 
91
  # if image is provided, store it in image_state:
 
140
  tile_num,
141
  state,
142
  image_input,
143
+ model_name):
 
144
 
145
  # Check if model_state is None
146
+ if model_name is None:
147
  chatbot.append(("System", "Please select a model to start the conversation."))
148
  return chatbot, state
149
+
150
+
151
+ with model_lock:
152
+ if model_name not in model_cache:
153
+ chatbot.append(("System", "Model not loaded. Please wait for the model to load."))
154
+ return chatbot, state
155
+ model = model_cache[model_name]
156
+ tokenizer = tokenizer_cache[model_name]
157
 
158
  # Check if there is a previous user message
159
  if chatbot is None or len(chatbot) == 0:
 
177
  else:
178
  state = None
179
 
 
 
180
  # Set generation config
181
  do_sample = (float(temperature) != 0.0)
182
 
 
218
 
219
  state= gr.State()
220
  model_state = gr.State()
221
+ # tokenizer_state = gr.State()
222
+ # image_load_function_state = gr.State()
223
 
224
  with gr.Row():
225
  model_dropdown = gr.Dropdown(
 
232
  model_dropdown.change(
233
  fn=load_model_and_set_image_function,
234
  inputs=[model_dropdown],
235
+ outputs=[model_state]
236
  )
237
 
238
  # Load the default model when the app starts
239
  demo.load(
240
  fn=load_model_and_set_image_function,
241
  inputs=[model_dropdown],
242
+ outputs=[model_state]
243
  )
244
 
245
  with gr.Row(equal_height=True):
 
305
  tile_num,
306
  chatbot,
307
  state,
308
+ model_state
 
309
  ],
310
  outputs=[chatbot, state, user_input]
311
  )
 
320
  tile_num,
321
  state,
322
  image_input,
323
+ model_state
 
324
  ],
325
  outputs=[chatbot, state]
326
  )
 
340
  inputs = [image_input, user_input],
341
  label = "examples",
342
  )
343
+ demo.queue()
344
+ demo.launch(max_threads=10)