|
import pandas as pd |
|
import json |
|
import random |
|
import streamlit as st |
|
|
|
import shared_streamlit_funcs as my |
|
|
|
|
|
def generate_episodes(num_channels=10, num_quantities=10, num_latencies=6, records_per_channel=3): |
|
|
|
|
|
|
|
|
|
st.session_state["channels_list"] = ['acme-'+str(n+1) for n in range(num_channels)] |
|
st.session_state["quantities_list"] = [str(int(pow(10,n))) for n in range(num_quantities)] |
|
st.session_state["latencies_list"] = list(['8h','12h']) |
|
st.session_state["latencies_list"].extend([str(n*24)+"h" for n in range(1,num_latencies-1)]) |
|
|
|
channels_column = [] |
|
quantities_column = [] |
|
latencies_column = [] |
|
for channel in st.session_state.channels_list: |
|
|
|
ch_quantities = [] |
|
for _ in range(records_per_channel): |
|
q = None |
|
while True: |
|
q = random.choice(st.session_state.quantities_list) |
|
if q not in ch_quantities: |
|
break |
|
ch_quantities.append(q) |
|
|
|
for q in ch_quantities: |
|
channels_column.append(channel) |
|
quantities_column.append(q) |
|
latencies_column.append(random.choice(st.session_state.latencies_list)) |
|
|
|
|
|
my.make_agent() |
|
|
|
|
|
st.session_state["channel_concept_map"] = {channel: st.session_state.agent.getConcept("{'request':{'channel':'"+channel+"'}}") for channel in st.session_state.channels_list} |
|
st.session_state["quantity_concept_map"] = {quantity: st.session_state.agent.getConcept("{'request':{'quantity':'"+quantity+"'}}") for quantity in st.session_state.quantities_list} |
|
st.session_state["latency_concept_map"] = {latency: st.session_state.agent.getConcept("{'request':{'latency':'"+latency+"'}}") for latency in st.session_state.latencies_list} |
|
|
|
|
|
for (channel, quantity, latency) in zip(channels_column,quantities_column,latencies_column): |
|
inList = st.session_state.Java_ArrayList() |
|
inList.add(st.session_state.channel_concept_map[channel]) |
|
inList.add(st.session_state.quantity_concept_map[quantity]) |
|
inList.add(st.session_state.latency_concept_map[latency]) |
|
st.session_state.agent.trainEpisode(inList) |
|
|
|
|
|
st.session_state["er_df"] = pd.DataFrame(data={'Channel':channels_column, |
|
'Quantity':quantities_column, |
|
'Latency':latencies_column}) |
|
st.session_state.er_df.sort_values(['Channel', 'Quantity'], ascending=[True, True], inplace=True) |
|
|
|
def json_to_output_df(json_str): |
|
data =json.loads(json_str) |
|
haserr = False |
|
if 'channel' not in data or 'quantity' not in data or 'latency' not in data: |
|
haserr = True |
|
print("** Missing expected data from agent output.") |
|
|
|
if not haserr: |
|
retval = pd.DataFrame(data={'Channel':data['channel'], |
|
'Quantity':data['quantity'], |
|
'Latency':data['latency']}) |
|
else: |
|
retval = pd.DataFrame() |
|
|
|
return retval |
|
|
|
|
|
if "agent" not in st.session_state: |
|
my.init() |
|
|
|
if "er_df" not in st.session_state: |
|
generate_episodes() |
|
|
|
|
|
|
|
st.set_page_config(page_title="🧠CRL Demo", layout="centered") |
|
st.subheader("Cognitive Reasoner Lite Demo") |
|
st.title("Episodic Recall") |
|
st.markdown("**Demonstrates automatic recall of prior experiences given a partial prompt.**") |
|
|
|
st.header("Episodes shown to the agent:") |
|
st.markdown("*(randomly generated)*") |
|
|
|
eps_col1, eps_col2 = st.columns(2) |
|
eps_col1.dataframe(st.session_state.er_df) |
|
eps_col1.text("Number of rows: "+str(len(st.session_state.er_df))) |
|
|
|
epgen_cha_count = eps_col2.number_input(label="Number of Channels", min_value=1, value=10) |
|
epgen_qua_count = eps_col2.number_input(label="Number of Quantities", min_value=1, value=6) |
|
epgen_lat_count = eps_col2.number_input(label="Number of Latencies", min_value=1, value=6) |
|
epgen_recs_per_channel = eps_col2.number_input(label="Records per Channel", min_value=1, max_value=int(epgen_qua_count), value=3) |
|
|
|
if eps_col2.button("Regenerate Episodes"): |
|
generate_episodes(epgen_cha_count, epgen_qua_count, epgen_lat_count, epgen_recs_per_channel) |
|
st.experimental_rerun() |
|
|
|
|
|
st.markdown("---") |
|
st.header("Input:") |
|
st.subheader("This will be shown to the agent.") |
|
select_col1, select_col2 = st.columns(2) |
|
channel_select = select_col1.selectbox(label="Query Channel:", options=st.session_state.channels_list) |
|
quantity_select = select_col2.selectbox(label="Query Quanity:", options=st.session_state.quantities_list) |
|
|
|
|
|
st.session_state.agent.clearInput() |
|
st.session_state.agent.addInput("{'request':{'channel':'"+channel_select+"', 'quantity':'"+str(quantity_select)+"'}}") |
|
|
|
|
|
|
|
st.markdown("---") |
|
st.header("Agent Output:") |
|
|
|
io_col1, io_col2 = st.columns(2) |
|
io_col1.text("Input sent to agent:") |
|
io_col1.dataframe(data={'Channel':[channel_select], |
|
'Quantity':[quantity_select]}) |
|
io_col1.text_area(label="Raw JSON Input", value=st.session_state.agent.getInputAsJsonString(), height=200) |
|
|
|
|
|
st.session_state.agent.setMaxOpCycles(1) |
|
st.session_state.agent.queryPatternCompletion() |
|
|
|
output = st.session_state.agent.getOutputAsJsonString() |
|
query_time_ms = st.session_state.agent.getLastQueryTime()/1000000.0 |
|
io_col2.text("Agent Response: ("+str(query_time_ms)+" ms)") |
|
io_col2.dataframe(data=json_to_output_df(output)) |
|
io_col2.text_area(label="Raw JSON Output:",value=output, height=300) |
|
|