bryan-stearns commited on
Commit
437883c
1 Parent(s): 15e67ef

Adding demo files

Browse files
Logic_Demo.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import streamlit as st
4
+
5
+ import shared_streamlit_funcs as my
6
+
7
+ if "ld_num_ss_inputs" not in st.session_state:
8
+ st.session_state["ld_num_ss_inputs"] = 1
9
+
10
+ def increment_ss_inputs():
11
+ st.session_state.ld_num_ss_inputs += 1
12
+ def decrement_ss_inputs():
13
+ st.session_state.ld_num_ss_inputs = max(1, st.session_state.ld_num_ss_inputs-1)
14
+
15
+ def short_cg(cg):
16
+ return {"Teaching, Guidance, and Counseling":"Teaching...",
17
+ "Case Management":"Case Mngmnt",
18
+ "Surveillance":"Surveillance",
19
+ "Treatments and Procedures":"Treatments..."}[cg]
20
+
21
+ def json_to_output_df(json_str, input_list):
22
+ indata =json.loads(json_str)
23
+ outdata = {"Output":[""]*len(input_list), "Explanation":[""]*len(input_list)}
24
+ # Format is: {<op-name>:{output:[{associated-item:{...}}], explanation:{tested-features:{...}}}}
25
+ haserr = False
26
+
27
+ try:
28
+ # Process output for each op type
29
+ for opname,opdata in indata.items():
30
+ # Process output for each input
31
+ for response in opdata:
32
+ # Process the output and explanation
33
+ if "explanation" not in response or "output" not in response:
34
+ continue
35
+ ss_ind = input_list.index(response["explanation"]["tested-features"]["member-data"]["sign-symptom"][0])
36
+ outdata["Explanation"][ss_ind] = json.dumps(response["explanation"]["tested-features"]["member-data"])
37
+ outdata["Output"][ss_ind] = json.dumps(response["output"][0]["associated-item"])
38
+ except Exception as e:
39
+ print("ERROR in LogicDemo json_to_output_df(): "+str(e))
40
+ haserr = True
41
+
42
+ if haserr:
43
+ retval = pd.DataFrame()
44
+ else:
45
+ retval = pd.DataFrame(data=outdata)
46
+
47
+ print(retval)
48
+ return retval
49
+
50
+
51
+ # Initialize the session
52
+ if "agent" not in st.session_state:
53
+ my.init()
54
+
55
+ ## SET UP STREAMLIT PAGE
56
+ # emojis: https://www.webfx.com/tools/emoji-cheat-sheet/
57
+ st.set_page_config(page_title="🧠CRL Demo", layout="wide")
58
+ st.subheader("Cognitive Reasoner Lite Demo")
59
+ st.title("Generalized Rule Logic")
60
+ st.markdown("**Demonstrates teaching the agent a single rule that lets it respond to many inputs.**")
61
+
62
+
63
+ ## Define S/S and intervention concepts
64
+ ss_list = [
65
+ "Decreased Bowel Sounds",
66
+ "Difficulty Providing Preventive and Therapeutic Health Care",
67
+ "Limited Recall of Long Past Events",
68
+ "Infection",
69
+ "Heartburn/Belching/Indigestion",
70
+ "Electrolyte Imbalance",
71
+ "Difficulty Expressing Grief Responses",
72
+ "Absent/Abnormal Response To Sound",
73
+ "Minimal Shared Activities"
74
+ ]
75
+ intvn_list = [
76
+ ("Teaching, Guidance, and Counseling","Anatomy/Physiology","bowel function"),
77
+ ("Case Management","Other Community Resources","long term care options"),
78
+ ("Teaching, Guidance, and Counseling","Continuity of Care","simplified routine"),
79
+ ("Teaching, Guidance, and Counseling","Wellness","prevention of infection/sepsis"),
80
+ ("Surveillance","Signs/Symptoms-Physical","epigastric / heartburn pain or discomfort"),
81
+ ("Surveillance","Signs/Symptoms-Physical","intake and output"),
82
+ ("Case Management","Support Group","age/cultural/condition-specific groups"),
83
+ ("Teaching, Guidance, and Counseling","Signs/Symptoms-Physical","increased hearing loss/other changes"),
84
+ ("Teaching, Guidance, and Counseling","Behavioral Health Care","therapy to strengthen family support systems"),
85
+ ]
86
+
87
+ # Reset the agent before defining and linking concepts
88
+ agent_config = my.make_agent()
89
+
90
+ # Allow the user to choose how to map S/Ss to Interventions
91
+ st.header("Training:")
92
+ st.subheader("How do you want the agent to map symptoms to interventions?")
93
+
94
+ map_xpnd = st.expander(label="Mappings",expanded=True)
95
+
96
+ row = map_xpnd.container()
97
+ map_col1, map_col2 = row.columns(2)
98
+ map_col1.subheader("Symptom")
99
+ map_col2.subheader("Intervention")
100
+ intvn_labels = [short_cg(cg)+"; "+tg+"; "+cd for (cg, tg, cd) in intvn_list]
101
+ # cd_list = [list(t) for t in zip(*intvn_list)][-1] # Transpose the list of tuples and convert to a list and get just the last list
102
+ for ind,ss in enumerate(ss_list):
103
+ row = map_xpnd.container()
104
+ map_col1, map_col2 = row.columns(2)
105
+ map_col1.text(ss)
106
+ intvn_select = map_col2.selectbox(label="Maps to Intervention:",options=range(len(intvn_labels)),index=ind, key="mapbox-"+str(ind), format_func=lambda x: intvn_labels[x])
107
+ # Tell the agent to associate this S/S with this intvn
108
+ ss_concept = st.session_state.agent.getConcept("{'member-data':{'sign-symptom':'"+ss+"'}}")
109
+ cg,tg,cd = intvn_list[intvn_select]
110
+ intvn_concept = st.session_state.agent.getConcept("{'intervention':{'category':'"+cg+"','target':'"+tg+"','care-descriptor':'"+cd+"'}}")
111
+ st.session_state.agent.linkConcepts(agent_config.decisionTypeId, "SS-INTVN", ss_concept, intvn_concept)
112
+
113
+ st.subheader("What do you want the agent to report?")
114
+ select_report_attr = st.selectbox(label="Intervention element", options=["Category","Target","Care Descriptor", "All"], index=1)
115
+ report_attr = {"Category":"category", "Target":"target", "Care Descriptor":"care-descriptor", "All":""}[select_report_attr]
116
+
117
+ # Define action behavior to report result (triggered as soon as the intervention concept is active in WM)
118
+ # Report just the active 'target-id' elements of the intervention associated with the matched condition
119
+ intvn_conc = st.session_state.agent.getConcept("{'intervention':null}")
120
+ st.session_state.agent.trainAction(agent_config, intvn_conc, my.ReportActiveConceptActionInList("associated-item", report_attr))
121
+
122
+ st.markdown("---")
123
+ st.header("Input:")
124
+ st.subheader("Choose a request to send to the agent.")
125
+
126
+ ss_input_select_list = [st.selectbox(label="Signs/Symptom:", options=ss_list, index=i, key="ss_in-"+str(i)) for i in range(st.session_state.ld_num_ss_inputs)]
127
+ in_col1, in_col2 = st.columns(16)[0:2]
128
+ in_col1.button(label="+", on_click=increment_ss_inputs)
129
+ in_col2.button(label="–", on_click=decrement_ss_inputs) # em: —, en: –
130
+
131
+
132
+ # Send a partial pattern to the agent's input
133
+ st.session_state.agent.clearInput()
134
+ for select in ss_input_select_list:
135
+ st.session_state.agent.addInput("{'member-data':{'sign-symptom':'"+select+"'}}")
136
+
137
+
138
+ st.markdown("---")
139
+ st.header("Agent Output:")
140
+ # Show the input to the user
141
+ io_col1, io_col2 = st.columns(2)
142
+ io_col1.text("Input sent to agent:")
143
+ io_col1.dataframe(data={'Signs/Symptoms':ss_input_select_list})
144
+ io_col1.text_area(label="Raw JSON Input", value=st.session_state.agent.getInputAsJsonString(), height=200)
145
+
146
+ # Run the agent with the given input to get a corresponding memory
147
+ st.session_state.agent.setMaxOpCycles(-1)
148
+ st.session_state.agent.queryDecision(agent_config.decisionTypeId, 5)
149
+
150
+ output = st.session_state.agent.getOutputAsJsonString()
151
+ query_time_ms = st.session_state.agent.getLastQueryTime()/1000000.0
152
+ io_col2.text("Agent Response: ("+str(query_time_ms)+" ms)")
153
+ io_col2.dataframe(data=json_to_output_df(output, ss_input_select_list),)
154
+ io_col2.text_area(label="Raw JSON Output:",value=output, height=500)
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Crl Demo
3
- emoji: 🐠
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: docker
 
