pseudotensor commited on
Commit
6f2c714
1 Parent(s): 046eafc

cerebras and assert

Browse files
Files changed (2) hide show
  1. app.py +6 -4
  2. models.py +56 -8
app.py CHANGED
@@ -2,11 +2,12 @@ import os
2
 
3
  import streamlit as st
4
  import time
 
5
  try:
6
- from src.models import get_all_model_names
7
  from src.open_strawberry import get_defaults, manage_conversation
8
  except (ModuleNotFoundError, ImportError):
9
- from models import get_all_model_names
10
  from open_strawberry import get_defaults, manage_conversation
11
 
12
  (model, system_prompt, initial_prompt, expected_answer,
@@ -158,6 +159,7 @@ if 'secrets' not in st.session_state:
158
  GEMINI_API_KEY='',
159
  # MISTRAL_API_KEY='',
160
  GROQ_API_KEY='',
 
161
  ANTHROPIC_API_KEY='',
162
  )
163
 
@@ -166,7 +168,7 @@ if 'secrets' not in st.session_state:
166
 
167
 
168
  def update_model_selection():
169
- visible_models1 = get_all_model_names(st.session_state.secrets, on_hf_spaces)
170
  if visible_models1 and "model_name" in st.session_state:
171
  if st.session_state.model_name not in visible_models1:
172
  st.session_state.model_name = visible_models1[0]
@@ -177,7 +179,7 @@ if 'model_name' not in st.session_state or not st.session_state.model_name:
177
  update_model_selection()
178
 
179
  # Model selection
180
- visible_models = get_all_model_names(st.session_state.secrets, on_hf_spaces)
181
  st.sidebar.selectbox("Select Model", visible_models, key="model_name",
182
  disabled=st.session_state.conversation_started)
183
  st.sidebar.checkbox("Show Next", value=show_next, key="show_next", disabled=st.session_state.conversation_started)
 
2
 
3
  import streamlit as st
4
  import time
5
+
6
  try:
7
+ from src.models import get_model_names
8
  from src.open_strawberry import get_defaults, manage_conversation
9
  except (ModuleNotFoundError, ImportError):
10
+ from models import get_model_names
11
  from open_strawberry import get_defaults, manage_conversation
12
 
13
  (model, system_prompt, initial_prompt, expected_answer,
 
159
  GEMINI_API_KEY='',
160
  # MISTRAL_API_KEY='',
161
  GROQ_API_KEY='',
162
+ CEREBRAS_OPENAI_API_KEY='',
163
  ANTHROPIC_API_KEY='',
164
  )
165
 
 
168
 
169
 
170
  def update_model_selection():
171
+ visible_models1 = get_model_names(st.session_state.secrets, on_hf_spaces)
172
  if visible_models1 and "model_name" in st.session_state:
173
  if st.session_state.model_name not in visible_models1:
174
  st.session_state.model_name = visible_models1[0]
 
179
  update_model_selection()
180
 
181
  # Model selection
182
+ visible_models = get_model_names(st.session_state.secrets, on_hf_spaces)
183
  st.sidebar.selectbox("Select Model", visible_models, key="model_name",
184
  disabled=st.session_state.conversation_started)
185
  st.sidebar.checkbox("Show Next", value=show_next, key="show_next", disabled=st.session_state.conversation_started)
