train / app.py
jhansi1's picture
Update app.py
db1baea verified
import gradio as gr
import streamlit as st
from transformers import pipeline
from datasets import load_dataset
from huggingface_hub import hf_hub_download
import subprocess
import os
# Clone the dataset repository if not already cloned
repo_url = "https://huggingface.co/datasets/BEE-spoke-data/survivorslib-law-books"
repo_dir = "./survivorslib-law-books"
if not os.path.exists(repo_dir):
subprocess.run(["git", "clone", repo_url], check=True)
# Load the dataset from the cloned repository
dataset_path = os.path.join(repo_dir, "train.parquet")
ds = load_dataset("parquet", data_files=dataset_path)
# Initialize text-generation pipeline with the model
model_name = "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"
pipe = pipeline("text-generation", model=model_name)
# Preprocess dataset (assuming it has a 'text' or 'content' column for feeding to the model)
# If the dataset is different, update the field names accordingly
def preprocess_data(dataset):
# Here, we assume that the dataset has a 'content' column with legal text
# Adjust the column name as needed (for example, it might be 'text' or 'paragraph')
return dataset['content'][:5] # Displaying only the first 5 entries for brevity
# Gradio Interface setup
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in pipe(
prompt=message,
max_length=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
):
token = message["generated_text"]
response += token
yield response
# Streamlit Interface setup
def streamlit_interface():
st.title("Canadian Legal Text Generator")
st.write("Enter a prompt related to Canadian legal data and generate text using Llama-3.1.")
# Show dataset sample (first 5 entries)
st.subheader("Sample Data from Canadian Legal Dataset:")
sample_data = preprocess_data(ds['train']) # Assuming 'train' split
st.write(sample_data) # Display the first 5 rows of the dataset
# Prompt input
prompt = st.text_area("Enter your prompt:", placeholder="Type something...")
if st.button("Generate Response"):
if prompt:
# Generate text based on the prompt
with st.spinner("Generating response..."):
generated_text = pipe(prompt, max_length=100, do_sample=True, temperature=0.7)[0]["generated_text"]
st.write("**Generated Text:**")
st.write(generated_text)
else:
st.write("Please enter a prompt to generate a response.")
# Running Gradio and Streamlit interfaces
if __name__ == "__main__":
st.sidebar.title("Choose an Interface")
interface = st.sidebar.radio("Select", ("Streamlit", "Gradio"))
if interface == "Streamlit":
streamlit_interface()
else:
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
demo.launch()