pseudotensor commited on
Commit
046eafc
1 Parent(s): 56980e9

fix secrets handling

Browse files
Files changed (4) hide show
  1. app.py +4 -0
  2. cli.py +2 -0
  3. models.py +21 -17
  4. open_strawberry.py +5 -2
app.py CHANGED
@@ -50,6 +50,8 @@ if "verbose" not in st.session_state:
50
  st.session_state.verbose = verbose
51
  if "max_tokens" not in st.session_state:
52
  st.session_state.max_tokens = max_tokens
 
 
53
  if "temperature" not in st.session_state:
54
  st.session_state.temperature = temperature
55
  if "next_prompts" not in st.session_state:
@@ -272,6 +274,8 @@ try:
272
  num_turns=st.session_state.num_turns,
273
  temperature=st.session_state.temperature,
274
  max_tokens=st.session_state.max_tokens,
 
 
275
  verbose=st.session_state.verbose,
276
  )
277
  chunk = next(st.session_state.generator)
 
50
  st.session_state.verbose = verbose
51
  if "max_tokens" not in st.session_state:
52
  st.session_state.max_tokens = max_tokens
53
+ if "seed" not in st.session_state:
54
+ st.session_state.seed = 0
55
  if "temperature" not in st.session_state:
56
  st.session_state.temperature = temperature
57
  if "next_prompts" not in st.session_state:
 
274
  num_turns=st.session_state.num_turns,
275
  temperature=st.session_state.temperature,
276
  max_tokens=st.session_state.max_tokens,
277
+ seed=st.session_state.seed,
278
+ secrets=st.session_state.secrets,
279
  verbose=st.session_state.verbose,
280
  )
281
  chunk = next(st.session_state.generator)
cli.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
  import time
3
 
4
  from src.open_strawberry import get_defaults, manage_conversation
@@ -54,6 +55,7 @@ def go_cli():
54
  temperature=args.temperature,
55
  max_tokens=args.max_tokens,
56
  seed=args.seed,
 
57
  cli_mode=True)
58
  response = ''
59
  conversation_history = []
 
1
  import argparse
2
+ import os
3
  import time
4
 
5
  from src.open_strawberry import get_defaults, manage_conversation
 
55
  temperature=args.temperature,
56
  max_tokens=args.max_tokens,
57
  seed=args.seed,
58
+ secrets=dict(os.environ),
59
  cli_mode=True)
60
  response = ''
61
  conversation_history = []
models.py CHANGED
@@ -25,6 +25,7 @@ def get_anthropic(model: str,
25
  max_tokens: int = 4096,
26
  system: str = '',
27
  chat_history: List[Dict] = None,
 
28
  verbose=False) -> \
29
  Generator[dict, None, None]:
30
  model = model.replace('anthropic:', '')
