import gradio as gr
from llm_inference import LLMInferenceNode
import random
title = """
Random Prompt Generator
[X gokaygokay]
[Github gokayfem]
[comfyui_dagthomas]
Generate random prompts using powerful LLMs from Hugging Face, Groq, and SambaNova.
"""
# Global variable to store selected prompt type
selected_prompt_type = "Long" # Default value
def create_interface():
llm_node = LLMInferenceNode()
with gr.Blocks(theme='bethecloud/storj_theme') as demo:
gr.HTML(title)
with gr.Row():
with gr.Column(scale=2):
with gr.Accordion("Basic Settings"):
custom = gr.Textbox(label="Custom Input Prompt (optional)")
with gr.Accordion("Prompt Generation Options", open=False):
prompt_type = gr.Dropdown(
choices=["Long", "Short", "Medium", "Long"],
label="Prompt Type",
value="Long",
interactive=True
)
# Function to update the selected prompt type
def update_prompt_type(value):
global selected_prompt_type
selected_prompt_type = value
print(f"Updated prompt type: {selected_prompt_type}")
return value
# Connect the update_prompt_type function to the prompt_type dropdown
prompt_type.change(update_prompt_type, inputs=[prompt_type], outputs=[prompt_type])
with gr.Column(scale=2):
generate_button = gr.Button("Generate Prompt")
with gr.Accordion("Generated Prompt", open=True):
output = gr.Textbox(label="Generated Prompt", lines=4, show_copy_button=True)
text_output = gr.Textbox(label="LLM Generated Text", lines=10, show_copy_button=True)
with gr.Column(scale=2):
with gr.Accordion("""LLM Prompt Generation""", open=False):
long_talk = gr.Checkbox(label="Long Talk", value=True)
compress = gr.Checkbox(label="Compress", value=True)
compression_level = gr.Dropdown(
choices=["soft", "medium", "hard"],
label="Compression Level",
value="hard"
)
custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5)
# LLM Provider Selection
llm_provider = gr.Dropdown(
choices=["Hugging Face", "Groq", "SambaNova"],
label="LLM Provider",
value="Hugging Face"
)
api_key = gr.Textbox(label="API Key", type="password", visible=False)
model = gr.Dropdown(label="Model", choices=[], value="")
generate_text_button = gr.Button("Generate Prompt with LLM")
text_output = gr.Textbox(label="Generated Text", lines=10, show_copy_button=True)
# Initialize Models based on provider
def update_model_choices(provider):
provider_models = {
"Hugging Face": ["meta-llama/Meta-Llama-3.1-70B-Instruct", "another-model-hf"],
"Groq": ["llama-3.1-70b-versatile", "mixtral-8x7b-32768", "gemma2-9b-it"],
"SambaNova": ["Meta-Llama-3.1-70B-Instruct", "Meta-Llama-3.1-405B-Instruct", "Meta-Llama-3.1-8B-Instruct"],
}
models = provider_models.get(provider, [])
return gr.Dropdown.update(choices=models, value=models[0] if models else "")
def update_api_key_visibility(provider):
return gr.update(visible=False) # No API key required for selected providers
llm_provider.change(update_model_choices, inputs=[llm_provider], outputs=[model])
llm_provider.change(update_api_key_visibility, inputs=[llm_provider], outputs=[api_key])
# Generate Prompt Function
def generate_prompt(prompt_type, custom_input):
dynamic_seed = random.randint(0, 1000000)
result = llm_node.generate_prompt(dynamic_seed, prompt_type, custom_input)
return result
generate_button.click(
generate_prompt,
inputs=[prompt_type, custom],
outputs=[output]
)
# Generate Text with LLM
def generate_text_with_llm(output_prompt, long_talk, compress, compression_level, custom_base_prompt, provider, api_key, model_selected):
global selected_prompt_type
result = llm_node.generate(
input_text=output_prompt,
long_talk=long_talk,
compress=compress,
compression_level=compression_level,
prompt_type=selected_prompt_type,
custom_base_prompt=custom_base_prompt,
provider=provider,
api_key=api_key,
model=model_selected
)
selected_prompt_type = "Long"
return result
generate_text_button.click(
generate_text_with_llm,
inputs=[output, long_talk, compress, compression_level, custom_base_prompt, llm_provider, api_key, model],
outputs=[text_output],
api_name="generate_text"
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()