models.py CHANGED
@@ -349,6 +349,52 @@ def get_groq(model: str,
349
  yield dict(output_tokens=output_tokens, input_tokens=input_tokens)
350
 
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  def get_openai_azure(model: str,
353
  prompt: str,
354
  temperature: float = 0,
@@ -449,6 +495,10 @@ def get_model_names(secrets, on_hf_spaces=False):
449
  'mixtral-8x7b-32768']
450
  else:
451
  groq_models = []
 
 
 
 
452
  if secrets.get('OLLAMA_OPENAI_API_KEY'):
453
  ollama_model = os.environ['OLLAMA_OPENAI_MODEL_NAME']
454
  ollama_model = to_list(ollama_model)
@@ -456,22 +506,18 @@ def get_model_names(secrets, on_hf_spaces=False):
456
  ollama_model = []
457
 
458
  groq_models = ['groq:' + x for x in groq_models]
 
459
  azure_models = ['azure:' + x for x in azure_models]
460
  openai_models = ['openai:' + x for x in openai_models]
461
  google_models = ['google:' + x for x in google_models]
462
  anthropic_models = ['anthropic:' + x for x in anthropic_models]
463
  ollama = ['ollama:' + x if 'ollama:' not in x else x for x in ollama_model]
464
 
465
- return anthropic_models, openai_models, google_models, groq_models, azure_models, ollama
466
-
467
-
468
- def get_all_model_names(secrets, on_hf_spaces=False):
469
- anthropic_models, openai_models, google_models, groq_models, azure_models, ollama = get_model_names(secrets,
470
- on_hf_spaces=on_hf_spaces)
471
- return anthropic_models + openai_models + google_models + groq_models + azure_models + ollama
472
 
473
 
474
  def get_model_api(model: str):
 
475
  if model.startswith('anthropic:'):
476
  return get_anthropic
477
  elif model.startswith('openai:') or model.startswith('ollama:'):
@@ -480,8 +526,10 @@ def get_model_api(model: str):
480
  return get_google
481
  elif model.startswith('groq:'):
482
  return get_groq
 
 
483
  elif model.startswith('azure:'):
484
  return get_openai_azure
485
  else:
486
  raise ValueError(
487
- f"Unsupported model: {model}. Ensure to add prefix (e.g. openai:, google:, groq:, azure:, ollama:, anthropic:)")
 
349
  yield dict(output_tokens=output_tokens, input_tokens=input_tokens)
350
 
351
 
352
+ def get_cerebras(model: str,
353
+ prompt: str,
354
+ temperature: float = 0,
355
+ max_tokens: int = 4096,
356
+ system: str = '',
357
+ chat_history: List[Dict] = None,
358
+ secrets: Dict = {},
359
+ verbose=False) -> Generator[dict, None, None]:
360
+ # context_length is only 8207
361
+ model = model.replace('cerebras:', '')
362
+
363
+ from cerebras.cloud.sdk import Cerebras
364
+
365
+ api_key = secrets.get("CEREBRAS_OPENAI_API_KEY")
366
+ client = Cerebras(api_key=api_key)
367
+
368
+ if chat_history is None:
369
+ chat_history = []
370
+
371
+ chat_history = chat_history.copy()
372
+
373
+ messages = [{"role": "system", "content": system}] + chat_history + [{"role": "user", "content": prompt}]
374
+
375
+ stream = openai_completion_with_backoff(client,
376
+ messages=messages,
377
+ model=model,
378
+ temperature=temperature,
379
+ max_tokens=max_tokens,
380
+ stream=True,
381
+ )
382
+
383
+ output_tokens = 0
384
+ input_tokens = 0
385
+ for chunk in stream:
386
+ if chunk.choices[0].delta.content:
387
+ yield dict(text=chunk.choices[0].delta.content)
388
+ if chunk.usage:
389
+ output_tokens = chunk.usage.completion_tokens
390
+ input_tokens = chunk.usage.prompt_tokens
391
+
392
+ if verbose:
393
+ print(f"Output tokens: {output_tokens}")
394
+ print(f"Input tokens: {input_tokens}")
395
+ yield dict(output_tokens=output_tokens, input_tokens=input_tokens)
396
+
397
+
398
  def get_openai_azure(model: str,
399
  prompt: str,
400
  temperature: float = 0,
 
495
  'mixtral-8x7b-32768']
496
  else:
497
  groq_models = []
498
+ if secrets.get('CEREBRAS_OPENAI_API_KEY'):
499
+ cerebras_models = ['llama3.1-70b', 'llama3.1-8b']
500
+ else:
501
+ cerebras_models = []
502
  if secrets.get('OLLAMA_OPENAI_API_KEY'):
503
  ollama_model = os.environ['OLLAMA_OPENAI_MODEL_NAME']
504
  ollama_model = to_list(ollama_model)
 
506
  ollama_model = []
507
 
508
  groq_models = ['groq:' + x for x in groq_models]
509
+ cerebras_models = ['cerebras:' + x for x in cerebras_models]
510
  azure_models = ['azure:' + x for x in azure_models]
511
  openai_models = ['openai:' + x for x in openai_models]
512
  google_models = ['google:' + x for x in google_models]
513
  anthropic_models = ['anthropic:' + x for x in anthropic_models]
514
  ollama = ['ollama:' + x if 'ollama:' not in x else x for x in ollama_model]
515
 
516
+ return anthropic_models + openai_models + google_models + groq_models + cerebras_models + azure_models + ollama
 
 
 
 
 
 
517
 
518
 
519
  def get_model_api(model: str):
520
+ assert model not in ['', None], "Model not set, need to add API key to have models appear and select one."
521
  if model.startswith('anthropic:'):
522
  return get_anthropic
523
  elif model.startswith('openai:') or model.startswith('ollama:'):
 
526
  return get_google
527
  elif model.startswith('groq:'):
528
  return get_groq
529
+ elif model.startswith('cerebras:'):
530
+ return get_cerebras
531
  elif model.startswith('azure:'):
532
  return get_openai_azure
533
  else:
534
  raise ValueError(
535
+ f"Unsupported model: {model}. Ensure to add prefix (e.g. openai:, google:, groq:, cerebras:, azure:, ollama:, anthropic:)")