chatdoctor / st_utils.py
GGYIMAH1031's picture
uploaded all single files
a59ab81 verified
"""Streamlit utils."""
from core.agent_builder.loader import (
load_meta_agent_and_tools,
AgentCacheRegistry,
)
from core.agent_builder.base import BaseRAGAgentBuilder
from core.param_cache import ParamCache
from core.constants import (
AGENT_CACHE_DIR,
)
from typing import Optional, cast
from pydantic import BaseModel
from llama_index.agent.types import BaseAgent
import streamlit as st
def update_selected_agent_with_id(selected_id: Optional[str] = None) -> None:
"""Update selected agent with id."""
# set session state
# set session state to the selected id which has been hardcoded
st.session_state.selected_id = (
#selected_id #if selected_id != "Create a new agent" else None ##ggyimah
selected_id if selected_id != "No Agent Selected" else None
)
# clear agent builder and builder agent
st.session_state.builder_agent = None
st.session_state.agent_builder = None
# clear selected cache
st.session_state.selected_cache = None
## handler for sidebar specifically
def update_selected_agent() -> None:
"""Update selected agent."""
selected_id = st.session_state.agent_selector ## ggyimah: Hard code agent id into this after printing it out
#selected_id = "Agent_950acb55-056f-4324-957d-15e1c9b48695" ##"agent_ids": ["Agent_950acb55-056f-4324-957d-15e1c9b48695"]
#print(selected_id)
update_selected_agent_with_id(selected_id)
def get_cached_is_multimodal() -> bool:
"""Get default multimodal st."""
if (
"selected_cache" not in st.session_state.keys()
or st.session_state.selected_cache is None
):
default_val = False
else:
selected_cache = cast(ParamCache, st.session_state.selected_cache)
default_val = True if selected_cache.builder_type == "multimodal" else False
return default_val
def get_is_multimodal() -> bool:
"""Get is multimodal."""
if "is_multimodal_st" not in st.session_state.keys():
st.session_state.is_multimodal_st = False
return st.session_state.is_multimodal_st
def add_builder_config() -> None:
"""Add builder config."""
with st.expander("Builder Config (Advanced)"):
# add a few options - openai api key, and
if (
"selected_cache" not in st.session_state.keys()
or st.session_state.selected_cache is None
):
is_locked = False
else:
is_locked = True
st.checkbox(
"Enable multimodal search (beta)",
key="is_multimodal_st",
on_change=update_selected_agent,
value=get_cached_is_multimodal(),
disabled=is_locked,
)
def add_sidebar() -> None:
"""Add sidebar."""
with st.sidebar:
agent_registry = cast(AgentCacheRegistry, st.session_state.agent_registry)
st.session_state.cur_agent_ids = agent_registry.get_agent_ids()
choices = ["No Agent Selected"] + st.session_state.cur_agent_ids ##ggyimah original script
#choices = st.session_state.cur_agent_ids
# by default, set index to 0. if value is in selected_id, set index to that
index = 0
#index = 0
#"""
if "selected_id" in st.session_state.keys():
if st.session_state.selected_id is not None:
index = choices.index(st.session_state.selected_id)
#"""
# display buttons
st.radio(
"Agents",
choices,
index=index,
on_change=update_selected_agent,
key="agent_selector",
)
class CurrentSessionState(BaseModel):
"""Current session state."""
# arbitrary types
class Config:
arbitrary_types_allowed = True
agent_registry: AgentCacheRegistry
selected_id: Optional[str]
selected_cache: Optional[ParamCache]
agent_builder: BaseRAGAgentBuilder
cache: ParamCache
builder_agent: BaseAgent
def get_current_state() -> CurrentSessionState:
"""Get current state.
This includes current state stored in session state and derived from it, e.g.
- agent registry
- selected agent
- selected cache
- agent builder
- builder agent
"""
"""
#ggyimah
# get agent registry
agent_registry = cast(AgentCacheRegistry, st.session_state.agent_registry)
agent_cache = agent_registry.get_agent_cache(st.session_state.selected_id)
st.session_state.selected_cache = agent_cache
# create builder agent / tools from selected cache
builder_agent, agent_builder = load_meta_agent_and_tools(
cache=st.session_state.selected_cache,
agent_registry=st.session_state.agent_registry,
# NOTE: we will probably generalize this later into different
# builder configs
is_multimodal=get_cached_is_multimodal(),
)
st.session_state.builder_agent = builder_agent
st.session_state.agent_builder = agent_builder
#ggyimah
"""
# get agent registry
agent_registry = AgentCacheRegistry(str(AGENT_CACHE_DIR))
if "agent_registry" not in st.session_state.keys():
st.session_state.agent_registry = agent_registry
if "cur_agent_ids" not in st.session_state.keys():
st.session_state.cur_agent_ids = agent_registry.get_agent_ids()
if "selected_id" not in st.session_state.keys():
st.session_state.selected_id = None
# set selected cache if doesn't exist
if (
"selected_cache" not in st.session_state.keys()
or st.session_state.selected_cache is None
):
# update selected cache
if st.session_state.selected_id is None:
st.session_state.selected_cache = None
else:
# load agent from directory
agent_registry = cast(AgentCacheRegistry, st.session_state.agent_registry)
agent_cache = agent_registry.get_agent_cache(st.session_state.selected_id)
st.session_state.selected_cache = agent_cache
# set builder agent / agent builder
if (
"builder_agent" not in st.session_state.keys()
or st.session_state.builder_agent is None
or "agent_builder" not in st.session_state.keys()
or st.session_state.agent_builder is None
):
if (
"selected_cache" in st.session_state.keys()
and st.session_state.selected_cache is not None
):
# create builder agent / tools from selected cache
builder_agent, agent_builder = load_meta_agent_and_tools(
cache=st.session_state.selected_cache,
agent_registry=st.session_state.agent_registry,
# NOTE: we will probably generalize this later into different
# builder configs
is_multimodal=get_cached_is_multimodal(),
)
else:
# create builder agent / tools from new cache
builder_agent, agent_builder = load_meta_agent_and_tools(
agent_registry=st.session_state.agent_registry,
is_multimodal=get_is_multimodal(),
)
st.session_state.builder_agent = builder_agent
st.session_state.agent_builder = agent_builder
#"""
#Print current state
#print("agent_registry:", st.session_state.agent_registry)
#print("selected_id:", st.session_state.selected_id)
#print("selected_cache:", st.session_state.selected_cache)
#print("agent_builder:", st.session_state.agent_builder)
#print("cache:", st.session_state.agent_builder.cache)
#print("builder_agent:", st.session_state.builder_agent)
return CurrentSessionState(
agent_registry=st.session_state.agent_registry,
selected_id=st.session_state.selected_id,
selected_cache=st.session_state.selected_cache,
agent_builder=st.session_state.agent_builder,
cache=st.session_state.agent_builder.cache,
builder_agent=st.session_state.builder_agent,
)