Spaces:
Running
on
A10G
Running
on
A10G
Shanshan Wang
commited on
Commit
•
ab3d7d0
1
Parent(s):
f588375
cache model
Browse files
app.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import AutoModel, AutoTokenizer
|
3 |
import torch
|
4 |
-
import
|
5 |
-
from PIL import Image
|
6 |
-
import logging
|
7 |
-
|
8 |
-
logging.basicConfig(level=logging.INFO)
|
9 |
-
from torchvision.transforms.functional import InterpolationMode
|
10 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
from huggingface_hub import login
|
12 |
hf_token = os.environ.get('hf_token', None)
|
13 |
|
@@ -23,25 +24,40 @@ model_paths = {
|
|
23 |
def load_model_and_set_image_function(model_name):
|
24 |
# Get the model path from the model_paths dictionary
|
25 |
model_path = model_paths[model_name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
).eval().cuda()
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
return
|
45 |
|
46 |
|
47 |
def inference(image_input,
|
@@ -52,22 +68,24 @@ def inference(image_input,
|
|
52 |
tile_num,
|
53 |
chatbot,
|
54 |
state,
|
55 |
-
|
56 |
-
tokenizer_state):
|
57 |
|
58 |
# Check if model_state is None
|
59 |
-
if
|
60 |
chatbot.append(("System", "Please select a model to start the conversation."))
|
61 |
return chatbot, state, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
# Check for empty or invalid user message
|
64 |
if not user_message or user_message.strip() == '' or user_message.lower() == 'system':
|
65 |
chatbot.append(("System", "Please enter a valid message to continue the conversation."))
|
66 |
return chatbot, state, ""
|
67 |
-
|
68 |
-
|
69 |
-
model = model_state
|
70 |
-
tokenizer = tokenizer_state
|
71 |
|
72 |
|
73 |
# if image is provided, store it in image_state:
|
@@ -122,13 +140,20 @@ def regenerate_response(chatbot,
|
|
122 |
tile_num,
|
123 |
state,
|
124 |
image_input,
|
125 |
-
|
126 |
-
tokenizer_state):
|
127 |
|
128 |
# Check if model_state is None
|
129 |
-
if
|
130 |
chatbot.append(("System", "Please select a model to start the conversation."))
|
131 |
return chatbot, state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
# Check if there is a previous user message
|
134 |
if chatbot is None or len(chatbot) == 0:
|
@@ -152,8 +177,6 @@ def regenerate_response(chatbot,
|
|
152 |
else:
|
153 |
state = None
|
154 |
|
155 |
-
model = model_state
|
156 |
-
tokenizer = tokenizer_state
|
157 |
# Set generation config
|
158 |
do_sample = (float(temperature) != 0.0)
|
159 |
|
@@ -195,8 +218,8 @@ with gr.Blocks() as demo:
|
|
195 |
|
196 |
state= gr.State()
|
197 |
model_state = gr.State()
|
198 |
-
tokenizer_state = gr.State()
|
199 |
-
image_load_function_state = gr.State()
|
200 |
|
201 |
with gr.Row():
|
202 |
model_dropdown = gr.Dropdown(
|
@@ -209,14 +232,14 @@ with gr.Blocks() as demo:
|
|
209 |
model_dropdown.change(
|
210 |
fn=load_model_and_set_image_function,
|
211 |
inputs=[model_dropdown],
|
212 |
-
outputs=[model_state
|
213 |
)
|
214 |
|
215 |
# Load the default model when the app starts
|
216 |
demo.load(
|
217 |
fn=load_model_and_set_image_function,
|
218 |
inputs=[model_dropdown],
|
219 |
-
outputs=[model_state
|
220 |
)
|
221 |
|
222 |
with gr.Row(equal_height=True):
|
@@ -282,8 +305,7 @@ with gr.Blocks() as demo:
|
|
282 |
tile_num,
|
283 |
chatbot,
|
284 |
state,
|
285 |
-
model_state
|
286 |
-
tokenizer_state
|
287 |
],
|
288 |
outputs=[chatbot, state, user_input]
|
289 |
)
|
@@ -298,8 +320,7 @@ with gr.Blocks() as demo:
|
|
298 |
tile_num,
|
299 |
state,
|
300 |
image_input,
|
301 |
-
model_state
|
302 |
-
tokenizer_state,
|
303 |
],
|
304 |
outputs=[chatbot, state]
|
305 |
)
|
@@ -319,5 +340,5 @@ with gr.Blocks() as demo:
|
|
319 |
inputs = [image_input, user_input],
|
320 |
label = "examples",
|
321 |
)
|
322 |
-
demo.queue(
|
323 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoModel, AutoTokenizer
|
3 |
import torch
|
4 |
+
import threading
|
|
|
|
|
|
|
|
|
|
|
5 |
import os
|
6 |
+
|
7 |
+
# caching the mode
|
8 |
+
model_cache = {}
|
9 |
+
tokenizer_cache = {}
|
10 |
+
model_lock = threading.Lock()
|
11 |
+
|
12 |
from huggingface_hub import login
|
13 |
hf_token = os.environ.get('hf_token', None)
|
14 |
|
|
|
24 |
def load_model_and_set_image_function(model_name):
|
25 |
# Get the model path from the model_paths dictionary
|
26 |
model_path = model_paths[model_name]
|
27 |
+
|
28 |
+
|
29 |
+
with model_lock:
|
30 |
+
if model_name in model_cache:
|
31 |
+
# model is already loaded; retrieve it from the cache
|
32 |
+
print(f"Model {model_name} is already loaded. Retrieving from cache.")
|
33 |
+
|
34 |
+
else:
|
35 |
+
# load the model and tokenizer
|
36 |
+
print(f"Loading model {model_name}...")
|
37 |
|
38 |
+
model = AutoModel.from_pretrained(
|
39 |
+
model_path,
|
40 |
+
torch_dtype=torch.bfloat16,
|
41 |
+
low_cpu_mem_usage=True,
|
42 |
+
trust_remote_code=True,
|
43 |
+
use_auth_token=hf_token,
|
44 |
+
# device_map="auto"
|
45 |
+
).eval().cuda()
|
|
|
46 |
|
47 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
48 |
+
model_path,
|
49 |
+
trust_remote_code=True,
|
50 |
+
use_fast=False,
|
51 |
+
use_auth_token=hf_token
|
52 |
+
)
|
53 |
+
|
54 |
+
# add the model and tokenizer to the cache
|
55 |
+
model_cache[model_name] = model
|
56 |
+
tokenizer_cache[model_name] = tokenizer
|
57 |
+
print(f"Model {model_name} loaded successfully.")
|
58 |
+
|
59 |
|
60 |
+
return model_name
|
61 |
|
62 |
|
63 |
def inference(image_input,
|
|
|
68 |
tile_num,
|
69 |
chatbot,
|
70 |
state,
|
71 |
+
model_name):
|
|
|
72 |
|
73 |
# Check if model_state is None
|
74 |
+
if model_name is None:
|
75 |
chatbot.append(("System", "Please select a model to start the conversation."))
|
76 |
return chatbot, state, ""
|
77 |
+
|
78 |
+
with model_lock:
|
79 |
+
if model_name not in model_cache:
|
80 |
+
chatbot.append(("System", "Model not loaded. Please wait for the model to load."))
|
81 |
+
return chatbot, state, ""
|
82 |
+
model = model_cache[model_name]
|
83 |
+
tokenizer = tokenizer_cache[model_name]
|
84 |
|
85 |
# Check for empty or invalid user message
|
86 |
if not user_message or user_message.strip() == '' or user_message.lower() == 'system':
|
87 |
chatbot.append(("System", "Please enter a valid message to continue the conversation."))
|
88 |
return chatbot, state, ""
|
|
|
|
|
|
|
|
|
89 |
|
90 |
|
91 |
# if image is provided, store it in image_state:
|
|
|
140 |
tile_num,
|
141 |
state,
|
142 |
image_input,
|
143 |
+
model_name):
|
|
|
144 |
|
145 |
# Check if model_state is None
|
146 |
+
if model_name is None:
|
147 |
chatbot.append(("System", "Please select a model to start the conversation."))
|
148 |
return chatbot, state
|
149 |
+
|
150 |
+
|
151 |
+
with model_lock:
|
152 |
+
if model_name not in model_cache:
|
153 |
+
chatbot.append(("System", "Model not loaded. Please wait for the model to load."))
|
154 |
+
return chatbot, state
|
155 |
+
model = model_cache[model_name]
|
156 |
+
tokenizer = tokenizer_cache[model_name]
|
157 |
|
158 |
# Check if there is a previous user message
|
159 |
if chatbot is None or len(chatbot) == 0:
|
|
|
177 |
else:
|
178 |
state = None
|
179 |
|
|
|
|
|
180 |
# Set generation config
|
181 |
do_sample = (float(temperature) != 0.0)
|
182 |
|
|
|
218 |
|
219 |
state= gr.State()
|
220 |
model_state = gr.State()
|
221 |
+
# tokenizer_state = gr.State()
|
222 |
+
# image_load_function_state = gr.State()
|
223 |
|
224 |
with gr.Row():
|
225 |
model_dropdown = gr.Dropdown(
|
|
|
232 |
model_dropdown.change(
|
233 |
fn=load_model_and_set_image_function,
|
234 |
inputs=[model_dropdown],
|
235 |
+
outputs=[model_state]
|
236 |
)
|
237 |
|
238 |
# Load the default model when the app starts
|
239 |
demo.load(
|
240 |
fn=load_model_and_set_image_function,
|
241 |
inputs=[model_dropdown],
|
242 |
+
outputs=[model_state]
|
243 |
)
|
244 |
|
245 |
with gr.Row(equal_height=True):
|
|
|
305 |
tile_num,
|
306 |
chatbot,
|
307 |
state,
|
308 |
+
model_state
|
|
|
309 |
],
|
310 |
outputs=[chatbot, state, user_input]
|
311 |
)
|
|
|
320 |
tile_num,
|
321 |
state,
|
322 |
image_input,
|
323 |
+
model_state
|
|
|
324 |
],
|
325 |
outputs=[chatbot, state]
|
326 |
)
|
|
|
340 |
inputs = [image_input, user_input],
|
341 |
label = "examples",
|
342 |
)
|
343 |
+
demo.queue()
|
344 |
+
demo.launch(max_threads=10)
|