Miaoran000 commited on
Commit
411e7e6
1 Parent(s): 0856ae9

update summary generation for new models

Browse files
Files changed (1) hide show
  1. src/backend/model_operations.py +58 -22
src/backend/model_operations.py CHANGED
@@ -215,16 +215,34 @@ class SummaryGenerator:
215
  {"role": "user", "content": user_prompt}] if 'gpt' in self.model_id
216
  else [{"role": "user", "content": system_prompt + '\n' + user_prompt}],
217
  temperature=0.0 if 'gpt' in self.model_id.lower() else 1.0, # fixed at 1 for o1 models
218
- max_completion_tokens=250 if 'gpt' in self.model_id.lower() else None, # not compatible with o1 series models
219
  )
220
  # print(response)
221
  result = response.choices[0].message.content
222
  print(result)
223
  return result
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  elif 'gemini' in self.model_id.lower():
226
  vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1")
227
- gemini_model_id_map = {'gemini-1.5-pro-exp-0827':'gemini-pro-experimental', 'gemini-1.5-flash-exp-0827': 'gemini-flash-experimental'}
228
  model = GenerativeModel(
229
  self.model_id.lower().split('google/')[-1],
230
  system_instruction = [system_prompt]
@@ -289,21 +307,23 @@ class SummaryGenerator:
289
  return response
290
 
291
  elif 'claude' in self.model_id.lower(): # using anthropic api
 
292
  client = anthropic.Anthropic()
293
  message = client.messages.create(
294
  model=self.model_id.split('/')[-1],
295
- max_tokens=250,
296
  temperature=0,
297
  system=system_prompt,
298
  messages=[
299
  {
300
  "role": "user",
301
- "content": [
302
- {
303
- "type": "text",
304
- "text": user_prompt
305
- }
306
- ]
 
307
  }
308
  ]
309
  )
@@ -311,15 +331,17 @@ class SummaryGenerator:
311
  print(result)
312
  return result
313
 
314
- elif 'command-r' in self.model_id.lower():
315
- co = cohere.Client(os.getenv('COHERE_API_TOKEN'))
316
  response = co.chat(
317
- chat_history=[
318
- {"role": "SYSTEM", "message": system_prompt},
 
 
319
  ],
320
- message=user_prompt,
321
  )
322
- result = response.text
323
  print(result)
324
  return result
325
 
@@ -375,7 +397,10 @@ class SummaryGenerator:
375
  trust_remote_code=True
376
  )
377
  else:
378
- self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
 
 
 
379
  print("Tokenizer loaded")
380
  if 'jamba' in self.model_id.lower():
381
  self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id,
@@ -390,8 +415,14 @@ class SummaryGenerator:
390
  )
391
  self.processor = AutoProcessor.from_pretrained(self.model_id)
392
 
 
 
 
 
 
 
393
  else:
394
- self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto")
395
  # print(self.local_model.device)
396
  print("Local model loaded")
397
 
@@ -419,7 +450,7 @@ class SummaryGenerator:
419
  # gemma-1.1, mistral-7b does not accept system role
420
  {"role": "user", "content": system_prompt + '\n' + user_prompt}
421
  ]
422
- prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False)
423
 
424
  elif 'phi-2' in self.model_id.lower():
425
  prompt = system_prompt + '\n' + user_prompt
@@ -451,20 +482,25 @@ class SummaryGenerator:
451
  # print(prompt)
452
  # print('-'*50)
453
  input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
454
- with torch.no_grad():
455
- outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01, pad_token_id=self.tokenizer.eos_token_id)
456
- if 'glm' in self.model_id.lower():
 
 
 
 
457
  outputs = outputs[:, input_ids['input_ids'].shape[1]:]
458
  elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
459
  outputs = [
460
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs)
461
  ]
462
 
463
-
464
  if 'qwen2-vl' in self.model_id.lower():
465
  result = self.processor.batch_decode(
466
  outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
467
  )[0]
 
 
468
  else:
469
  result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
470
 
 
215
  {"role": "user", "content": user_prompt}] if 'gpt' in self.model_id
216
  else [{"role": "user", "content": system_prompt + '\n' + user_prompt}],
217
  temperature=0.0 if 'gpt' in self.model_id.lower() else 1.0, # fixed at 1 for o1 models
218
+ # max_completion_tokens=250 if 'gpt' in self.model_id.lower() else None, # not compatible with o1 series models
219
  )
220
  # print(response)
221
  result = response.choices[0].message.content
222
  print(result)
223
  return result
224
 
