DialogGen / app.py
anilbhatt1's picture
Initial commit
54200b7
raw
history blame
2.8 kB
import gradio as gr
def generate_text(context, num_samples, context_length, model_name):
from base import main
from pathlib import Path
if model_name == "pythia_160m_deduped_custom" or model_name == "pythia_160m_deduped_huggingface":
ckpt_dir = Path('checkpoints/EleutherAI/pythia-160m-deduped')
elif model_name == "pythia_70m_deduped":
ckpt_dir = Path('checkpoints/EleutherAI/pythia-70m-deduped')
elif model_name == "pythia_410m_deduped":
ckpt_dir = Path('checkpoints/EleutherAI/pythia-410m-deduped')
context = str(context)
num_samples = int(num_samples)
context_length = int(context_length)
model_name = str(model_name)
output_msg_list = main(prompt=context, checkpoint_dir=ckpt_dir, model_name=model_name, num_samples=num_samples, max_new_tokens=context_length)
output_msg = str()
for idx, msg in enumerate(output_msg_list):
title = f"--Generated message : {idx + 1} using the model : {model_name}--\n"
output_msg += f"{title}\n"
output_msg += f"{msg}\n"
output_msg += f"\n"
return output_msg
def gradio_fn(context, num_samples, context_length, model_name):
output_txt_msg = generate_text(context, num_samples, context_length, model_name)
return output_txt_msg
markdown_description = """
- This is a Gradio app that generates text based on
- given text context
- for given character length
- number of Samples
- using Selected GPT model
- Currently following models are available :
- **(a)** pythia_160m_deduped_huggingface **(b)** pythia_160m_deduped_custom \
**(c)** pythia_410m_deduped **(d)** pythia_70m_deduped
"""
demo = gr.Interface(fn=gradio_fn,
inputs=[gr.Textbox(info="Start my passage with: 'I would like to'"),
gr.Number(value=1, minimum=1, maximum=5, \
info="Number of samples to be generated min=1, max=5"),
gr.Slider(value=50, minimum=50, maximum=250, \
info="Num characters for passage min=50, max=250"),
gr.Dropdown(["pythia_160m_deduped_huggingface", "pythia_160m_deduped_custom",
"pythia_410m_deduped", "pythia_70m_deduped"], \
multiselect=False, label="Model-Name", \
info="Pretrained model to be used for text generation")],
outputs=gr.Textbox(),
title="DialogGen - Dialogue Generator",
description=markdown_description,
article=" **Credits** : https://github.com/Lightning-AI/lit-gpt ")
demo.launch(share=True)