File size: 6,293 Bytes
437883c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import pandas as pd
import json
import random
import streamlit as st
import shared_streamlit_funcs as my
def generate_episodes(num_channels=10, num_quantities=10, num_latencies=6, records_per_channel=3):
# Generate random episodes as combinations of the above concept elements
# and store them in a table for display
# Define the possible inputs for the agent
st.session_state["channels_list"] = ['acme-'+str(n+1) for n in range(num_channels)]
st.session_state["quantities_list"] = [str(int(pow(10,n))) for n in range(num_quantities)]
st.session_state["latencies_list"] = list(['8h','12h'])
st.session_state["latencies_list"].extend([str(n*24)+"h" for n in range(1,num_latencies-1)])
channels_column = []
quantities_column = []
latencies_column = []
for channel in st.session_state.channels_list:
# Get 3 random quantity:latency pairs
ch_quantities = []
for _ in range(records_per_channel):
q = None
while True:
q = random.choice(st.session_state.quantities_list)
if q not in ch_quantities:
break
ch_quantities.append(q)
# Add a row with this channel and each quantity with a random latency
for q in ch_quantities:
channels_column.append(channel)
quantities_column.append(q)
latencies_column.append(random.choice(st.session_state.latencies_list))
# Remake the agent as a quick way of resetting its episodic memory
my.make_agent()
# Define the input concepts for this agent
st.session_state["channel_concept_map"] = {channel: st.session_state.agent.getConcept("{'request':{'channel':'"+channel+"'}}") for channel in st.session_state.channels_list}
st.session_state["quantity_concept_map"] = {quantity: st.session_state.agent.getConcept("{'request':{'quantity':'"+quantity+"'}}") for quantity in st.session_state.quantities_list}
st.session_state["latency_concept_map"] = {latency: st.session_state.agent.getConcept("{'request':{'latency':'"+latency+"'}}") for latency in st.session_state.latencies_list}
# Send each row in the table to the agent as an episode of experience
for (channel, quantity, latency) in zip(channels_column,quantities_column,latencies_column):
inList = st.session_state.Java_ArrayList()
inList.add(st.session_state.channel_concept_map[channel])
inList.add(st.session_state.quantity_concept_map[quantity])
inList.add(st.session_state.latency_concept_map[latency])
st.session_state.agent.trainEpisode(inList)
# Compile the dataframe from these columns
st.session_state["er_df"] = pd.DataFrame(data={'Channel':channels_column,
'Quantity':quantities_column,
'Latency':latencies_column})
st.session_state.er_df.sort_values(['Channel', 'Quantity'], ascending=[True, True], inplace=True)
def json_to_output_df(json_str):
data =json.loads(json_str)
haserr = False
if 'channel' not in data or 'quantity' not in data or 'latency' not in data:
haserr = True
print("** Missing expected data from agent output.")
if not haserr:
retval = pd.DataFrame(data={'Channel':data['channel'],
'Quantity':data['quantity'],
'Latency':data['latency']})
else:
retval = pd.DataFrame()
return retval
# Initialize the session
if "agent" not in st.session_state:
my.init()
if "er_df" not in st.session_state:
generate_episodes()
## SET UP STREAMLIT PAGE
# emojis: https://www.webfx.com/tools/emoji-cheat-sheet/
st.set_page_config(page_title="🧠CRL Demo", layout="centered")
st.subheader("Cognitive Reasoner Lite Demo")
st.title("Episodic Recall")
st.markdown("**Demonstrates automatic recall of prior experiences given a partial prompt.**")
st.header("Episodes shown to the agent:")
st.markdown("*(randomly generated)*")
eps_col1, eps_col2 = st.columns(2)
eps_col1.dataframe(st.session_state.er_df)
eps_col1.text("Number of rows: "+str(len(st.session_state.er_df)))
epgen_cha_count = eps_col2.number_input(label="Number of Channels", min_value=1, value=10)
epgen_qua_count = eps_col2.number_input(label="Number of Quantities", min_value=1, value=6)
epgen_lat_count = eps_col2.number_input(label="Number of Latencies", min_value=1, value=6)
epgen_recs_per_channel = eps_col2.number_input(label="Records per Channel", min_value=1, max_value=int(epgen_qua_count), value=3)
if eps_col2.button("Regenerate Episodes"):
generate_episodes(epgen_cha_count, epgen_qua_count, epgen_lat_count, epgen_recs_per_channel)
st.experimental_rerun()
st.markdown("---")
st.header("Input:")
st.subheader("This will be shown to the agent.")
select_col1, select_col2 = st.columns(2)
channel_select = select_col1.selectbox(label="Query Channel:", options=st.session_state.channels_list)
quantity_select = select_col2.selectbox(label="Query Quanity:", options=st.session_state.quantities_list)
# Send a partial pattern to the agent's input
st.session_state.agent.clearInput()
st.session_state.agent.addInput("{'request':{'channel':'"+channel_select+"', 'quantity':'"+str(quantity_select)+"'}}")
# st.session_state.agent.addInput("{'request':{'channel':'acme-1', 'quantity':'100'}}")
st.markdown("---")
st.header("Agent Output:")
# Show the input to the user
io_col1, io_col2 = st.columns(2)
io_col1.text("Input sent to agent:")
io_col1.dataframe(data={'Channel':[channel_select],
'Quantity':[quantity_select]})
io_col1.text_area(label="Raw JSON Input", value=st.session_state.agent.getInputAsJsonString(), height=200)
# Run the agent with the given input to get a corresponding memory
st.session_state.agent.setMaxOpCycles(1) # FIXME: Why can so many episodes get recalled?
st.session_state.agent.queryPatternCompletion()
output = st.session_state.agent.getOutputAsJsonString()
query_time_ms = st.session_state.agent.getLastQueryTime()/1000000.0
io_col2.text("Agent Response: ("+str(query_time_ms)+" ms)")
io_col2.dataframe(data=json_to_output_df(output))
io_col2.text_area(label="Raw JSON Output:",value=output, height=300)
|