225
+ elif 'grok' in self.model_id.lower(): # xai
226
+ XAI_API_KEY = os.getenv("XAI_API_KEY")
227
+ client = OpenAI(
228
+ api_key=XAI_API_KEY,
229
+ base_url="https://api.x.ai/v1",
230
+ )
231
+
232
+ completion = client.chat.completions.create(
233
+ model=self.model_id.split('/')[-1],
234
+ messages=[
235
+ {"role": "system", "content": system_prompt},
236
+ {"role": "user", "content": user_prompt},
237
+ ],
238
+ temperature=0.0
239
+ )
240
+ result = completion.choices[0].message.content
241
+ print(result)
242
+ return result
243
+
244
  elif 'gemini' in self.model_id.lower():
245
  vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1")
 
246
  model = GenerativeModel(
247
  self.model_id.lower().split('google/')[-1],
248
  system_instruction = [system_prompt]
 
307
  return response
308
 
309
  elif 'claude' in self.model_id.lower(): # using anthropic api
310
+ print('using Anthropic API')
311
  client = anthropic.Anthropic()
312
  message = client.messages.create(
313
  model=self.model_id.split('/')[-1],
314
+ max_tokens=1024,
315
  temperature=0,
316
  system=system_prompt,
317
  messages=[
318
  {
319
  "role": "user",
320
+ # "content": [
321
+ # {
322
+ # "type": "text",
323
+ # "text": user_prompt
324
+ # }
325
+ # ]
326
+ "content": user_prompt
327
  }
328
  ]
329
  )
 
331
  print(result)
332
  return result
333
 
334
+ elif 'command-r' in self.model_id.lower() or 'aya-expanse' in self.model_id.lower():
335
+ co = cohere.ClientV2(os.getenv('COHERE_API_TOKEN'))
336
  response = co.chat(
337
+ model=self.model_id.split('/')[-1],
338
+ messages=[
339
+ {"role": "system", "content": system_prompt},
340
+ {"role": "user", "content": user_prompt}
341
  ],
342
+ temperature=0,
343
  )
344
+ result = response.message.content[0].text
345
  print(result)
346
  return result
347
 
 
397
  trust_remote_code=True
398
  )
399
  else:
400
+ if 'ragamuffin' in self.model_id.lower():
401
+ self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('/home/miaoran', self.model_id))
402
+ else:
403
+ self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
404
  print("Tokenizer loaded")
405
  if 'jamba' in self.model_id.lower():
406
  self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id,
 
415
  )
416
  self.processor = AutoProcessor.from_pretrained(self.model_id)
417
 
418
+ # elif 'ragamuffin' in self.model_id.lower():
419
+ # print('Using ragamuffin')
420
+ # self.local_model = AutoModelForCausalLM.from_pretrained(os.path.join('/home/miaoran', self.model_id),
421
+ # torch_dtype=torch.bfloat16, # forcing bfloat16 for now
422
+ # attn_implementation="flash_attention_2")
423
+
424
  else:
425
+ self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto")#torch_dtype="auto"
426
  # print(self.local_model.device)
427
  print("Local model loaded")
428
 
 
450
  # gemma-1.1, mistral-7b does not accept system role
451
  {"role": "user", "content": system_prompt + '\n' + user_prompt}
452
  ]
453
+ prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
454
 
455
  elif 'phi-2' in self.model_id.lower():
456
  prompt = system_prompt + '\n' + user_prompt
 
482
  # print(prompt)
483
  # print('-'*50)
484
  input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
485
+ if 'granite' in self.model_id.lower():
486
+ self.local_model.eval()
487
+ outputs = self.local_model.generate(**input_ids, max_new_tokens=250)
488
+ else:
489
+ with torch.no_grad():
490
+ outputs = self.local_model.generate(**input_ids, do_sample=True, max_new_tokens=250, temperature=0.01)#, pad_token_id=self.tokenizer.eos_token_id
491
+ if 'glm' in self.model_id.lower() or 'ragamuffin' in self.model_id.lower() or 'granite' in self.model_id.lower():
492
  outputs = outputs[:, input_ids['input_ids'].shape[1]:]
493
  elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
494
  outputs = [
495
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs)
496
  ]
497
 
 
498
  if 'qwen2-vl' in self.model_id.lower():
499
  result = self.processor.batch_decode(
500
  outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
501
  )[0]
502
+ # elif 'granite' in self.model_id.lower():
503
+ # result = self.tokenizer.batch_decode(outputs)[0]
504
  else:
505
  result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
506