|
import streamlit as st |
|
import openai |
|
import os |
|
import sys |
|
import argparse |
|
sys.path.append('./lats') |
|
from lats_main import lats_main |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
if 'response_content' not in st.session_state: |
|
st.session_state.response_content = None |
|
|
|
|
|
chat_col = st.container() |
|
|
|
chat_col.title("SambaLATS") |
|
description = """This demo is an implementation of Language Agent Tree Search (LATS) (https://arxiv.org/abs/2310.04406) with Samba-1 in the backend. Thank you to the original authors of demo on which this is based from [Lapis Labs](https://lapis.rocks/)! |
|
|
|
Given Samba-1's lightning quick inference, not only can we accelerate our system's speeds but also improve our system's accuracy. Using many inference calls in this LATS style, we can solve programming questions with higher accuracy. In fact, this system reaches **GPT-3.5 accuracy on HumanEval Python**, 74% accuracy, with LLaMa 3 8B, taking 8 seconds on average. This is a 15.5% boost on LLaMa 3 8B alone. |
|
|
|
Listed below is an example programming problem (https://leetcode.com/problems/median-of-two-sorted-arrays/description/) to get started with. |
|
|
|
```python |
|
Given two sorted arrays `nums1` and `nums2` of size `m` and `n` respectively, return **the median** of the two sorted arrays. The overall run time complexity should be `O(log (m+n))`. **Example 1:** **Input:** nums1 = \[1,3\], nums2 = \[2\] **Output:** 2.00000 **Explanation:** merged array = \[1,2,3\] and median is 2. **Example 2:** **Input:** nums1 = \[1,2\], nums2 = \[3,4\] **Output:** 2.50000 **Explanation:** merged array = \[1,2,3,4\] and median is (2 + 3) / 2 = 2.5. **Constraints:** * `nums1.length == m` * `nums2.length == n` * `0 <= m <= 1000` * `0 <= n <= 1000` * `1 <= m + n <= 2000` * `-106 <= nums1[i], nums2[i] <= 106` |
|
``` |
|
""" |
|
|
|
chat_col.markdown(description) |
|
sidebar = st.sidebar |
|
|
|
runtime_container = st.container() |
|
|
|
|
|
sidebar.title("From SambaNova Systems") |
|
parameters_section = sidebar.expander("Parameters", expanded=False) |
|
tree_width = parameters_section.number_input("Tree Width", min_value=1, max_value=5, value=1) |
|
tree_depth = parameters_section.number_input("Tree Depth", min_value=1, max_value=8, value=3) |
|
iterations = parameters_section.number_input("Iterations", min_value=1, max_value=4, value=2) |
|
sidebar.markdown('<hr style="margin-top: 0.5rem; margin-bottom: 0.5rem;">', unsafe_allow_html=True) |
|
|
|
with sidebar: |
|
runtime_container = st.container() |
|
runtime_container.empty() |
|
|
|
runtime_messages = [] |
|
|
|
def make_args(instruction, tree_depth, tree_width, iterations): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--strategy", default="mcts", help="Strategy to use") |
|
parser.add_argument("--language", default="py", help="Programming language") |
|
parser.add_argument("--max_iters", default=iterations, help="Maximum iterations") |
|
parser.add_argument("--instruction", default=instruction, help="Instruction text") |
|
parser.add_argument("--verbose", action="store_true", help="Verbose output") |
|
parser.add_argument("--is_leetcode", action='store_true', |
|
help="To run the leetcode benchmark") |
|
parser.add_argument("--n_samples", type=int, |
|
help="The number of nodes added during expansion", default=tree_width) |
|
parser.add_argument("--depth", type=int, |
|
help="Tree depth", default=tree_depth) |
|
args = parser.parse_args() |
|
return args |
|
|
|
def run_querry(): |
|
if user_input: |
|
|
|
runtime_container.write("Initiating process...") |
|
|
|
|
|
old_stdout = sys.stdout |
|
sys.stdout = runtime_container |
|
|
|
with chat_col: |
|
|
|
with st.spinner('Running...'): |
|
args = make_args(user_input, tree_depth, tree_width, iterations) |
|
setattr(args, 'model', 'samba') |
|
|
|
response = lats_main(args) |
|
|
|
sys.stdout = old_stdout |
|
runtime_container.write("Response fetched.") |
|
chat_col.markdown('<hr style="margin-top: 0.5rem; margin-bottom: 0.5rem;">', unsafe_allow_html=True) |
|
chat_col.write(f"```python\n{response} \n") |
|
|
|
return response |
|
|
|
|
|
with chat_col: |
|
user_input = st.text_area("Enter your message here:", placeholder="Type your message here...", label_visibility="collapsed") |
|
button = st.button("Send") |
|
|
|
if button: |
|
fail = False |
|
|
|
if user_input == "": |
|
st.warning("Missing a coding problem") |
|
fail = True |
|
|
|
if (not fail): |
|
run_querry() |
|
|
|
|