File size: 8,020 Bytes
e1089fb
 
 
f8cf70e
 
e1089fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00229d2
754b512
a3165fa
23ee02b
a3165fa
23ee02b
a3165fa
23ee02b
 
 
54e5a29
a3165fa
54e5a29
a3165fa
54e5a29
a3165fa
 
 
 
54e5a29
 
23ee02b
 
54e5a29
e1089fb
 
754b512
e1089fb
 
 
 
 
 
 
 
 
 
 
 
 
 
4c39c86
e1089fb
 
af7fb19
1836ce5
c2537ba
4489ce8
00229d2
c2537ba
f8cf70e
e1089fb
ad1b7ba
e1089fb
 
ad1b7ba
 
 
 
 
 
 
 
 
c2e6c0d
ad1b7ba
 
 
 
 
 
e1089fb
 
ad1b7ba
e1089fb
 
 
 
ad1b7ba
 
 
 
 
 
 
 
 
 
e1089fb
754b512
54e5a29
754b512
 
 
231da44
 
 
 
 
 
9405818
 
 
 
 
 
 
754b512
 
 
 
 
 
 
 
 
231da44
 
754b512
 
 
 
 
 
e1089fb
754b512
e1089fb
754b512
 
 
 
 
 
 
54e5a29
e1089fb
754b512
e1089fb
 
f8cf70e
 
 
 
 
 
ec31dc5
 
 
f8cf70e
 
 
 
 
 
 
 
 
 
 
e1089fb
 
 
 
5c7ec8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import gradio as gr
from llm_inference import LLMInferenceNode
import random
from PIL import Image
import io

title = """<h1 align="center">Random Prompt Generator</h1>
<p><center>
<a href="https://x.com/gokayfem" target="_blank">[X gokaygokay]</a>
<a href="https://github.com/gokayfem" target="_blank">[Github gokayfem]</a>
<p align="center">Generate random prompts using powerful LLMs from Hugging Face, Groq, and SambaNova.</p>
</center></p>
"""

