import gradio as gr
from llm_inference import LLMInferenceNode
import random
title = """
Random Prompt Generator
[X gokaygokay]
[Github gokayfem]
[comfyui_dagthomas]
Generate random prompts using powerful LLMs from Hugging Face, Groq, and SambaNova.
"""
# Global variable to store selected prompt type
selected_prompt_type = "Long" # Default value
def create_interface():
llm_node = LLMInferenceNode()
with gr.Blocks(theme='bethecloud/storj_theme') as demo:
gr.HTML(title)
with gr.Row():
with gr.Column(scale=2):
with gr.Accordion("Settings"):
custom = gr.Textbox(label="Custom Input Prompt (optional)")
with gr.Accordion("Prompt Generation Options", open=False):
prompt_type = gr.Dropdown(
choices=["Long", "Short", "Medium", "OnlyObjects", "NoFigure", "Landscape", "Fantasy"],
label="Prompt Type",
value="Long",
interactive=True
)
# Function to update the selected prompt type
def update_prompt_type(value):
global selected_prompt_type
selected_prompt_type = value
print(f"Updated prompt type: {selected_prompt_type}")
return value
# Connect the update_prompt_type function to the prompt_type dropdown
prompt_type.change(update_prompt_type, inputs=[prompt_type], outputs=[prompt_type])
with gr.Column(scale=2):
with gr.Accordion("LLM Prompt Generation", open=False):
long_talk = gr.Checkbox(label="Long Talk", value=True)
compress = gr.Checkbox(label="Compress", value=True)
compression_level = gr.Dropdown(
choices=["soft", "medium", "hard"],
label="Compression Level",
value="hard"
)
custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5)
# LLM Provider Selection
llm_provider = gr.Dropdown(
choices=["Hugging Face", "Groq", "SambaNova"],
label="LLM Provider",
value="Hugging Face"
)
api_key = gr.Textbox(label="API Key", type="password", visible=False)
model = gr.Dropdown(label="Model", choices=[], value="")
# **Single Button for Generating Prompt and Text**
generate_button = gr.Button("Generate Random Prompt with LLM")
text_output = gr.Textbox(label="LLM Generated Text", lines=10, show_copy_button=True)
# Initialize Models based on provider
def update_model_choices(provider):
provider_models = {
"Hugging Face": ["meta-llama/Meta-Llama-3.1-70B-Instruct", "another-model-hf"],
"Groq": ["llama-3.1-70b-versatile", "mixtral-8x7b-32768", "gemma2-9b-it"],
"SambaNova": ["Meta-Llama-3.1-70B-Instruct", "Meta-Llama-3.1-405B-Instruct", "Meta-Llama-3.1-8B-Instruct"],
}
models = provider_models.get(provider, [])
return gr.Dropdown.update(choices=models, value=models[0] if models else "")
def update_api_key_visibility(provider):
return gr.update(visible=False) # No API key required for selected providers
llm_provider.change(update_model_choices, inputs=[llm_provider], outputs=[model])
llm_provider.change(update_api_key_visibility, inputs=[llm_provider], outputs=[api_key])
# **Unified Function to Generate Prompt and Text**
def generate_random_prompt_with_llm(custom_input, prompt_type, long_talk, compress, compression_level, custom_base_prompt, provider, api_key, model_selected):
try:
# Step 1: Generate Prompt
dynamic_seed = random.randint(0, 1000000)
prompt = llm_node.generate_prompt(dynamic_seed, prompt_type, custom_input)
print(f"Generated Prompt: {prompt}")
# Step 2: Generate Text with LLM
poster = False # Set a default value or modify as needed
result = llm_node.generate(
input_text=prompt,
long_talk=long_talk,
compress=compress,
compression_level=compression_level,
poster=poster, # Added the missing 'poster' argument
prompt_type=selected_prompt_type,
custom_base_prompt=custom_base_prompt,
provider=provider,
api_key=api_key,
model=model_selected
)
print(f"Generated Text: {result}")
# Reset selected_prompt_type if necessary
selected_prompt_type = "Long"
return result
except Exception as e:
print(f"An error occurred: {e}")
return f"Error occurred while processing the request: {str(e)}"
# **Connect the Unified Function to the Single Button**
generate_button.click(
generate_random_prompt_with_llm,
inputs=[custom, prompt_type, long_talk, compress, compression_level, custom_base_prompt, llm_provider, api_key, model],
outputs=[text_output],
api_name="generate_random_prompt_with_llm"
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True)