Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Miaoran000
commited on
Commit
•
411e7e6
1
Parent(s):
0856ae9
update summary generation for new models
Browse files- src/backend/model_operations.py +58 -22
src/backend/model_operations.py
CHANGED
@@ -215,16 +215,34 @@ class SummaryGenerator:
|
|
215 |
{"role": "user", "content": user_prompt}] if 'gpt' in self.model_id
|
216 |
else [{"role": "user", "content": system_prompt + '\n' + user_prompt}],
|
217 |
temperature=0.0 if 'gpt' in self.model_id.lower() else 1.0, # fixed at 1 for o1 models
|
218 |
-
max_completion_tokens=250 if 'gpt' in self.model_id.lower() else None, # not compatible with o1 series models
|
219 |
)
|
220 |
# print(response)
|
221 |
result = response.choices[0].message.content
|
222 |
print(result)
|
223 |
return result
|
224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
elif 'gemini' in self.model_id.lower():
|
226 |
vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1")
|
227 |
-
gemini_model_id_map = {'gemini-1.5-pro-exp-0827':'gemini-pro-experimental', 'gemini-1.5-flash-exp-0827': 'gemini-flash-experimental'}
|
228 |
model = GenerativeModel(
|
229 |
self.model_id.lower().split('google/')[-1],
|
230 |
system_instruction = [system_prompt]
|
@@ -289,21 +307,23 @@ class SummaryGenerator:
|
|
289 |
return response
|
290 |
|
291 |
elif 'claude' in self.model_id.lower(): # using anthropic api
|
|
|
292 |
client = anthropic.Anthropic()
|
293 |
message = client.messages.create(
|
294 |
model=self.model_id.split('/')[-1],
|
295 |
-
max_tokens=
|
296 |
temperature=0,
|
297 |
system=system_prompt,
|
298 |
messages=[
|
299 |
{
|
300 |
"role": "user",
|
301 |
-
"content": [
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
]
|
|
|
307 |
}
|
308 |
]
|
309 |
)
|
@@ -311,15 +331,17 @@ class SummaryGenerator:
|
|
311 |
print(result)
|
312 |
return result
|
313 |
|
314 |
-
elif 'command-r' in self.model_id.lower():
|
315 |
-
co = cohere.
|
316 |
response = co.chat(
|
317 |
-
|
318 |
-
|
|
|
|
|
319 |
],
|
320 |
-
|
321 |
)
|
322 |
-
result = response.text
|
323 |
print(result)
|
324 |
return result
|
325 |
|
@@ -375,7 +397,10 @@ class SummaryGenerator:
|
|
375 |
trust_remote_code=True
|
376 |
)
|
377 |
else:
|
378 |
-
|
|
|
|
|
|
|
379 |
print("Tokenizer loaded")
|
380 |
if 'jamba' in self.model_id.lower():
|
381 |
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id,
|
@@ -390,8 +415,14 @@ class SummaryGenerator:
|
|
390 |
)
|
391 |
self.processor = AutoProcessor.from_pretrained(self.model_id)
|
392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
else:
|
394 |
-
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto"
|
395 |
# print(self.local_model.device)
|
396 |
print("Local model loaded")
|
397 |
|
@@ -419,7 +450,7 @@ class SummaryGenerator:
|
|
419 |
# gemma-1.1, mistral-7b does not accept system role
|
420 |
{"role": "user", "content": system_prompt + '\n' + user_prompt}
|
421 |
]
|
422 |
-
prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False)
|
423 |
|
424 |
elif 'phi-2' in self.model_id.lower():
|
425 |
prompt = system_prompt + '\n' + user_prompt
|
@@ -451,20 +482,25 @@ class SummaryGenerator:
|
|
451 |
# print(prompt)
|
452 |
# print('-'*50)
|
453 |
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
454 |
-
|
455 |
-
|
456 |
-
|
|
|
|
|
|
|
|
|
457 |
outputs = outputs[:, input_ids['input_ids'].shape[1]:]
|
458 |
elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
|
459 |
outputs = [
|
460 |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs)
|
461 |
]
|
462 |
|
463 |
-
|
464 |
if 'qwen2-vl' in self.model_id.lower():
|
465 |
result = self.processor.batch_decode(
|
466 |
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
467 |
)[0]
|
|
|
|
|
468 |
else:
|
469 |
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
470 |
|
|
|
215 |
{"role": "user", "content": user_prompt}] if 'gpt' in self.model_id
|
216 |
else [{"role": "user", "content": system_prompt + '\n' + user_prompt}],
|
217 |
temperature=0.0 if 'gpt' in self.model_id.lower() else 1.0, # fixed at 1 for o1 models
|
218 |
+
# max_completion_tokens=250 if 'gpt' in self.model_id.lower() else None, # not compatible with o1 series models
|
219 |
)
|
220 |
# print(response)
|
221 |
result = response.choices[0].message.content
|
222 |
print(result)
|
223 |
return result
|
224 |
|
225 |
+
elif 'grok' in self.model_id.lower(): # xai
|
226 |
+
XAI_API_KEY = os.getenv("XAI_API_KEY")
|
227 |
+
client = OpenAI(
|
228 |
+
api_key=XAI_API_KEY,
|
229 |
+
base_url="https://api.x.ai/v1",
|
230 |
+
)
|
231 |
+
|
232 |
+
completion = client.chat.completions.create(
|
233 |
+
model=self.model_id.split('/')[-1],
|
234 |
+
messages=[
|
235 |
+
{"role": "system", "content": system_prompt},
|
236 |
+
{"role": "user", "content": user_prompt},
|
237 |
+
],
|
238 |
+
temperature=0.0
|
239 |
+
)
|
240 |
+
result = completion.choices[0].message.content
|
241 |
+
print(result)
|
242 |
+
return result
|
243 |
+
|
244 |
elif 'gemini' in self.model_id.lower():
|
245 |
vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1")
|
|
|
246 |
model = GenerativeModel(
|
247 |
self.model_id.lower().split('google/')[-1],
|
248 |
system_instruction = [system_prompt]
|
|
|
307 |
return response
|
308 |
|
309 |
elif 'claude' in self.model_id.lower(): # using anthropic api
|
310 |
+
print('using Anthropic API')
|
311 |
client = anthropic.Anthropic()
|
312 |
message = client.messages.create(
|
313 |
model=self.model_id.split('/')[-1],
|
314 |
+
max_tokens=1024,
|
315 |
temperature=0,
|
316 |
system=system_prompt,
|
317 |
messages=[
|
318 |
{
|
319 |
"role": "user",
|
320 |
+
# "content": [
|
321 |
+
# {
|
322 |
+
# "type": "text",
|
323 |
+
# "text": user_prompt
|
324 |
+
# }
|
325 |
+
# ]
|
326 |
+
"content": user_prompt
|
327 |
}
|
328 |
]
|
329 |
)
|
|
|
331 |
print(result)
|
332 |
return result
|
333 |
|
334 |
+
elif 'command-r' in self.model_id.lower() or 'aya-expanse' in self.model_id.lower():
|
335 |
+
co = cohere.ClientV2(os.getenv('COHERE_API_TOKEN'))
|
336 |
response = co.chat(
|
337 |
+
model=self.model_id.split('/')[-1],
|
338 |
+
messages=[
|
339 |
+
{"role": "system", "content": system_prompt},
|
340 |
+
{"role": "user", "content": user_prompt}
|
341 |
],
|
342 |
+
temperature=0,
|
343 |
)
|
344 |
+
result = response.message.content[0].text
|
345 |
print(result)
|
346 |
return result
|
347 |
|
|
|
397 |
trust_remote_code=True
|
398 |
)
|
399 |
else:
|
400 |
+
if 'ragamuffin' in self.model_id.lower():
|
401 |
+
self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('/home/miaoran', self.model_id))
|
402 |
+
else:
|
403 |
+
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
|
404 |
print("Tokenizer loaded")
|
405 |
if 'jamba' in self.model_id.lower():
|
406 |
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id,
|
|
|
415 |
)
|
416 |
self.processor = AutoProcessor.from_pretrained(self.model_id)
|
417 |
|
418 |
+
# elif 'ragamuffin' in self.model_id.lower():
|
419 |
+
# print('Using ragamuffin')
|
420 |
+
# self.local_model = AutoModelForCausalLM.from_pretrained(os.path.join('/home/miaoran', self.model_id),
|
421 |
+
# torch_dtype=torch.bfloat16, # forcing bfloat16 for now
|
422 |
+
# attn_implementation="flash_attention_2")
|
423 |
+
|
424 |
else:
|
425 |
+
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto")#torch_dtype="auto"
|
426 |
# print(self.local_model.device)
|
427 |
print("Local model loaded")
|
428 |
|
|
|
450 |
# gemma-1.1, mistral-7b does not accept system role
|
451 |
{"role": "user", "content": system_prompt + '\n' + user_prompt}
|
452 |
]
|
453 |
+
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
454 |
|
455 |
elif 'phi-2' in self.model_id.lower():
|
456 |
prompt = system_prompt + '\n' + user_prompt
|
|
|
482 |
# print(prompt)
|
483 |
# print('-'*50)
|
484 |
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
485 |
+
if 'granite' in self.model_id.lower():
|
486 |
+
self.local_model.eval()
|
487 |
+
outputs = self.local_model.generate(**input_ids, max_new_tokens=250)
|
488 |
+
else:
|
489 |
+
with torch.no_grad():
|
490 |
+
outputs = self.local_model.generate(**input_ids, do_sample=True, max_new_tokens=250, temperature=0.01)#, pad_token_id=self.tokenizer.eos_token_id
|
491 |
+
if 'glm' in self.model_id.lower() or 'ragamuffin' in self.model_id.lower() or 'granite' in self.model_id.lower():
|
492 |
outputs = outputs[:, input_ids['input_ids'].shape[1]:]
|
493 |
elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
|
494 |
outputs = [
|
495 |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs)
|
496 |
]
|
497 |
|
|
|
498 |
if 'qwen2-vl' in self.model_id.lower():
|
499 |
result = self.processor.batch_decode(
|
500 |
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
501 |
)[0]
|
502 |
+
# elif 'granite' in self.model_id.lower():
|
503 |
+
# result = self.tokenizer.batch_decode(outputs)[0]
|
504 |
else:
|
505 |
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
506 |
|