def create_interface():
    llm_node = LLMInferenceNode()

    with gr.Blocks(theme='bethecloud/storj_theme') as demo:
        
        gr.HTML(title)

        with gr.Row():
            with gr.Column(scale=2):
                custom = gr.Textbox(label="Custom Input Prompt (optional)", lines=3)

                prompt_types = ["Random", "Long", "Short", "Medium", "OnlyObjects", "NoFigure", "Landscape", "Fantasy"]
                prompt_type = gr.Dropdown(
                    choices=prompt_types,
                    label="Prompt Type",
                    value="Random",
                    interactive=True
                )
                
                # Add a State component to store the selected prompt type
                prompt_type_state = gr.State("Random")

                # Update the function to use State and handle Random option
                def update_prompt_type(value, state):
                    if value == "Random":
                        new_value = random.choice([t for t in prompt_types if t != "Random"])
                        print(f"Random prompt type selected: {new_value}")
                        return value, new_value
                    print(f"Updated prompt type: {value}")
                    return value, value
                
                # Connect the update_prompt_type function to the prompt_type dropdown
                prompt_type.change(update_prompt_type, inputs=[prompt_type, prompt_type_state], outputs=[prompt_type, prompt_type_state])

            with gr.Column(scale=2):
                with gr.Accordion("LLM Prompt Generation", open=False):
                    long_talk = gr.Checkbox(label="Long Talk", value=True)
                    compress = gr.Checkbox(label="Compress", value=True)
                    compression_level = gr.Dropdown(
                        choices=["soft", "medium", "hard"],
                        label="Compression Level",
                        value="hard"
                    )

                    custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5)

                # LLM Provider Selection
                llm_provider = gr.Dropdown(
                    choices=["Hugging Face", "Groq", "SambaNova"],
                    label="LLM Provider",
                    value="Groq"
                )
                api_key = gr.Textbox(label="API Key", type="password", visible=False)
                model = gr.Dropdown(label="Model", choices=["llama-3.1-70b-versatile", "mixtral-8x7b-32768", "llama-3.2-90b-text-preview"], value="llama-3.2-90b-text-preview")
        with gr.Row():
            # **Single Button for Generating Prompt and Text**
            generate_button = gr.Button("Generate Prompt")
        with gr.Row():
            text_output = gr.Textbox(label="LLM Generated Text", lines=10, show_copy_button=True)
            image_output = gr.Image(label="Generated Image", type="pil")

        # Updated Models based on provider
        def update_model_choices(provider):
            provider_models = {
                "Hugging Face": [
                    "Qwen/Qwen2.5-72B-Instruct",
                    "meta-llama/Meta-Llama-3.1-70B-Instruct",
                    "mistralai/Mixtral-8x7B-Instruct-v0.1",
                    "mistralai/Mistral-7B-Instruct-v0.3"
                ],
                "Groq": [
                    "llama-3.1-70b-versatile",
                    "mixtral-8x7b-32768",
                    "llama-3.2-90b-text-preview"
                ],
                "SambaNova": [
                    "Meta-Llama-3.1-70B-Instruct",
                    "Meta-Llama-3.1-405B-Instruct",
                    "Meta-Llama-3.1-8B-Instruct"
                ],
            }
            models = provider_models.get(provider, [])
            return gr.Dropdown(choices=models, value=models[0] if models else "")
        
        def update_api_key_visibility(provider):
            return gr.update(visible=False)  # No API key required for selected providers

        llm_provider.change(
            update_model_choices, 
            inputs=[llm_provider], 
            outputs=[model]
        )
        llm_provider.change(
            update_api_key_visibility, 
            inputs=[llm_provider], 
            outputs=[api_key]
        )

        # **Unified Function to Generate Prompt and Text**
        def generate_random_prompt_with_llm(custom_input, prompt_type, long_talk, compress, compression_level, custom_base_prompt, provider, api_key, model_selected, prompt_type_state):
            try:
                # Step 1: Generate Prompt
                dynamic_seed = random.randint(0, 1000000)
                
                # Update prompt_type if it's "Random"
                if prompt_type == "Random":
                    prompt_type = random.choice([t for t in prompt_types if t != "Random"])
                    print(f"Random prompt type selected: {prompt_type}")
                
                if custom_input and custom_input.strip():
                    prompt = llm_node.generate_prompt(dynamic_seed, prompt_type, custom_input)
                    print(f"Using Custom Input Prompt.")
                else:
                    prompt = llm_node.generate_prompt(dynamic_seed, prompt_type, f"Create a random prompt based on the '{prompt_type}' type.")
                    print(f"No Custom Input Prompt provided. Generated prompt based on prompt_type: {prompt_type}")

                print(f"Generated Prompt: {prompt}")

                # Step 2: Generate Text with LLM
                poster = False  # Set a default value or modify as needed
                result = llm_node.generate(
                    input_text=prompt,
                    long_talk=long_talk,
                    compress=compress,
                    compression_level=compression_level,
                    poster=poster,
                    prompt_type=prompt_type,  # Use the updated prompt_type here
                    custom_base_prompt=custom_base_prompt,
                    provider=provider,
                    api_key=api_key,
                    model=model_selected
                )
                print(f"Generated Text: {result}")

                return result

            except Exception as e:
                print(f"An error occurred: {e}")
                return f"Error occurred while processing the request: {str(e)}"

        # **Connect the Unified Function to the Single Button**
        generate_button.click(
            generate_random_prompt_with_llm,
            inputs=[custom, prompt_type, long_talk, compress, compression_level, custom_base_prompt, llm_provider, api_key, model, prompt_type_state],
            outputs=[text_output],
            api_name="generate_random_prompt_with_llm"
        )

        # Add image generation button
        generate_image_button = gr.Button("Generate Image")

        # Function to generate image
        def generate_image(text):
            try:
                seed = random.randint(0, 1000000)
                image_path = llm_node.generate_image(text, seed=seed)
                return image_path
            except Exception as e:
                print(f"An error occurred while generating the image: {e}")
                return None

        # Connect the image generation button
        generate_image_button.click(
            generate_image,
            inputs=[text_output],
            outputs=[image_output]
        )

    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.launch(share=True)