Spaces:
Running
Running
import os | |
import uuid | |
import redis | |
import torch | |
import scipy | |
from transformers import ( | |
pipeline, AutoTokenizer, AutoModelForCausalLM, AutoProcessor, | |
MusicgenForConditionalGeneration, WhisperProcessor, WhisperForConditionalGeneration, | |
MarianMTModel, MarianTokenizer, BartTokenizer, BartForConditionalGeneration | |
) | |
from diffusers import ( | |
FluxPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler, | |
StableDiffusionImg2ImgPipeline, DiffusionPipeline | |
) | |
from diffusers.utils import export_to_video | |
from datasets import load_dataset | |
from PIL import Image | |
import gradio as gr | |
from dotenv import load_dotenv | |
import multiprocessing | |
load_dotenv() | |
redis_client = redis.Redis( | |
host=os.getenv('REDIS_HOST'), | |
port=os.getenv('REDIS_PORT'), | |
password=os.getenv("REDIS_PASSWORD") | |
) | |
huggingface_token = os.getenv('HF_TOKEN') | |
def generate_unique_id(): | |
return str(uuid.uuid4()) | |
def store_special_tokens(tokenizer, model_name): | |
special_tokens = { | |
'pad_token': tokenizer.pad_token, | |
'pad_token_id': tokenizer.pad_token_id, | |
'eos_token': tokenizer.eos_token, | |
'eos_token_id': tokenizer.eos_token_id, | |
'unk_token': tokenizer.unk_token, | |
'unk_token_id': tokenizer.unk_token_id, | |
'bos_token': tokenizer.bos_token, | |
'bos_token_id': tokenizer.bos_token_id | |
} | |
redis_client.hmset(f"tokenizer_special_tokens:{model_name}", special_tokens) | |
def load_special_tokens(tokenizer, model_name): | |
special_tokens = redis_client.hgetall(f"tokenizer_special_tokens:{model_name}") | |
if special_tokens: | |
tokenizer.pad_token = special_tokens.get('pad_token', '').decode("utf-8") | |
tokenizer.pad_token_id = int(special_tokens.get('pad_token_id', -1)) | |
tokenizer.eos_token = special_tokens.get('eos_token', '').decode("utf-8") | |
tokenizer.eos_token_id = int(special_tokens.get('eos_token_id', -1)) | |
tokenizer.unk_token = special_tokens.get('unk_token', '').decode("utf-8") | |
tokenizer.unk_token_id = int(special_tokens.get('unk_token_id', -1)) | |
tokenizer.bos_token = special_tokens.get('bos_token', '').decode("utf-8") | |
tokenizer.bos_token_id = int(special_tokens.get('bos_token_id', -1)) | |
def train_and_store_transformers_model(model_name, data): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model.train() | |
store_special_tokens(tokenizer, model_name) | |
torch.save(model.state_dict(), "transformers_model.pt") | |
with open("transformers_model.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"transformers_model:{model_name}:state_dict", model_data) | |
tokenizer_data = tokenizer.save_pretrained("transformers_tokenizer") | |
redis_client.set(f"transformers_tokenizer:{model_name}", tokenizer_data) | |
def generate_transformers_response_from_redis(model_name, prompt): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"transformers_model:{model_name}:state_dict") | |
with open("transformers_model.pt", "wb") as f: | |
f.write(model_data) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model.load_state_dict(torch.load("transformers_model.pt")) | |
tokenizer_data = redis_client.get(f"transformers_tokenizer:{model_name}") | |
tokenizer = AutoTokenizer.from_pretrained("transformers_tokenizer") | |
load_special_tokens(tokenizer, model_name) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(inputs.input_ids, max_length=50) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
redis_client.set(f"transformers_response:{unique_id}", response) | |
return response | |
def train_and_store_diffusers_model(model_name, data): | |
pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16) | |
pipe.enable_model_cpu_offload() | |
pipe.train() | |
pipe.save_pretrained("diffusers_model") | |
with open("diffusers_model/flux_pipeline.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"diffusers_model:{model_name}", model_data) | |
def generate_diffusers_image_from_redis(model_name, prompt): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"diffusers_model:{model_name}") | |
with open("diffusers_model/flux_pipeline.pt", "wb") as f: | |
f.write(model_data) | |
pipe = FluxPipeline.from_pretrained("diffusers_model", torch_dtype=torch.bfloat16) | |
pipe.enable_model_cpu_offload() | |
image = pipe(prompt, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, | |
generator=torch.Generator("cpu").manual_seed(0)).images[0] | |
image_path = f"images/diffusers_{unique_id}.png" | |
image.save(image_path) | |
redis_client.set(f"diffusers_image:{unique_id}", image_path) | |
return image | |
def train_and_store_musicgen_model(model_name, data): | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = MusicgenForConditionalGeneration.from_pretrained(model_name) | |
model.train() | |
torch.save(model.state_dict(), "musicgen_model.pt") | |
with open("musicgen_model.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"musicgen_model:{model_name}:state_dict", model_data) | |
processor_data = processor.save_pretrained("musicgen_processor") | |
redis_client.set(f"musicgen_processor:{model_name}", processor_data) | |
def generate_musicgen_audio_from_redis(model_name, text_prompts): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"musicgen_model:{model_name}:state_dict") | |
with open("musicgen_model.pt", "wb") as f: | |
f.write(model_data) | |
model = MusicgenForConditionalGeneration.from_pretrained(model_name) | |
model.load_state_dict(torch.load("musicgen_model.pt")) | |
processor_data = redis_client.get(f"musicgen_processor:{model_name}") | |
processor = AutoProcessor.from_pretrained("musicgen_processor") | |
inputs = processor(text=text_prompts, padding=True, return_tensors="pt") | |
audio_values = model.generate(**inputs, max_new_tokens=256) | |
audio_path = f"audio/musicgen_{unique_id}.wav" | |
scipy.io.wavfile.write(audio_path, rate=audio_values["sampling_rate"], data=audio_values["audio"]) | |
redis_client.set(f"musicgen_audio:{unique_id}", audio_path) | |
return audio_path | |
def train_and_store_stable_diffusion_model(model_name, data): | |
pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe = pipe.to("cuda") | |
pipe.train() | |
pipe.save_pretrained("stable_diffusion_model") | |
with open("stable_diffusion_model/stable_diffusion_pipeline.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"stable_diffusion_model:{model_name}", model_data) | |
def generate_stable_diffusion_image_from_redis(model_name, prompt): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"stable_diffusion_model:{model_name}") | |
with open("stable_diffusion_model/stable_diffusion_pipeline.pt", "wb") as f: | |
f.write(model_data) | |
pipe = StableDiffusionPipeline.from_pretrained("stable_diffusion_model", torch_dtype=torch.float16) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe = pipe.to("cuda") | |
image = pipe(prompt).images[0] | |
image_path = f"images/stable_diffusion_{unique_id}.png" | |
image.save(image_path) | |
redis_client.set(f"stable_diffusion_image:{unique_id}", image_path) | |
return image | |
def train_and_store_img2img_model(model_name, data): | |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_name, torch_dtype=torch.float16) | |
pipe = pipe.to("cuda") | |
pipe.train() | |
pipe.save_pretrained("img2img_model") | |
with open("img2img_model/img2img_pipeline.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"img2img_model:{model_name}", model_data) | |
def generate_img2img_from_redis(model_name, init_image, prompt, strength=0.75): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"img2img_model:{model_name}") | |
with open("img2img_model/img2img_pipeline.pt", "wb") as f: | |
f.write(model_data) | |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("img2img_model", torch_dtype=torch.float16) | |
pipe = pipe.to("cuda") | |
init_image = Image.open(init_image).convert("RGB") | |
image = pipe(prompt=prompt, init_image=init_image, strength=strength).images[0] | |
image_path = f"images/img2img_{unique_id}.png" | |
image.save(image_path) | |
redis_client.set(f"img2img_image:{unique_id}", image_path) | |
return image | |
def train_and_store_marianmt_model(model_name, data): | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
model = MarianMTModel.from_pretrained(model_name) | |
model.train() | |
torch.save(model.state_dict(), "marianmt_model.pt") | |
with open("marianmt_model.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"marianmt_model:{model_name}:state_dict", model_data) | |
tokenizer_data = tokenizer.save_pretrained("marianmt_tokenizer") | |
redis_client.set(f"marianmt_tokenizer:{model_name}", tokenizer_data) | |
def translate_text_from_redis(model_name, text, src_lang, tgt_lang): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"marianmt_model:{model_name}:state_dict") | |
with open("marianmt_model.pt", "wb") as f: | |
f.write(model_data) | |
model = MarianMTModel.from_pretrained(model_name) | |
model.load_state_dict(torch.load("marianmt_model.pt")) | |
tokenizer_data = redis_client.get(f"marianmt_tokenizer:{model_name}") | |
tokenizer = MarianTokenizer.from_pretrained("marianmt_tokenizer") | |
inputs = tokenizer(text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang) | |
translated_tokens = model.generate(**inputs) | |
translation = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
redis_client.set(f"marianmt_translation:{unique_id}", translation) | |
return translation | |
def train_and_store_bart_model(model_name, data): | |
tokenizer = BartTokenizer.from_pretrained(model_name) | |
model = BartForConditionalGeneration.from_pretrained(model_name) | |
model.train() | |
torch.save(model.state_dict(), "bart_model.pt") | |
with open("bart_model.pt", "rb") as f: | |
model_data = f.read() | |
redis_client.set(f"bart_model:{model_name}:state_dict", model_data) | |
tokenizer_data = tokenizer.save_pretrained("bart_tokenizer") | |
redis_client.set(f"bart_tokenizer:{model_name}", tokenizer_data) | |
def summarize_text_from_redis(model_name, text): | |
unique_id = generate_unique_id() | |
model_data = redis_client.get(f"bart_model:{model_name}:state_dict") | |
with open("bart_model.pt", "wb") as f: | |
f.write(model_data) | |
model = BartForConditionalGeneration.from_pretrained(model_name) | |
model.load_state_dict(torch.load("bart_model.pt")) | |
tokenizer_data = redis_client.get(f"bart_tokenizer:{model_name}") | |
tokenizer = BartTokenizer.from_pretrained("bart_tokenizer") | |
load_special_tokens(tokenizer, model_name) | |
inputs = tokenizer(text, return_tensors="pt", truncation=True) | |
summary_ids = model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
redis_client.set(f"bart_summary:{unique_id}", summary) | |
return summary | |
def auto_train_and_store(model_name, task, data): | |
if task == "text-generation": | |
train_and_store_transformers_model(model_name, data) | |
elif task == "diffusers": | |
train_and_store_diffusers_model(model_name, data) | |
elif task == "musicgen": | |
train_and_store_musicgen_model(model_name, data) | |
elif task == "stable-diffusion": | |
train_and_store_stable_diffusion_model(model_name, data) | |
elif task == "img2img": | |
train_and_store_img2img_model(model_name, data) | |
elif task == "translation": | |
train_and_store_marianmt_model(model_name, data) | |
elif task == "summarization": | |
train_and_store_bart_model(model_name, data) | |
def transcribe_audio_from_redis(audio_file): | |
audio_file_path = "audio_file.wav" | |
with open(audio_file_path, "wb") as f: | |
f.write(audio_file) | |
processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
model.config.forced_decoder_ids = None | |
input_features = processor(audio_file, sampling_rate=16000, return_tensors="pt").input_features | |
predicted_ids = model.generate(input_features) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
return transcription[0] | |
def generate_image_from_redis(model_name, prompt, model_type): | |
if model_type == "diffusers": | |
image = generate_diffusers_image_from_redis(model_name, prompt) | |
elif model_type == "stable-diffusion": | |
image = generate_stable_diffusion_image_from_redis(model_name, prompt) | |
elif model_type == "img2img": | |
image = generate_img2img_from_redis(model_name, "init_image.png", prompt) | |
return image | |
def generate_video_from_redis(prompt): | |
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, | |
variant="fp16") | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_model_cpu_offload() | |
video_frames = pipe(prompt, num_inference_steps=25).frames | |
video_path = export_to_video(video_frames) | |
unique_id = generate_unique_id() | |
redis_client.set(f"video_{unique_id}", video_path) | |
return video_path | |
def generate_random_response(prompts, generator): | |
responses = [] | |
for prompt in prompts: | |
response = generator(prompt, max_length=50)[0]['generated_text'] | |
responses.append(response) | |
return responses | |
def process_parallel(tasks): | |
with multiprocessing.Pool() as pool: | |
results = pool.map(lambda task: task(), tasks) | |
return results | |
def generate_response_from_prompt(prompt, model_name="google/flan-t5-xl"): | |
generator = pipeline('text-generation', model=model_name, tokenizer=model_name) | |
responses = generate_random_response([prompt], generator) | |
return responses[0] | |
def generate_image_from_prompt(prompt, image_type, model_name="CompVis/stable-diffusion-v1-4"): | |
if image_type == "diffusers": | |
image = generate_diffusers_image_from_redis(model_name, prompt) | |
elif image_type == "stable-diffusion": | |
image = generate_stable_diffusion_image_from_redis(model_name, prompt) | |
elif image_type == "img2img": | |
image = generate_img2img_from_redis(model_name, "init_image.png", prompt) | |
return image | |
def gradio_app(): | |
with gr.Blocks() as app: | |
gr.Markdown( | |
""" | |
# IA Generativa con Transformers y Diffusers | |
Explora diferentes modelos de IA para generar texto, im谩genes, audio, video y m谩s. | |
""" | |
) | |
with gr.Tab("Texto"): | |
with gr.Row(): | |
with gr.Column(): | |
prompt_text = gr.Textbox(label="Texto de Entrada", placeholder="Ingresa tu prompt de texto aqu铆...") | |
text_button = gr.Button("Generar Texto", variant="primary") | |
with gr.Column(): | |
text_output = gr.Textbox(label="Respuesta") | |
text_button.click(generate_response_from_prompt, inputs=prompt_text, outputs=text_output) | |
with gr.Tab("Imagen"): | |
with gr.Row(): | |
with gr.Column(): | |
prompt_image = gr.Textbox(label="Prompt de Imagen", | |
placeholder="Ingresa tu prompt de imagen aqu铆...") | |
image_type = gr.Dropdown(["diffusers", "stable-diffusion", "img2img"], label="Tipo de Modelo", | |
value="stable-diffusion") | |
model_name_image = gr.Textbox(label="Nombre del Modelo", | |
value="CompVis/stable-diffusion-v1-4") | |
image_button = gr.Button("Generar Imagen", variant="primary") | |
with gr.Column(): | |
image_output = gr.Image(label="Imagen Generada") | |
image_button.click(generate_image_from_prompt, inputs=[prompt_image, image_type, model_name_image], | |
outputs=image_output) | |
with gr.Tab("Video"): | |
with gr.Row(): | |
with gr.Column(): | |
prompt_video = gr.Textbox(label="Prompt de Video", placeholder="Ingresa tu prompt de video aqu铆...") | |
video_button = gr.Button("Generar Video", variant="primary") | |
with gr.Column(): | |
video_output = gr.Video(label="Video Generado") | |
video_button.click(generate_video_from_redis, inputs=prompt_video, outputs=video_output) | |
with gr.Tab("Audio"): | |
with gr.Row(): | |
with gr.Column(): | |
model_name_audio = gr.Textbox(label="Nombre del Modelo", value="facebook/musicgen-small") | |
text_prompts_audio = gr.Textbox(label="Prompts de Audio", | |
placeholder="Ingresa tus prompts de audio aqu铆...") | |
audio_button = gr.Button("Generar Audio", variant="primary") | |
with gr.Column(): | |
audio_output = gr.Audio(label="Audio Generado") | |
audio_button.click(generate_musicgen_audio_from_redis, inputs=[model_name_audio, text_prompts_audio], | |
outputs=audio_output) | |
with gr.Tab("Transcripci贸n"): | |
with gr.Row(): | |
with gr.Column(): | |
audio_file = gr.Audio(type="filepath", label="Archivo de Audio") | |
audio_button = gr.Button("Transcribir Audio", variant="primary") | |
with gr.Column(): | |
transcription_output = gr.Textbox(label="Transcripci贸n") | |
audio_button.click(transcribe_audio_from_redis, inputs=audio_file, outputs=transcription_output) | |
with gr.Tab("Traducci贸n"): | |
with gr.Row(): | |
with gr.Column(): | |
model_name_translate = gr.Textbox(label="Nombre del Modelo", value="Helsinki-NLP/opus-mt-en-es") | |
text_input = gr.Textbox(label="Texto a Traducir", placeholder="Ingresa el texto a traducir...") | |
src_lang_input = gr.Textbox(label="Idioma de Origen", value="en") | |
tgt_lang_input = gr.Textbox(label="Idioma de Destino", value="es") | |
translate_button = gr.Button("Traducir Texto", variant="primary") | |
with gr.Column(): | |
translation_output = gr.Textbox(label="Traducci贸n") | |
translate_button.click(translate_text_from_redis, | |
inputs=[model_name_translate, text_input, src_lang_input, tgt_lang_input], | |
outputs=translation_output) | |
with gr.Tab("Resumen"): | |
with gr.Row(): | |
with gr.Column(): | |
model_name_summarize = gr.Textbox(label="Nombre del Modelo", value="facebook/bart-large-cnn") | |
text_to_summarize = gr.Textbox(label="Texto para Resumir", | |
placeholder="Ingresa el texto a resumir...") | |
summarize_button = gr.Button("Generar Resumen", variant="primary") | |
with gr.Column(): | |
summary_output = gr.Textbox(label="Resumen") | |
summarize_button.click(summarize_text_from_redis, inputs=[model_name_summarize, text_to_summarize], | |
outputs=summary_output) | |
app.launch() | |
if __name__ == "__main__": | |
gradio_app() |