Spaces:
Sleeping
Sleeping
File size: 4,264 Bytes
1a0cf07 8ff6b24 1a0cf07 8ff6b24 1a0cf07 8ff6b24 118d254 8ff6b24 1a0cf07 8ff6b24 1a0cf07 118d254 1a0cf07 8ff6b24 1a0cf07 8ff6b24 1a0cf07 8ff6b24 1a0cf07 8ff6b24 1a0cf07 474fca3 1a0cf07 8ff6b24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import json
import logging
import multiprocessing
import os
import gradio as gr
from swiftsage.agents import SwiftSage
from swiftsage.utils.commons import PromptTemplate, api_configs, setup_logging
from pkg_resources import resource_filename
def solve_problem(problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, feedback_model_id, use_retrieval, start_with_sage):
# Configuration for each LLM
max_iterations = int(max_iterations)
reward_threshold = int(reward_threshold)
swift_config = {
"model_id": swift_model_id,
"api_config": api_configs['Together']
}
reward_config = {
"model_id": feedback_model_id,
"api_config": api_configs['Together']
}
sage_config = {
"model_id": sage_model_id,
"api_config": api_configs['Together']
}
# specify the path to the prompt templates
# prompt_template_dir = './swiftsage/prompt_templates'
prompt_template_dir = resource_filename('swiftsage', 'prompt_templates')
dataset = []
embeddings = [] # TODO: for retrieval augmentation (not implemented yet now)
s2 = SwiftSage(
dataset,
embeddings,
prompt_template_dir,
swift_config,
sage_config,
reward_config,
use_retrieval=use_retrieval,
start_with_sage=start_with_sage,
)
reasoning, solution = s2.solve(problem, max_iterations, reward_threshold)
solution = solution.replace("Answer (from running the code):\n ", " ")
return reasoning, solution
with gr.Blocks(theme=gr.themes.Soft()) as demo:
# gr.Markdown("## SwiftSage: A Multi-Agent Framework for Reasoning")
# use the html and center the title
gr.HTML("<h1 style='text-align: center;'>SwiftSage: A Multi-Agent Framework for Reasoning</h1>")
with gr.Row():
swift_model_id = gr.Textbox(label="😄 Swift Model ID", value="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo")
feedback_model_id = gr.Textbox(label="🤔 Feedback Model ID", value="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo")
sage_model_id = gr.Textbox(label="😎 Sage Model ID", value="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo")
# the following two should have a smaller width
with gr.Accordion(label="⚙️ Advanced Options", open=False):
with gr.Row():
with gr.Column():
max_iterations = gr.Textbox(label="Max Iterations", value="5")
reward_threshold = gr.Textbox(label="Reward Threshold", value="8")
# TODO: add top-p and temperature for each module for controlling
with gr.Column():
top_p_swift = gr.Textbox(label="Top-p for Swift", value="0.9")
temperature_swift = gr.Textbox(label="Temperature for Swift", value="0.7")
with gr.Column():
top_p_sage = gr.Textbox(label="Top-p for Sage", value="0.9")
temperature_sage = gr.Textbox(label="Temperature for Sage", value="0.7")
with gr.Column():
top_p_reward = gr.Textbox(label="Top-p for Feedback", value="0.9")
temperature_reward = gr.Textbox(label="Temperature for Feedback", value="0.7")
use_retrieval = gr.Checkbox(label="Use Retrieval Augmentation", value=False, visible=False)
start_with_sage = gr.Checkbox(label="Start with Sage", value=False, visible=False)
problem = gr.Textbox(label="Input your problem", value="How many letter r are there in the sentence 'My strawberry is so ridiculously red.'?", lines=2)
solve_button = gr.Button("🚀 Solve Problem")
reasoning_output = gr.Textbox(label="Reasoning steps with Code", interactive=False)
solution_output = gr.Textbox(label="Final answer", interactive=False)
solve_button.click(
solve_problem,
inputs=[problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, feedback_model_id, use_retrieval, start_with_sage],
outputs=[reasoning_output, solution_output]
)
if __name__ == '__main__':
# make logs dir if it does not exist
if not os.path.exists('logs'):
os.makedirs('logs')
multiprocessing.set_start_method('spawn')
demo.launch(share=False, show_api=False)
|