maxschulz-COL commited on
Commit
1d14b94
1 Parent(s): 9fa94f5

Update to make ready for xAI

Browse files
Files changed (4) hide show
  1. actions.py +9 -1
  2. app.py +6 -1
  3. requirements.in +1 -1
  4. requirements.txt +1 -1
actions.py CHANGED
@@ -28,7 +28,12 @@ except ImportError:
28
  logger = logging.getLogger(__name__)
29
  logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled
30
 
31
- SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI}
 
 
 
 
 
32
 
33
  SUPPORTED_MODELS = {
34
  "OpenAI": [
@@ -43,6 +48,7 @@ SUPPORTED_MODELS = {
43
  "claude-3-haiku-20240307",
44
  ],
45
  "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
 
46
  }
47
  DEFAULT_TEMPERATURE = 0.1
48
  DEFAULT_RETRY = 3
@@ -62,6 +68,8 @@ def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):
62
  )
63
  if vendor_input == "Mistral":
64
  llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE)
 
 
65
 
66
  vizro_ai = VizroAI(model=llm)
67
  ai_outputs = vizro_ai.plot(df, user_prompt, max_debug_retry=DEFAULT_RETRY, return_elements=True)
 
28
  logger = logging.getLogger(__name__)
29
  logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled
30
 
31
+ SUPPORTED_VENDORS = {
32
+ "OpenAI": ChatOpenAI,
33
+ "Anthropic": ChatAnthropic,
34
+ "Mistral": ChatMistralAI,
35
+ "xAI": ChatOpenAI,
36
+ }
37
 
38
  SUPPORTED_MODELS = {
39
  "OpenAI": [
 
48
  "claude-3-haiku-20240307",
49
  ],
50
  "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
51
+ "xAI": ["grok-beta"],
52
  }
53
  DEFAULT_TEMPERATURE = 0.1
54
  DEFAULT_RETRY = 3
 
68
  )
69
  if vendor_input == "Mistral":
70
  llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE)
71
+ if vendor_input == "xAI":
72
+ llm = vendor(model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE)
73
 
74
  vizro_ai = VizroAI(model=llm)
75
  ai_outputs = vizro_ai.plot(df, user_prompt, max_debug_retry=DEFAULT_RETRY, return_elements=True)
app.py CHANGED
@@ -70,6 +70,7 @@ SUPPORTED_MODELS = {
70
  "claude-3-haiku-20240307",
71
  ],
72
  "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
 
73
  }
74
 
75
 
@@ -180,7 +181,11 @@ plot_page = vm.Page(
180
  MyDropdown(
181
  options=SUPPORTED_MODELS["OpenAI"], value="gpt-4o-mini", multi=False, id="model-dropdown-id"
182
  ),
183
- OffCanvas(id="settings", options=["OpenAI", "Anthropic", "Mistral"], value="OpenAI"),
 
 
 
 
184
  UserPromptTextArea(id="text-area-id"),
185
  # Modal(id="modal"),
186
  ],
 
70
  "claude-3-haiku-20240307",
71
  ],
72
  "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
73
+ "xAI": ["grok-beta"],
74
  }
75
 
76
 
 
181
  MyDropdown(
182
  options=SUPPORTED_MODELS["OpenAI"], value="gpt-4o-mini", multi=False, id="model-dropdown-id"
183
  ),
184
+ OffCanvas(
185
+ id="settings",
186
+ options=["OpenAI", "Anthropic", "Mistral", "xAI"],
187
+ value="OpenAI",
188
+ ),
189
  UserPromptTextArea(id="text-area-id"),
190
  # Modal(id="modal"),
191
  ],
requirements.in CHANGED
@@ -1,5 +1,5 @@
1
  gunicorn
2
- vizro-ai>=0.3.1
3
  black
4
  openpyxl
5
  langchain_anthropic
 
1
  gunicorn
2
+ vizro-ai>=0.3.2
3
  black
4
  openpyxl
5
  langchain_anthropic
requirements.txt CHANGED
@@ -266,7 +266,7 @@ urllib3==2.2.3
266
  # via requests
267
  vizro==0.1.23
268
  # via vizro-ai
269
- vizro-ai==0.3.1
270
  # via -r requirements.in
271
  werkzeug==3.0.4
272
  # via
 
266
  # via requests
267
  vizro==0.1.23
268
  # via vizro-ai
269
+ vizro-ai==0.3.2
270
  # via -r requirements.in
271
  werkzeug==3.0.4
272
  # via