Spaces:
Sleeping
Sleeping
import gradio as gr | |
from share_btn import community_icon_html, loading_icon_html, share_js | |
import random | |
import re | |
import torch | |
from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, set_seed | |
from optimum.intel.openvino import OVStableDiffusionPipeline | |
horoscope_model_id = "shahp7575/gpt2-horoscopes" | |
tokenizer = AutoTokenizer.from_pretrained(horoscope_model_id) | |
model = AutoModelWithLMHead.from_pretrained(horoscope_model_id) | |
text_generation_pipe = pipeline("text-generation", model="Gustavosta/MagicPrompt-Stable-Diffusion", tokenizer="gpt2") | |
stable_diffusion_pipe = OVStableDiffusionPipeline.from_pretrained("echarlaix/stable-diffusion-v1-5-openvino", revision="fp16", compile=False) | |
height = 128 | |
width = 128 | |
stable_diffusion_pipe.reshape(batch_size=1, height=height, width=width, num_images_per_prompt=1) | |
stable_diffusion_pipe.compile() | |
def fn(sign, cat): | |
prompt = f"<|category|> {cat} <|horoscope|> {sign}" | |
prompt_encoded = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) | |
sample_outputs = model.generate( | |
prompt_encoded, | |
do_sample=True, | |
top_k=40, | |
max_length=300, | |
top_p=0.95, | |
temperature=0.95, | |
num_beams=4, | |
num_return_sequences=1, | |
) | |
final_out = tokenizer.decode(sample_outputs[0], skip_special_tokens=True) | |
starting_text = " ".join(final_out.split(" ")[4:]) | |
seed = random.randint(100, 1000000) | |
set_seed(seed) | |
response = text_generation_pipe(starting_text + " " + sign + " art", max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=1) | |
image = stable_diffusion_pipe(response[0]["generated_text"], height=height, width=width, num_inference_steps=30).images[0] | |
return [image, starting_text] | |
block = gr.Blocks(css="./css.css") | |
with block: | |
with gr.Group(): | |
with gr.Box(): | |
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): | |
text = gr.Dropdown( | |
label="Star Sign", | |
choices=["Aries", "Taurus","Gemini", "Cancer", "Leo", "Virgo", "Libra", "Scorpio", "Sagittarius", "Capricorn", "Aquarius", "Pisces"], | |
show_label=True, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
elem_id="prompt-text-input", | |
).style( | |
border=(True, False, True, True), | |
rounded=(True, False, False, True), | |
container=False, | |
) | |
text2 = gr.Dropdown( | |
choices=["Love", "Career", "Wellness"], | |
label="Category", | |
show_label=True, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
elem_id="prompt-text-input", | |
).style( | |
border=(True, True, True, True), | |
rounded=(True, False, False, True), | |
container=False, | |
) | |
btn = gr.Button("Generate image").style( | |
margin=False, | |
rounded=(False, True, True, False), | |
full_width=False, | |
) | |
gallery = gr.Image( | |
interactive=False, | |
label="Generated images", show_label=False, elem_id="gallery" | |
).style(grid=[2], height="auto") | |
text = gr.Textbox("Text") | |
with gr.Group(elem_id="container-advanced-btns"): | |
with gr.Group(elem_id="share-btn-container"): | |
community_icon = gr.HTML(community_icon_html) | |
loading_icon = gr.HTML(loading_icon_html) | |
share_button = gr.Button("Share to community", elem_id="share-btn") | |
btn.click(fn=fn, inputs=[text, text2], outputs=[gallery, text]) | |
share_button.click( | |
None, | |
[], | |
[], | |
_js=share_js, | |
) | |
block.queue(concurrency_count=40, max_size=20).launch(max_threads=150) | |