@@ -32,7 +33,7 @@ def get_anthropic(model: str,
32
  # https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
33
  import anthropic
34
 
35
- clawd_key = os.getenv('ANTHROPIC_API_KEY')
36
  clawd_client = anthropic.Anthropic(api_key=clawd_key) if clawd_key else None
37
 
38
  if chat_history is None:
@@ -118,16 +119,16 @@ def get_openai(model: str,
118
  max_tokens: int = 4096,
119
  system: str = '',
120
  chat_history: List[Dict] = None,
 
121
  verbose=False) -> Generator[dict, None, None]:
122
- anthropic_models, openai_models, google_models, groq_models, azure_models, ollama = get_model_names()
123
- if model in ollama:
124
  model = model.replace('ollama:', '')
125
- openai_key = os.getenv('OLLAMA_OPENAI_API_KEY')
126
- openai_base_url = os.getenv('OLLAMA_OPENAI_BASE_URL', 'http://localhost:11434/v1/')
127
  else:
128
  model = model.replace('openai:', '')
129
- openai_key = os.getenv('OPENAI_API_KEY')
130
- openai_base_url = os.getenv('OPENAI_BASE_URL', 'https://api.openai.com/v1')
131
 
132
  from openai import OpenAI
133
 
@@ -206,12 +207,13 @@ def get_google(model: str,
206
  max_tokens: int = 4096,
207
  system: str = '',
208
  chat_history: List[Dict] = None,
 
209
  verbose=False) -> Generator[dict, None, None]:
210
  model = model.replace('google:', '').replace('gemini:', '')
211
 
212
  import google.generativeai as genai
213
 
214
- gemini_key = os.getenv("GEMINI_API_KEY")
215
  genai.configure(api_key=gemini_key)
216
  # Create the model
217
  generation_config = {
@@ -308,12 +310,13 @@ def get_groq(model: str,
308
  max_tokens: int = 4096,
309
  system: str = '',
310
  chat_history: List[Dict] = None,
 
311
  verbose=False) -> Generator[dict, None, None]:
312
  model = model.replace('groq:', '')
313
 
314
  from groq import Groq
315
 
316
- groq_key = os.getenv("GROQ_API_KEY")
317
  client = Groq(api_key=groq_key)
318
 
319
  if chat_history is None:
@@ -352,15 +355,16 @@ def get_openai_azure(model: str,
352
  max_tokens: int = 4096,
353
  system: str = '',
354
  chat_history: List[Dict] = None,
 
355
  verbose=False) -> Generator[dict, None, None]:
356
  model = model.replace('azure:', '').replace('openai_azure:', '')
357
 
358
  from openai import AzureOpenAI
359
 
360
- azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") # e.g. https://project.openai.azure.com
361
- azure_key = os.getenv("AZURE_OPENAI_API_KEY")
362
- azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") # i.e. deployment name with some models deployed
363
- azure_api_version = os.getenv('AZURE_OPENAI_API_VERSION', '2024-07-01-preview')
364
  assert azure_endpoint is not None, "Azure OpenAI endpoint not set"
365
  assert azure_key is not None, "Azure OpenAI API key not set"
366
  assert azure_deployment is not None, "Azure OpenAI deployment not set"
@@ -420,15 +424,15 @@ def get_model_names(secrets, on_hf_spaces=False):
420
  else:
421
  anthropic_models = []
422
  if secrets.get('OPENAI_API_KEY'):
423
- if os.getenv('OPENAI_MODEL_NAME'):
424
- openai_models = to_list(os.getenv('OPENAI_MODEL_NAME'))
425
  else:
426
  openai_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
427
  else:
428
  openai_models = []
429
  if secrets.get('AZURE_OPENAI_API_KEY'):
430
- if os.getenv('AZURE_OPENAI_MODEL_NAME'):
431
- azure_models = to_list(os.getenv('AZURE_OPENAI_MODEL_NAME'))
432
  else:
433
  azure_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
434
  else:
 
25
  max_tokens: int = 4096,
26
  system: str = '',
27
  chat_history: List[Dict] = None,
28
+ secrets: Dict = {},
29
  verbose=False) -> \
30
  Generator[dict, None, None]:
31
  model = model.replace('anthropic:', '')
 
33
  # https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
34
  import anthropic
35
 
36
+ clawd_key = secrets.get('ANTHROPIC_API_KEY')
37
  clawd_client = anthropic.Anthropic(api_key=clawd_key) if clawd_key else None
38
 
39
  if chat_history is None:
 
119
  max_tokens: int = 4096,
120
  system: str = '',
121
  chat_history: List[Dict] = None,
122
+ secrets: Dict = {},
123
  verbose=False) -> Generator[dict, None, None]:
124
+ if model.startswith('ollama:'):
 
125
  model = model.replace('ollama:', '')
126
+ openai_key = secrets.get('OLLAMA_OPENAI_API_KEY')
127
+ openai_base_url = secrets.get('OLLAMA_OPENAI_BASE_URL', 'http://localhost:11434/v1/')
128
  else:
129
  model = model.replace('openai:', '')
130
+ openai_key = secrets.get('OPENAI_API_KEY')
131
+ openai_base_url = secrets.get('OPENAI_BASE_URL', 'https://api.openai.com/v1')
132
 
133
  from openai import OpenAI
134
 
 
207
  max_tokens: int = 4096,
208
  system: str = '',
209
  chat_history: List[Dict] = None,
210
+ secrets: Dict = {},
211
  verbose=False) -> Generator[dict, None, None]:
