Spaces:
Sleeping
Sleeping
import gradio as gr | |
def generate_text(context, num_samples, context_length, model_name): | |
from base import main | |
from pathlib import Path | |
if model_name == "pythia_160m_deduped_custom" or model_name == "pythia_160m_deduped_huggingface": | |
ckpt_dir = Path('checkpoints/EleutherAI/pythia-160m-deduped') | |
elif model_name == "pythia_70m_deduped": | |
ckpt_dir = Path('checkpoints/EleutherAI/pythia-70m-deduped') | |
elif model_name == "pythia_410m_deduped": | |
ckpt_dir = Path('checkpoints/EleutherAI/pythia-410m-deduped') | |
context = str(context) | |
num_samples = int(num_samples) | |
context_length = int(context_length) | |
model_name = str(model_name) | |
output_msg_list = main(prompt=context, checkpoint_dir=ckpt_dir, model_name=model_name, num_samples=num_samples, max_new_tokens=context_length) | |
output_msg = str() | |
for idx, msg in enumerate(output_msg_list): | |
title = f"--Generated message : {idx + 1} using the model : {model_name}--\n" | |
output_msg += f"{title}\n" | |
output_msg += f"{msg}\n" | |
output_msg += f"\n" | |
return output_msg | |
def gradio_fn(context, num_samples, context_length, model_name): | |
output_txt_msg = generate_text(context, num_samples, context_length, model_name) | |
return output_txt_msg | |
markdown_description = """ | |
- This is a Gradio app that generates text based on | |
- given text context | |
- for given character length | |
- number of Samples | |
- using Selected GPT model | |
- Currently following models are available : | |
- **(a)** pythia_160m_deduped_huggingface **(b)** pythia_160m_deduped_custom \ | |
**(c)** pythia_410m_deduped **(d)** pythia_70m_deduped | |
""" | |
demo = gr.Interface(fn=gradio_fn, | |
inputs=[gr.Textbox(info="Start my passage with: 'I would like to'"), | |
gr.Number(value=1, minimum=1, maximum=5, \ | |
info="Number of samples to be generated min=1, max=5"), | |
gr.Slider(value=50, minimum=50, maximum=250, \ | |
info="Num characters for passage min=50, max=250"), | |
gr.Dropdown(["pythia_160m_deduped_huggingface", "pythia_160m_deduped_custom", | |
"pythia_410m_deduped", "pythia_70m_deduped"], \ | |
multiselect=False, label="Model-Name", \ | |
info="Pretrained model to be used for text generation")], | |
outputs=gr.Textbox(), | |
title="DialogGen - Dialogue Generator", | |
description=markdown_description, | |
article=" **Credits** : https://github.com/Lightning-AI/lit-gpt ") | |
demo.launch(share=True) | |