import pandas as pd import json import streamlit as st import shared_streamlit_funcs as my if "ld_num_ss_inputs" not in st.session_state: st.session_state["ld_num_ss_inputs"] = 1 def increment_ss_inputs(): st.session_state.ld_num_ss_inputs += 1 def decrement_ss_inputs(): st.session_state.ld_num_ss_inputs = max(1, st.session_state.ld_num_ss_inputs-1) def short_cg(cg): return {"Teaching, Guidance, and Counseling":"Teaching...", "Case Management":"Case Mngmnt", "Surveillance":"Surveillance", "Treatments and Procedures":"Treatments..."}[cg] def json_to_output_df(json_str, input_list): indata =json.loads(json_str) outdata = {"Output":[""]*len(input_list), "Explanation":[""]*len(input_list)} # Format is: {:{output:[{associated-item:{...}}], explanation:{tested-features:{...}}}} haserr = False try: # Process output for each op type for opname,opdata in indata.items(): # Process output for each input for response in opdata: # Process the output and explanation if "explanation" not in response or "output" not in response: continue ss_ind = input_list.index(response["explanation"]["tested-features"]["member-data"]["sign-symptom"][0]) outdata["Explanation"][ss_ind] = json.dumps(response["explanation"]["tested-features"]["member-data"]) outdata["Output"][ss_ind] = json.dumps(response["output"][0]["associated-item"]) except Exception as e: print("ERROR in LogicDemo json_to_output_df(): "+str(e)) haserr = True if haserr: retval = pd.DataFrame() else: retval = pd.DataFrame(data=outdata) return retval # Initialize the session if "agent" not in st.session_state: my.init() ## SET UP STREAMLIT PAGE # emojis: https://www.webfx.com/tools/emoji-cheat-sheet/ st.set_page_config(page_title="🧠CRL Demo", layout="wide") st.subheader("Cognitive Reasoner Lite Demo") st.title("Generalized Rule Logic") st.markdown("**Demonstrates teaching the agent a single rule that lets it respond to many inputs.**") ## Define S/S and intervention concepts ss_list = [ "Decreased Bowel Sounds", "Difficulty Providing Preventive and Therapeutic Health Care", "Limited Recall of Long Past Events", "Infection", "Heartburn/Belching/Indigestion", "Electrolyte Imbalance", "Difficulty Expressing Grief Responses", "Absent/Abnormal Response To Sound", "Minimal Shared Activities" ] intvn_list = [ ("Teaching, Guidance, and Counseling","Anatomy/Physiology","bowel function"), ("Case Management","Other Community Resources","long term care options"), ("Teaching, Guidance, and Counseling","Continuity of Care","simplified routine"), ("Teaching, Guidance, and Counseling","Wellness","prevention of infection/sepsis"), ("Surveillance","Signs/Symptoms-Physical","epigastric / heartburn pain or discomfort"), ("Surveillance","Signs/Symptoms-Physical","intake and output"), ("Case Management","Support Group","age/cultural/condition-specific groups"), ("Teaching, Guidance, and Counseling","Signs/Symptoms-Physical","increased hearing loss/other changes"), ("Teaching, Guidance, and Counseling","Behavioral Health Care","therapy to strengthen family support systems"), ] # Reset the agent before defining and linking concepts agent_config = my.make_agent() # Allow the user to choose how to map S/Ss to Interventions st.header("Training:") st.subheader("How do you want the agent to map symptoms to interventions?") map_xpnd = st.expander(label="Mappings",expanded=True) row = map_xpnd.container() map_col1, map_col2 = row.columns(2) map_col1.subheader("Symptom") map_col2.subheader("Intervention") intvn_labels = [short_cg(cg)+"; "+tg+"; "+cd for (cg, tg, cd) in intvn_list] # cd_list = [list(t) for t in zip(*intvn_list)][-1] # Transpose the list of tuples and convert to a list and get just the last list for ind,ss in enumerate(ss_list): row = map_xpnd.container() map_col1, map_col2 = row.columns(2) map_col1.text(ss) intvn_select = map_col2.selectbox(label="Maps to Intervention:",options=range(len(intvn_labels)),index=ind, key="mapbox-"+str(ind), format_func=lambda x: intvn_labels[x]) # Tell the agent to associate this S/S with this intvn ss_concept = st.session_state.agent.getConcept("{'member-data':{'sign-symptom':'"+ss+"'}}") cg,tg,cd = intvn_list[intvn_select] intvn_concept = st.session_state.agent.getConcept("{'intervention':{'category':'"+cg+"','target':'"+tg+"','care-descriptor':'"+cd+"'}}") st.session_state.agent.linkConcepts(agent_config.decisionTypeId, "SS-INTVN", ss_concept, intvn_concept) st.subheader("What do you want the agent to report?") select_report_attr = st.selectbox(label="Intervention element", options=["Category","Target","Care Descriptor", "All"], index=1) report_attr = {"Category":"category", "Target":"target", "Care Descriptor":"care-descriptor", "All":""}[select_report_attr] # Define action behavior to report result (triggered as soon as the intervention concept is active in WM) # Report just the active 'target-id' elements of the intervention associated with the matched condition intvn_conc = st.session_state.agent.getConcept("{'intervention':null}") st.session_state.agent.trainAction(agent_config, intvn_conc, my.ReportActiveConceptActionInList("associated-item", report_attr)) st.markdown("---") st.header("Input:") st.subheader("Choose a request to send to the agent.") if st.session_state.ld_num_ss_inputs > len(ss_list): st.session_state.ld_num_ss_inputs = len(ss_list) ss_input_select_list = [st.selectbox(label="Signs/Symptom:", options=ss_list, index=i, key="ss_in-"+str(i)) for i in range(st.session_state.ld_num_ss_inputs)] in_col1, in_col2 = st.columns(8)[0:2] in_col1.button(label="New Input", on_click=increment_ss_inputs, disabled=(st.session_state.ld_num_ss_inputs >= len(ss_list))) in_col2.button(label="Remove Input", on_click=decrement_ss_inputs, disabled=(st.session_state.ld_num_ss_inputs <= 1)) # em: —, en: – # Send a partial pattern to the agent's input st.session_state.agent.clearInput() for select in ss_input_select_list: st.session_state.agent.addInput("{'member-data':{'sign-symptom':'"+select+"'}}") st.markdown("---") st.header("Agent Output:") # Show the input to the user io_col1, io_col2 = st.columns(2) io_col1.text("Input sent to agent:") io_col1.dataframe(data={'Signs/Symptoms':ss_input_select_list}) io_col1.text_area(label="Raw JSON Input", value=st.session_state.agent.getInputAsJsonString(), height=200) # Run the agent with the given input to get a corresponding memory st.session_state.agent.setMaxOpCycles(-1) st.session_state.agent.queryDecision(agent_config.decisionTypeId, 5) output = st.session_state.agent.getOutputAsJsonString() query_time_ms = st.session_state.agent.getLastQueryTime()/1000000.0 io_col2.text("Agent Response: ("+str(query_time_ms)+" ms)") io_col2.dataframe(data=json_to_output_df(output, ss_input_select_list),) io_col2.text_area(label="Raw JSON Output:",value=output, height=500)