212
  model = model.replace('google:', '').replace('gemini:', '')
213
 
214
  import google.generativeai as genai
215
 
216
+ gemini_key = secrets.get("GEMINI_API_KEY")
217
  genai.configure(api_key=gemini_key)
218
  # Create the model
219
  generation_config = {
 
310
  max_tokens: int = 4096,
311
  system: str = '',
312
  chat_history: List[Dict] = None,
313
+ secrets: Dict = {},
314
  verbose=False) -> Generator[dict, None, None]:
315
  model = model.replace('groq:', '')
316
 
317
  from groq import Groq
318
 
319
+ groq_key = secrets.get("GROQ_API_KEY")
320
  client = Groq(api_key=groq_key)
321
 
322
  if chat_history is None:
 
355
  max_tokens: int = 4096,
356
  system: str = '',
357
  chat_history: List[Dict] = None,
358
+ secrets: Dict = {},
359
  verbose=False) -> Generator[dict, None, None]:
360
  model = model.replace('azure:', '').replace('openai_azure:', '')
361
 
362
  from openai import AzureOpenAI
363
 
364
+ azure_endpoint = secrets.get("AZURE_OPENAI_ENDPOINT") # e.g. https://project.openai.azure.com
365
+ azure_key = secrets.get("AZURE_OPENAI_API_KEY")
366
+ azure_deployment = secrets.get("AZURE_OPENAI_DEPLOYMENT") # i.e. deployment name with some models deployed
367
+ azure_api_version = secrets.get('AZURE_OPENAI_API_VERSION', '2024-07-01-preview')
368
  assert azure_endpoint is not None, "Azure OpenAI endpoint not set"
369
  assert azure_key is not None, "Azure OpenAI API key not set"
370
  assert azure_deployment is not None, "Azure OpenAI deployment not set"
 
424
  else:
425
  anthropic_models = []
426
  if secrets.get('OPENAI_API_KEY'):
427
+ if secrets.get('OPENAI_MODEL_NAME'):
428
+ openai_models = to_list(secrets.get('OPENAI_MODEL_NAME'))
429
  else:
430
  openai_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
431
  else:
432
  openai_models = []
433
  if secrets.get('AZURE_OPENAI_API_KEY'):
434
+ if secrets.get('AZURE_OPENAI_MODEL_NAME'):
435
+ azure_models = to_list(secrets.get('AZURE_OPENAI_MODEL_NAME'))
436
  else:
437
  azure_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
438
  else:
open_strawberry.py CHANGED
@@ -290,7 +290,8 @@ def manage_conversation(model: str,
290
  temperature: float = 0.3,
291
  max_tokens: int = 4096,
292
  seed: int = 1234,
293
- verbose: bool = False
 
294
  ) -> Generator[Dict, None, list]:
295
  if seed == 0:
296
  seed = random.randint(0, 1000000)
@@ -344,7 +345,9 @@ def manage_conversation(model: str,
344
  thinking_time = time.time()
345
  response_text = ''
346
  for chunk in get_model_func(model, prompt, system=system, chat_history=chat_history,
347
- temperature=temperature, max_tokens=max_tokens, verbose=verbose):
 
 
348
  if 'text' in chunk and chunk['text']:
349
  response_text += chunk['text']
350
  yield {"role": "assistant", "content": chunk['text'], "streaming": True, "chat_history": chat_history,
 
290
  temperature: float = 0.3,
291
  max_tokens: int = 4096,
292
  seed: int = 1234,
293
+ secrets: Dict = {},
294
+ verbose: bool = False,
295
  ) -> Generator[Dict, None, list]:
296
  if seed == 0:
297
  seed = random.randint(0, 1000000)
 
345
  thinking_time = time.time()
346
  response_text = ''
347
  for chunk in get_model_func(model, prompt, system=system, chat_history=chat_history,
348
+ temperature=temperature, max_tokens=max_tokens,
349
+ secrets=secrets,
350
+ verbose=verbose):
351
  if 'text' in chunk and chunk['text']:
352
  response_text += chunk['text']
353
  yield {"role": "assistant", "content": chunk['text'], "streaming": True, "chat_history": chat_history,