1
  ---
2
+ title: Cognitive Reasoner Lite - Demo
3
+ emoji: 🧠
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: docker
pages/Episode_Recall_Demo.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import random
4
+ import streamlit as st
5
+
6
+ import shared_streamlit_funcs as my
7
+
8
+
9
+ def generate_episodes(num_channels=10, num_quantities=10, num_latencies=6, records_per_channel=3):
10
+ # Generate random episodes as combinations of the above concept elements
11
+ # and store them in a table for display
12
+
13
+ # Define the possible inputs for the agent
14
+ st.session_state["channels_list"] = ['acme-'+str(n+1) for n in range(num_channels)]
15
+ st.session_state["quantities_list"] = [str(int(pow(10,n))) for n in range(num_quantities)]
16
+ st.session_state["latencies_list"] = list(['8h','12h'])
17
+ st.session_state["latencies_list"].extend([str(n*24)+"h" for n in range(1,num_latencies-1)])
18
+
19
+ channels_column = []
20
+ quantities_column = []
21
+ latencies_column = []
22
+ for channel in st.session_state.channels_list:
23
+ # Get 3 random quantity:latency pairs
24
+ ch_quantities = []
25
+ for _ in range(records_per_channel):
26
+ q = None
27
+ while True:
28
+ q = random.choice(st.session_state.quantities_list)
29
+ if q not in ch_quantities:
30
+ break
31
+ ch_quantities.append(q)
32
+ # Add a row with this channel and each quantity with a random latency
33
+ for q in ch_quantities:
34
+ channels_column.append(channel)
35
+ quantities_column.append(q)
36
+ latencies_column.append(random.choice(st.session_state.latencies_list))
37
+
38
+ # Remake the agent as a quick way of resetting its episodic memory
39
+ my.make_agent()
40
+
41
+ # Define the input concepts for this agent
42
+ st.session_state["channel_concept_map"] = {channel: st.session_state.agent.getConcept("{'request':{'channel':'"+channel+"'}}") for channel in st.session_state.channels_list}
43
+ st.session_state["quantity_concept_map"] = {quantity: st.session_state.agent.getConcept("{'request':{'quantity':'"+quantity+"'}}") for quantity in st.session_state.quantities_list}
44
+ st.session_state["latency_concept_map"] = {latency: st.session_state.agent.getConcept("{'request':{'latency':'"+latency+"'}}") for latency in st.session_state.latencies_list}
45
+
46
+ # Send each row in the table to the agent as an episode of experience
47
+ for (channel, quantity, latency) in zip(channels_column,quantities_column,latencies_column):
48
+ inList = st.session_state.Java_ArrayList()
49
+ inList.add(st.session_state.channel_concept_map[channel])
50
+ inList.add(st.session_state.quantity_concept_map[quantity])
51
+ inList.add(st.session_state.latency_concept_map[latency])
52
+ st.session_state.agent.trainEpisode(inList)
53
+
54
+ # Compile the dataframe from these columns
55
+ st.session_state["er_df"] = pd.DataFrame(data={'Channel':channels_column,
56
+ 'Quantity':quantities_column,
57
+ 'Latency':latencies_column})
58
+ st.session_state.er_df.sort_values(['Channel', 'Quantity'], ascending=[True, True], inplace=True)
59
+
60
+ def json_to_output_df(json_str):
61
+ data =json.loads(json_str)
62
+ haserr = False
63
+ if 'channel' not in data or 'quantity' not in data or 'latency' not in data:
64
+ haserr = True
65
+ print("** Missing expected data from agent output.")
66
+
67
+ if not haserr:
68
+ retval = pd.DataFrame(data={'Channel':data['channel'],
69
+ 'Quantity':data['quantity'],
70
+ 'Latency':data['latency']})
71
+ else:
72
+ retval = pd.DataFrame()
73
+
74
+ return retval
75
+
76
+ # Initialize the session
77
+ if "agent" not in st.session_state:
78
+ my.init()
79
+
80
+ if "er_df" not in st.session_state:
81
+ generate_episodes()
82
+
83
+ ## SET UP STREAMLIT PAGE
84
+ # emojis: https://www.webfx.com/tools/emoji-cheat-sheet/
85
+ st.set_page_config(page_title="🧠CRL Demo", layout="centered")
86
+ st.subheader("Cognitive Reasoner Lite Demo")
87
+ st.title("Episodic Recall")
88
+ st.markdown("**Demonstrates automatic recall of prior experiences given a partial prompt.**")
89
+
90
+ st.header("Episodes shown to the agent:")
91
+ st.markdown("*(randomly generated)*")
92
+
93
+ eps_col1, eps_col2 = st.columns(2)
94
+ eps_col1.dataframe(st.session_state.er_df)
95
+ eps_col1.text("Number of rows: "+str(len(st.session_state.er_df)))
96
+
97
+ epgen_cha_count = eps_col2.number_input(label="Number of Channels", min_value=1, value=10)
98
+ epgen_qua_count = eps_col2.number_input(label="Number of Quantities", min_value=1, value=6)
99
+ epgen_lat_count = eps_col2.number_input(label="Number of Latencies", min_value=1, value=6)
100
+ epgen_recs_per_channel = eps_col2.number_input(label="Records per Channel", min_value=1, max_value=int(epgen_qua_count), value=3)
101
+
102
+ if eps_col2.button("Regenerate Episodes"):
103
+ generate_episodes(epgen_cha_count, epgen_qua_count, epgen_lat_count, epgen_recs_per_channel)
104
+ st.experimental_rerun()
105
+
106
+
107
+ st.markdown("---")
108
+ st.header("Input:")
109
+ st.subheader("This will be shown to the agent.")
110
+ select_col1, select_col2 = st.columns(2)
111
+ channel_select = select_col1.selectbox(label="Query Channel:", options=st.session_state.channels_list)
112
+ quantity_select = select_col2.selectbox(label="Query Quanity:", options=st.session_state.quantities_list)
113
+
114
+ # Send a partial pattern to the agent's input
115
+ st.session_state.agent.clearInput()
116
+ st.session_state.agent.addInput("{'request':{'channel':'"+channel_select+"', 'quantity':'"+str(quantity_select)+"'}}")
117
+ # st.session_state.agent.addInput("{'request':{'channel':'acme-1', 'quantity':'100'}}")
118
+
119
+
120
+ st.markdown("---")
121
+ st.header("Agent Output:")
122
+ # Show the input to the user
123
+ io_col1, io_col2 = st.columns(2)
124
+ io_col1.text("Input sent to agent:")
125
+ io_col1.dataframe(data={'Channel':[channel_select],
126
+ 'Quantity':[quantity_select]})
127
+ io_col1.text_area(label="Raw JSON Input", value=st.session_state.agent.getInputAsJsonString(), height=200)
128
+
129
+ # Run the agent with the given input to get a corresponding memory
130
+ st.session_state.agent.setMaxOpCycles(1) # FIXME: Why can so many episodes get recalled?
131
+ st.session_state.agent.queryPatternCompletion()
132
+
133
+ output = st.session_state.agent.getOutputAsJsonString()
134
+ query_time_ms = st.session_state.agent.getLastQueryTime()/1000000.0
135
+ io_col2.text("Agent Response: ("+str(query_time_ms)+" ms)")
136
+ io_col2.dataframe(data=json_to_output_df(output))
137
+ io_col2.text_area(label="Raw JSON Output:",value=output, height=300)
shared_streamlit_funcs.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.command.config import config
2
+ import streamlit as st
3
+
4
+ import jnius_config
5
+ if not jnius_config.vm_running:
6
+ jnius_config.set_classpath("/Users/bstearn1/Documents/Projects/OptumGitHub/bstearn1/CognitiveReasonerLite/CognitiveReasonerLite.jar")
7
+ from jnius import autoclass # For running Java. See https://pyjnius.readthedocs.io/en/latest/ for documentation.
8
+
9
+
10
+ CRL_PACKAGE = "com.optum.cogtech.crl."
11
+
12
+ def make_agent(config_name="agent_demo"):
13
+ # Start the CRL engine
14
+ st.session_state["agent"] = st.session_state.Java_Agent()
15
+ # Configure the decision making
16
+ decConfig = st.session_state.Java_DecisionConfig(config_name)
17
+ decConfig.selectAll()
18
+ st.session_state.agent.addSettings(decConfig)
19
+ # Configure debug printing
20
+ # st.session_state.agent.logger.disable()
21
+ st.session_state.agent.logger.setWriteToFile(False)
22
+ st.session_state.agent.logger.setEnableLogCycles(True)
23
+ st.session_state.agent.logger.setEnableLogContexts(True)
24
+ st.session_state.agent.logger.setEnableLogOperators(True)
25
+ st.session_state.agent.logger.setEnableLogActivation(True)
26
+ return decConfig
27
+
28
+ def init():
29
+ # Define the Java<->Python interface needed to run the CRL jar
30
+ st.session_state["Java_ArrayList"] = autoclass('java.util.ArrayList')
31
+ st.session_state["Java_Agent"] = autoclass(CRL_PACKAGE+"Agent")
32
+ st.session_state["Java_DecisionConfig"] = autoclass(CRL_PACKAGE+"DecisionConfig")
33
+ st.session_state["Java_Concept"] = autoclass(CRL_PACKAGE+"Concept")
34
+ st.session_state["Java_ActionReportActiveConcept"] = autoclass(CRL_PACKAGE+"ActionReportActiveConcept")
35
+
36
+ make_agent()
37
+
38
+ def ReportActiveConceptActionInList(outputAttribute, attributeForReportValue):
39
+ collection = st.session_state.Java_ArrayList()
40
+ collection.add(st.session_state.Java_ActionReportActiveConcept(outputAttribute, attributeForReportValue))
41
+ return collection