rene-cartesia / app.py
archit11's picture
Update app.py
d073849 verified
raw
history blame contribute delete
No virus
1.06 kB
import spaces
import gradio as gr
from cartesia_pytorch import ReneLMHeadModel
from transformers import AutoTokenizer
#import subprocess
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Load model and tokenizer
model = ReneLMHeadModel.from_pretrained("cartesia-ai/Rene-v0.1-1.3b-pytorch").half().cuda()
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf")
# Define the function to generate text
@spaces.GPU(duration=120)
def generate_text(input_text):
inputs = tokenizer([input_text], return_tensors="pt")
outputs = model.generate(inputs.input_ids.cuda(), max_length=50, top_k=100, top_p=0.99)
out_message = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return out_message
# Create Gradio interface
interface = gr.Interface(
fn=generate_text,
inputs="text",
outputs="text",
title="ReneLM Text Generator",
description="Generate text using ReneLMHeadModel from a prompt."
)
# Launch the Gradio app
interface.launch()