|
import pandas as pd |
|
import json |
|
import streamlit as st |
|
|
|
import shared_streamlit_funcs as my |
|
|
|
if "ld_num_ss_inputs" not in st.session_state: |
|
st.session_state["ld_num_ss_inputs"] = 1 |
|
|
|
def increment_ss_inputs(): |
|
st.session_state.ld_num_ss_inputs += 1 |
|
def decrement_ss_inputs(): |
|
st.session_state.ld_num_ss_inputs = max(1, st.session_state.ld_num_ss_inputs-1) |
|
|
|
def short_cg(cg): |
|
return {"Teaching, Guidance, and Counseling":"Teaching...", |
|
"Case Management":"Case Mngmnt", |
|
"Surveillance":"Surveillance", |
|
"Treatments and Procedures":"Treatments..."}[cg] |
|
|
|
def json_to_output_df(json_str, input_list): |
|
indata =json.loads(json_str) |
|
outdata = {"Output":[""]*len(input_list), "Explanation":[""]*len(input_list)} |
|
|
|
haserr = False |
|
|
|
try: |
|
|
|
for opname,opdata in indata.items(): |
|
|
|
for response in opdata: |
|
|
|
if "explanation" not in response or "output" not in response: |
|
continue |
|
ss_ind = input_list.index(response["explanation"]["tested-features"]["member-data"]["sign-symptom"][0]) |
|
outdata["Explanation"][ss_ind] = json.dumps(response["explanation"]["tested-features"]["member-data"]) |
|
outdata["Output"][ss_ind] = json.dumps(response["output"][0]["associated-item"]) |
|
except Exception as e: |
|
print("ERROR in LogicDemo json_to_output_df(): "+str(e)) |
|
haserr = True |
|
|
|
if haserr: |
|
retval = pd.DataFrame() |
|
else: |
|
retval = pd.DataFrame(data=outdata) |
|
|
|
return retval |
|
|
|
|
|
|
|
if "agent" not in st.session_state: |
|
my.init() |
|
|
|
|
|
|
|
st.set_page_config(page_title="🧠CRL Demo", layout="wide") |
|
st.subheader("Cognitive Reasoner Lite Demo") |
|
st.title("Generalized Rule Logic") |
|
st.markdown("**Demonstrates teaching the agent a single rule that lets it respond to many inputs.**") |
|
|
|
|
|
|
|
ss_list = [ |
|
"Decreased Bowel Sounds", |
|
"Difficulty Providing Preventive and Therapeutic Health Care", |
|
"Limited Recall of Long Past Events", |
|
"Infection", |
|
"Heartburn/Belching/Indigestion", |
|
"Electrolyte Imbalance", |
|
"Difficulty Expressing Grief Responses", |
|
"Absent/Abnormal Response To Sound", |
|
"Minimal Shared Activities" |
|
] |
|
intvn_list = [ |
|
("Teaching, Guidance, and Counseling","Anatomy/Physiology","bowel function"), |
|
("Case Management","Other Community Resources","long term care options"), |
|
("Teaching, Guidance, and Counseling","Continuity of Care","simplified routine"), |
|
("Teaching, Guidance, and Counseling","Wellness","prevention of infection/sepsis"), |
|
("Surveillance","Signs/Symptoms-Physical","epigastric / heartburn pain or discomfort"), |
|
("Surveillance","Signs/Symptoms-Physical","intake and output"), |
|
("Case Management","Support Group","age/cultural/condition-specific groups"), |
|
("Teaching, Guidance, and Counseling","Signs/Symptoms-Physical","increased hearing loss/other changes"), |
|
("Teaching, Guidance, and Counseling","Behavioral Health Care","therapy to strengthen family support systems"), |
|
] |
|
|
|
|
|
agent_config = my.make_agent() |
|
|
|
|
|
st.header("Training:") |
|
st.subheader("How do you want the agent to map symptoms to interventions?") |
|
|
|
map_xpnd = st.expander(label="Mappings",expanded=True) |
|
|
|
row = map_xpnd.container() |
|
map_col1, map_col2 = row.columns(2) |
|
map_col1.subheader("Symptom") |
|
map_col2.subheader("Intervention") |
|
intvn_labels = [short_cg(cg)+"; "+tg+"; "+cd for (cg, tg, cd) in intvn_list] |
|
|
|
for ind,ss in enumerate(ss_list): |
|
row = map_xpnd.container() |
|
map_col1, map_col2 = row.columns(2) |
|
map_col1.text(ss) |
|
intvn_select = map_col2.selectbox(label="Maps to Intervention:",options=range(len(intvn_labels)),index=ind, key="mapbox-"+str(ind), format_func=lambda x: intvn_labels[x]) |
|
|
|
ss_concept = st.session_state.agent.getConcept("{'member-data':{'sign-symptom':'"+ss+"'}}") |
|
cg,tg,cd = intvn_list[intvn_select] |
|
intvn_concept = st.session_state.agent.getConcept("{'intervention':{'category':'"+cg+"','target':'"+tg+"','care-descriptor':'"+cd+"'}}") |
|
st.session_state.agent.linkConcepts(agent_config.decisionTypeId, "SS-INTVN", ss_concept, intvn_concept) |
|
|
|
st.subheader("What do you want the agent to report?") |
|
select_report_attr = st.selectbox(label="Intervention element", options=["Category","Target","Care Descriptor", "All"], index=1) |
|
report_attr = {"Category":"category", "Target":"target", "Care Descriptor":"care-descriptor", "All":""}[select_report_attr] |
|
|
|
|
|
|
|
intvn_conc = st.session_state.agent.getConcept("{'intervention':null}") |
|
st.session_state.agent.trainAction(agent_config, intvn_conc, my.ReportActiveConceptActionInList("associated-item", report_attr)) |
|
|
|
st.markdown("---") |
|
st.header("Input:") |
|
st.subheader("Choose a request to send to the agent.") |
|
|
|
if st.session_state.ld_num_ss_inputs > len(ss_list): |
|
st.session_state.ld_num_ss_inputs = len(ss_list) |
|
ss_input_select_list = [st.selectbox(label="Signs/Symptom:", options=ss_list, index=i, key="ss_in-"+str(i)) for i in range(st.session_state.ld_num_ss_inputs)] |
|
in_col1, in_col2 = st.columns(8)[0:2] |
|
in_col1.button(label="New Input", on_click=increment_ss_inputs, disabled=(st.session_state.ld_num_ss_inputs >= len(ss_list))) |
|
in_col2.button(label="Remove Input", on_click=decrement_ss_inputs, disabled=(st.session_state.ld_num_ss_inputs <= 1)) |
|
|
|
|
|
|
|
st.session_state.agent.clearInput() |
|
for select in ss_input_select_list: |
|
st.session_state.agent.addInput("{'member-data':{'sign-symptom':'"+select+"'}}") |
|
|
|
|
|
st.markdown("---") |
|
st.header("Agent Output:") |
|
|
|
io_col1, io_col2 = st.columns(2) |
|
io_col1.text("Input sent to agent:") |
|
io_col1.dataframe(data={'Signs/Symptoms':ss_input_select_list}) |
|
io_col1.text_area(label="Raw JSON Input", value=st.session_state.agent.getInputAsJsonString(), height=200) |
|
|
|
|
|
st.session_state.agent.setMaxOpCycles(-1) |
|
st.session_state.agent.queryDecision(agent_config.decisionTypeId, 5) |
|
|
|
output = st.session_state.agent.getOutputAsJsonString() |
|
query_time_ms = st.session_state.agent.getLastQueryTime()/1000000.0 |
|
io_col2.text("Agent Response: ("+str(query_time_ms)+" ms)") |
|
io_col2.dataframe(data=json_to_output_df(output, ss_input_select_list),) |
|
io_col2.text_area(label="Raw JSON Output:",value=output, height=500) |
|
|