from huggingface_hub import InferenceClient from gradio_client import Client import torch import nltk # we'll use this to split into sentences import numpy as np from transformers import BarkModel, AutoProcessor nltk.download('punkt') import gradio as gr import os def _grab_best_device(use_gpu=True): if torch.cuda.device_count() > 0 and use_gpu: device = "cuda" else: device = "cpu" return device device = _grab_best_device() SYST_PROMPT="""You're the storyteller, crafting a short tale for young listeners. Please abide by these guidelines: - Keep your sentences short, concise and easy to understand. - There should be only the narrator speaking. If there are dialogues, they should be indirect.""" #story_prompt = "A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson." story_prompt = "A princess breaks free from a dragon's grip. This evocates women empowerement and freedom." temperature = 0.9 top_p = 0.6 repetition_penalty = 1.2 TIMEOUT = int(os.environ.get("TIMEOUT", 45)) temperature = 0.9 top_p = 0.6 repetition_penalty = 1.2 # TODO: requirements: accelerate optimum text_client = InferenceClient( "mistralai/Mistral-7B-Instruct-v0.1", timeout=TIMEOUT, ) image_client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/545bst2bq/") image_negative_prompt = "ultrarealistic, soft lighting, 8k, ugly, text, blurry" image_positive_prompt = "" image_seed = 6 processor = AutoProcessor.from_pretrained("suno/bark") def format_speaker_key(key): key = key.replace("v2/", "").split("_") return f"Speaker {key[2]} ({key[0]})" voice_presets = [key for key in processor.speaker_embeddings.keys() if "v2/en" in key] voice_presets_dict = { format_speaker_key(key): key for key in voice_presets } model = BarkModel.from_pretrained("suno/bark", torch_dtype=torch.float16, use_flash_attention_2=True).to(device) sampling_rate = model.generation_config.sample_rate silence = np.zeros(int(0.25 * sampling_rate)) # quarter second of silence voice_preset = "v2/en_speaker_6" BATCH_SIZE = 32 # enable CPU offload model.enable_cpu_offload() # MISTRAL ONLY default_system_understand_message = ( "I understand, I am a Mistral chatbot." ) system_understand_message = os.environ.get( "SYSTEM_UNDERSTAND_MESSAGE", default_system_understand_message ) # Mistral formatter def format_prompt(message): prompt = ( "[INST]" + SYST_PROMPT + "[/INST]" + system_understand_message + "" ) prompt += f"[INST] {message} [/INST]" return prompt def generate_story( story_prompt, temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,): temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42, ) try: output = text_client.text_generation( format_prompt(story_prompt), **generate_kwargs, details=False, return_full_text=False, ) except Exception as e: if "Too Many Requests" in str(e): print("ERROR: Too many requests on mistral client") gr.Warning("Unfortunately Mistral is unable to process") output = "Unfortuanately I am not able to process your request now, too many people are asking me !" elif "Model not loaded on the server" in str(e): print("ERROR: Mistral server down") gr.Warning("Unfortunately Mistral LLM is unable to process") output = "Unfortuanately I am not able to process your request now, I have problem with Mistral!" else: print("Unhandled Exception: ", str(e)) gr.Warning("Unfortunately Mistral is unable to process") output = "I do not know what happened but I could not understand you." return output return output def generate_audio_and_image(story_prompt, voice_preset="Speaker 3 (en)"): story = generate_story(story_prompt) print(story) model_input = story.replace("\n", " ").strip() model_input = nltk.sent_tokenize(model_input) print("text generated - now calling for image") job_img = image_client.submit( story_prompt+image_positive_prompt, # str in 'parameter_11' Textbox component image_negative_prompt, # str in 'parameter_12' Textbox component 25, 7, 1024, 1024, image_seed, fn_index=0, ) print("image called - now generating audio") pieces = [] for i in range(0, len(model_input), BATCH_SIZE): inputs = model_input[i:min(i + BATCH_SIZE, len(model_input))] if len(inputs) != 0: inputs = processor(inputs, voice_preset=voice_presets_dict[voice_preset]) speech_output, output_lengths = model.generate(**inputs.to(device), return_output_lengths=True, min_eos_p=0.2) speech_output = [output[:length].cpu().numpy() for (output,length) in zip(speech_output, output_lengths)] print(f"{i}-th part generated") pieces += [*speech_output, silence.copy()] print("Calling image") try: img = job_img.result() except Exception as e: print("Unhandled Exception: ", str(e)) gr.Warning("Unfortunately there was an issue when generating the image with SDXL.") img = None return story, (sampling_rate, np.concatenate(pieces)), img # Gradio blocks demo with gr.Blocks() as demo_blocks: gr.Markdown("""

🐶Children story

""") gr.HTML("""

Let Mistral tell you a story

""") with gr.Group(): with gr.Row(): inp_text = gr.Textbox(label="Story prompt", info="Enter text here") with gr.Row(): with gr.Accordion("Advanced settings", open=False): voice_preset = gr.Dropdown( voice_presets_dict, value="Speaker 6 (en)", label="Available speakers", ) with gr.Row(): btn = gr.Button("Create a story") with gr.Row(): with gr.Column(scale=1): image_output = gr.Image(elem_id="gallery") with gr.Row(): out_audio = gr.Audio( streaming=False, autoplay=True) # needed to stream output audio out_text = gr.Text() btn.click(generate_audio_and_image, [inp_text, voice_preset], [out_text, out_audio, image_output] ) #[out_audio]) #, out_count]) with gr.Row(): gr.Examples( [ "A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson.", "A princess breaks free from a dragon's grip. This evocates women empowerement and freedom.", "Tell me about the wonders of the world.", ], [inp_text], [out_text, out_audio, image_output], generate_audio_and_image, cache_examples=True, ) with gr.Row(): gr.Markdown( """ This Space uses **[Bark](https://huggingface.co/docs/transformers/main/en/model_doc/bark)**, [Mistral-7b-instruct](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [Fast SD-XL](https://huggingface.co/spaces/openskyml/fast-sdxl-stable-diffusion-xl)! """ ) demo_blocks.queue().launch(debug=True)