crl-demo / Logic_Demo.py
bryan-stearns
Fixing add/remove button text/enabling
8895767
import pandas as pd
import json
import streamlit as st
import shared_streamlit_funcs as my
if "ld_num_ss_inputs" not in st.session_state:
st.session_state["ld_num_ss_inputs"] = 1
def increment_ss_inputs():
st.session_state.ld_num_ss_inputs += 1
def decrement_ss_inputs():
st.session_state.ld_num_ss_inputs = max(1, st.session_state.ld_num_ss_inputs-1)
def short_cg(cg):
return {"Teaching, Guidance, and Counseling":"Teaching...",
"Case Management":"Case Mngmnt",
"Surveillance":"Surveillance",
"Treatments and Procedures":"Treatments..."}[cg]
def json_to_output_df(json_str, input_list):
indata =json.loads(json_str)
outdata = {"Output":[""]*len(input_list), "Explanation":[""]*len(input_list)}
# Format is: {<op-name>:{output:[{associated-item:{...}}], explanation:{tested-features:{...}}}}
haserr = False
try:
# Process output for each op type
for opname,opdata in indata.items():
# Process output for each input
for response in opdata:
# Process the output and explanation
if "explanation" not in response or "output" not in response:
continue
ss_ind = input_list.index(response["explanation"]["tested-features"]["member-data"]["sign-symptom"][0])
outdata["Explanation"][ss_ind] = json.dumps(response["explanation"]["tested-features"]["member-data"])
outdata["Output"][ss_ind] = json.dumps(response["output"][0]["associated-item"])
except Exception as e:
print("ERROR in LogicDemo json_to_output_df(): "+str(e))
haserr = True
if haserr:
retval = pd.DataFrame()
else:
retval = pd.DataFrame(data=outdata)
return retval
# Initialize the session
if "agent" not in st.session_state:
my.init()
## SET UP STREAMLIT PAGE
# emojis: https://www.webfx.com/tools/emoji-cheat-sheet/
st.set_page_config(page_title="🧠CRL Demo", layout="wide")
st.subheader("Cognitive Reasoner Lite Demo")
st.title("Generalized Rule Logic")
st.markdown("**Demonstrates teaching the agent a single rule that lets it respond to many inputs.**")
## Define S/S and intervention concepts
ss_list = [
"Decreased Bowel Sounds",
"Difficulty Providing Preventive and Therapeutic Health Care",
"Limited Recall of Long Past Events",
"Infection",
"Heartburn/Belching/Indigestion",
"Electrolyte Imbalance",
"Difficulty Expressing Grief Responses",
"Absent/Abnormal Response To Sound",
"Minimal Shared Activities"
]
intvn_list = [
("Teaching, Guidance, and Counseling","Anatomy/Physiology","bowel function"),
("Case Management","Other Community Resources","long term care options"),
("Teaching, Guidance, and Counseling","Continuity of Care","simplified routine"),
("Teaching, Guidance, and Counseling","Wellness","prevention of infection/sepsis"),
("Surveillance","Signs/Symptoms-Physical","epigastric / heartburn pain or discomfort"),
("Surveillance","Signs/Symptoms-Physical","intake and output"),
("Case Management","Support Group","age/cultural/condition-specific groups"),
("Teaching, Guidance, and Counseling","Signs/Symptoms-Physical","increased hearing loss/other changes"),
("Teaching, Guidance, and Counseling","Behavioral Health Care","therapy to strengthen family support systems"),
]
# Reset the agent before defining and linking concepts
agent_config = my.make_agent()
# Allow the user to choose how to map S/Ss to Interventions
st.header("Training:")
st.subheader("How do you want the agent to map symptoms to interventions?")
map_xpnd = st.expander(label="Mappings",expanded=True)
row = map_xpnd.container()
map_col1, map_col2 = row.columns(2)
map_col1.subheader("Symptom")
map_col2.subheader("Intervention")
intvn_labels = [short_cg(cg)+"; "+tg+"; "+cd for (cg, tg, cd) in intvn_list]
# cd_list = [list(t) for t in zip(*intvn_list)][-1] # Transpose the list of tuples and convert to a list and get just the last list
for ind,ss in enumerate(ss_list):
row = map_xpnd.container()
map_col1, map_col2 = row.columns(2)
map_col1.text(ss)
intvn_select = map_col2.selectbox(label="Maps to Intervention:",options=range(len(intvn_labels)),index=ind, key="mapbox-"+str(ind), format_func=lambda x: intvn_labels[x])
# Tell the agent to associate this S/S with this intvn
ss_concept = st.session_state.agent.getConcept("{'member-data':{'sign-symptom':'"+ss+"'}}")
cg,tg,cd = intvn_list[intvn_select]
intvn_concept = st.session_state.agent.getConcept("{'intervention':{'category':'"+cg+"','target':'"+tg+"','care-descriptor':'"+cd+"'}}")
st.session_state.agent.linkConcepts(agent_config.decisionTypeId, "SS-INTVN", ss_concept, intvn_concept)
st.subheader("What do you want the agent to report?")
select_report_attr = st.selectbox(label="Intervention element", options=["Category","Target","Care Descriptor", "All"], index=1)
report_attr = {"Category":"category", "Target":"target", "Care Descriptor":"care-descriptor", "All":""}[select_report_attr]
# Define action behavior to report result (triggered as soon as the intervention concept is active in WM)
# Report just the active 'target-id' elements of the intervention associated with the matched condition
intvn_conc = st.session_state.agent.getConcept("{'intervention':null}")
st.session_state.agent.trainAction(agent_config, intvn_conc, my.ReportActiveConceptActionInList("associated-item", report_attr))
st.markdown("---")
st.header("Input:")
st.subheader("Choose a request to send to the agent.")
if st.session_state.ld_num_ss_inputs > len(ss_list):
st.session_state.ld_num_ss_inputs = len(ss_list)
ss_input_select_list = [st.selectbox(label="Signs/Symptom:", options=ss_list, index=i, key="ss_in-"+str(i)) for i in range(st.session_state.ld_num_ss_inputs)]
in_col1, in_col2 = st.columns(8)[0:2]
in_col1.button(label="New Input", on_click=increment_ss_inputs, disabled=(st.session_state.ld_num_ss_inputs >= len(ss_list)))
in_col2.button(label="Remove Input", on_click=decrement_ss_inputs, disabled=(st.session_state.ld_num_ss_inputs <= 1)) # em: —, en: –
# Send a partial pattern to the agent's input
st.session_state.agent.clearInput()
for select in ss_input_select_list:
st.session_state.agent.addInput("{'member-data':{'sign-symptom':'"+select+"'}}")
st.markdown("---")
st.header("Agent Output:")
# Show the input to the user
io_col1, io_col2 = st.columns(2)
io_col1.text("Input sent to agent:")
io_col1.dataframe(data={'Signs/Symptoms':ss_input_select_list})
io_col1.text_area(label="Raw JSON Input", value=st.session_state.agent.getInputAsJsonString(), height=200)
# Run the agent with the given input to get a corresponding memory
st.session_state.agent.setMaxOpCycles(-1)
st.session_state.agent.queryDecision(agent_config.decisionTypeId, 5)
output = st.session_state.agent.getOutputAsJsonString()
query_time_ms = st.session_state.agent.getLastQueryTime()/1000000.0
io_col2.text("Agent Response: ("+str(query_time_ms)+" ms)")
io_col2.dataframe(data=json_to_output_df(output, ss_input_select_list),)
io_col2.text_area(label="Raw JSON Output:",value=output, height=500)