crl-demo / pages /Episode_Recall_Demo.py
bryan-stearns
Adding demo files
437883c
raw
history blame contribute delete
No virus
6.29 kB
import pandas as pd
import json
import random
import streamlit as st
import shared_streamlit_funcs as my
def generate_episodes(num_channels=10, num_quantities=10, num_latencies=6, records_per_channel=3):
# Generate random episodes as combinations of the above concept elements
# and store them in a table for display
# Define the possible inputs for the agent
st.session_state["channels_list"] = ['acme-'+str(n+1) for n in range(num_channels)]
st.session_state["quantities_list"] = [str(int(pow(10,n))) for n in range(num_quantities)]
st.session_state["latencies_list"] = list(['8h','12h'])
st.session_state["latencies_list"].extend([str(n*24)+"h" for n in range(1,num_latencies-1)])
channels_column = []
quantities_column = []
latencies_column = []
for channel in st.session_state.channels_list:
# Get 3 random quantity:latency pairs
ch_quantities = []
for _ in range(records_per_channel):
q = None
while True:
q = random.choice(st.session_state.quantities_list)
if q not in ch_quantities:
break
ch_quantities.append(q)
# Add a row with this channel and each quantity with a random latency
for q in ch_quantities:
channels_column.append(channel)
quantities_column.append(q)
latencies_column.append(random.choice(st.session_state.latencies_list))
# Remake the agent as a quick way of resetting its episodic memory
my.make_agent()
# Define the input concepts for this agent
st.session_state["channel_concept_map"] = {channel: st.session_state.agent.getConcept("{'request':{'channel':'"+channel+"'}}") for channel in st.session_state.channels_list}
st.session_state["quantity_concept_map"] = {quantity: st.session_state.agent.getConcept("{'request':{'quantity':'"+quantity+"'}}") for quantity in st.session_state.quantities_list}
st.session_state["latency_concept_map"] = {latency: st.session_state.agent.getConcept("{'request':{'latency':'"+latency+"'}}") for latency in st.session_state.latencies_list}
# Send each row in the table to the agent as an episode of experience
for (channel, quantity, latency) in zip(channels_column,quantities_column,latencies_column):
inList = st.session_state.Java_ArrayList()
inList.add(st.session_state.channel_concept_map[channel])
inList.add(st.session_state.quantity_concept_map[quantity])
inList.add(st.session_state.latency_concept_map[latency])
st.session_state.agent.trainEpisode(inList)
# Compile the dataframe from these columns
st.session_state["er_df"] = pd.DataFrame(data={'Channel':channels_column,
'Quantity':quantities_column,
'Latency':latencies_column})
st.session_state.er_df.sort_values(['Channel', 'Quantity'], ascending=[True, True], inplace=True)
def json_to_output_df(json_str):
data =json.loads(json_str)
haserr = False
if 'channel' not in data or 'quantity' not in data or 'latency' not in data:
haserr = True
print("** Missing expected data from agent output.")
if not haserr:
retval = pd.DataFrame(data={'Channel':data['channel'],
'Quantity':data['quantity'],
'Latency':data['latency']})
else:
retval = pd.DataFrame()
return retval
# Initialize the session
if "agent" not in st.session_state:
my.init()
if "er_df" not in st.session_state:
generate_episodes()
## SET UP STREAMLIT PAGE
# emojis: https://www.webfx.com/tools/emoji-cheat-sheet/
st.set_page_config(page_title="🧠CRL Demo", layout="centered")
st.subheader("Cognitive Reasoner Lite Demo")
st.title("Episodic Recall")
st.markdown("**Demonstrates automatic recall of prior experiences given a partial prompt.**")
st.header("Episodes shown to the agent:")
st.markdown("*(randomly generated)*")
eps_col1, eps_col2 = st.columns(2)
eps_col1.dataframe(st.session_state.er_df)
eps_col1.text("Number of rows: "+str(len(st.session_state.er_df)))
epgen_cha_count = eps_col2.number_input(label="Number of Channels", min_value=1, value=10)
epgen_qua_count = eps_col2.number_input(label="Number of Quantities", min_value=1, value=6)
epgen_lat_count = eps_col2.number_input(label="Number of Latencies", min_value=1, value=6)
epgen_recs_per_channel = eps_col2.number_input(label="Records per Channel", min_value=1, max_value=int(epgen_qua_count), value=3)
if eps_col2.button("Regenerate Episodes"):
generate_episodes(epgen_cha_count, epgen_qua_count, epgen_lat_count, epgen_recs_per_channel)
st.experimental_rerun()
st.markdown("---")
st.header("Input:")
st.subheader("This will be shown to the agent.")
select_col1, select_col2 = st.columns(2)
channel_select = select_col1.selectbox(label="Query Channel:", options=st.session_state.channels_list)
quantity_select = select_col2.selectbox(label="Query Quanity:", options=st.session_state.quantities_list)
# Send a partial pattern to the agent's input
st.session_state.agent.clearInput()
st.session_state.agent.addInput("{'request':{'channel':'"+channel_select+"', 'quantity':'"+str(quantity_select)+"'}}")
# st.session_state.agent.addInput("{'request':{'channel':'acme-1', 'quantity':'100'}}")
st.markdown("---")
st.header("Agent Output:")
# Show the input to the user
io_col1, io_col2 = st.columns(2)
io_col1.text("Input sent to agent:")
io_col1.dataframe(data={'Channel':[channel_select],
'Quantity':[quantity_select]})
io_col1.text_area(label="Raw JSON Input", value=st.session_state.agent.getInputAsJsonString(), height=200)
# Run the agent with the given input to get a corresponding memory
st.session_state.agent.setMaxOpCycles(1) # FIXME: Why can so many episodes get recalled?
st.session_state.agent.queryPatternCompletion()
output = st.session_state.agent.getOutputAsJsonString()
query_time_ms = st.session_state.agent.getLastQueryTime()/1000000.0
io_col2.text("Agent Response: ("+str(query_time_ms)+" ms)")
io_col2.dataframe(data=json_to_output_df(output))
io_col2.text_area(label="Raw JSON Output:",value=output, height=300)