Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from langchain.document_loaders import
|
3 |
from langchain.text_splitter import CharacterTextSplitter
|
4 |
from langchain.llms import HuggingFaceHub
|
5 |
from langchain.embeddings import HuggingFaceHubEmbeddings
|
@@ -36,8 +36,9 @@ def get_chain(llm, retriever):
|
|
36 |
)
|
37 |
return qa_chain
|
38 |
|
39 |
-
def
|
40 |
-
|
|
|
41 |
documents = loader.load()
|
42 |
text_splitter = CharacterTextSplitter(chunk_size=2096, chunk_overlap=0)
|
43 |
texts = text_splitter.split_documents(documents)
|
@@ -46,11 +47,10 @@ def pdf_changes(pdf_doc, repo_id):
|
|
46 |
retriever = db.as_retriever()
|
47 |
llm = HuggingFaceHub(
|
48 |
repo_id=repo_id,
|
49 |
-
model_kwargs={'temperature': 0.
|
50 |
)
|
51 |
-
|
52 |
-
|
53 |
-
return "Ready"
|
54 |
|
55 |
def generate_guideline(infrastructure, location, cyclone_predicted_coordinates, cyclone_speed):
|
56 |
if infrastructure and location and cyclone_predicted_coordinates and cyclone_speed:
|
@@ -72,9 +72,9 @@ with gr.Blocks(css=css, theme='Taithrah/Minimal') as demo:
|
|
72 |
with gr.Column(elem_id='col-container'):
|
73 |
gr.HTML(title)
|
74 |
|
75 |
-
|
76 |
repo_id = gr.Dropdown(
|
77 |
-
label='LLM',
|
78 |
choices=[
|
79 |
'mistralai/Mistral-7B-Instruct-v0.1',
|
80 |
'HuggingFaceH4/zephyr-7b-beta',
|
@@ -84,27 +84,46 @@ with gr.Blocks(css=css, theme='Taithrah/Minimal') as demo:
|
|
84 |
],
|
85 |
value='mistralai/Mistral-7B-Instruct-v0.1'
|
86 |
)
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
|
93 |
# Input fields for user information
|
94 |
infrastructure = gr.Textbox(label='Infrastructure')
|
95 |
-
location = gr.Textbox(label='Location
|
96 |
cyclone_predicted_coordinates = gr.Textbox(label='Predicted Cyclone Coordinates (lat,lon)')
|
97 |
cyclone_speed = gr.Textbox(label='Cyclone Speed in Knots')
|
98 |
|
99 |
submit_btn = gr.Button('Generate Guideline')
|
100 |
output = gr.Textbox(label='Personalized Guideline', lines=10)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
repo_id.change(
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
pdf_changes, inputs=[pdf_doc, repo_id], outputs=[langchain_status], queue=False
|
107 |
)
|
|
|
108 |
submit_btn.click(
|
109 |
generate_guideline,
|
110 |
inputs=[infrastructure, location, cyclone_predicted_coordinates, cyclone_speed],
|
|
|
1 |
import gradio as gr
|
2 |
+
from langchain.document_loaders import PyPDFLoader
|
3 |
from langchain.text_splitter import CharacterTextSplitter
|
4 |
from langchain.llms import HuggingFaceHub
|
5 |
from langchain.embeddings import HuggingFaceHubEmbeddings
|
|
|
36 |
)
|
37 |
return qa_chain
|
38 |
|
39 |
+
def load_pdf_to_langchain(pdf_path, repo_id):
|
40 |
+
# Load the PDF using PyPDFLoader
|
41 |
+
loader = PyPDFLoader(pdf_path)
|
42 |
documents = loader.load()
|
43 |
text_splitter = CharacterTextSplitter(chunk_size=2096, chunk_overlap=0)
|
44 |
texts = text_splitter.split_documents(documents)
|
|
|
47 |
retriever = db.as_retriever()
|
48 |
llm = HuggingFaceHub(
|
49 |
repo_id=repo_id,
|
50 |
+
model_kwargs={'temperature': 0.3, 'max_new_tokens': 2096}
|
51 |
)
|
52 |
+
qa_chain = get_chain(llm, retriever)
|
53 |
+
return qa_chain
|
|
|
54 |
|
55 |
def generate_guideline(infrastructure, location, cyclone_predicted_coordinates, cyclone_speed):
|
56 |
if infrastructure and location and cyclone_predicted_coordinates and cyclone_speed:
|
|
|
72 |
with gr.Column(elem_id='col-container'):
|
73 |
gr.HTML(title)
|
74 |
|
75 |
+
# LLM selection
|
76 |
repo_id = gr.Dropdown(
|
77 |
+
label='Select Language Model (LLM)',
|
78 |
choices=[
|
79 |
'mistralai/Mistral-7B-Instruct-v0.1',
|
80 |
'HuggingFaceH4/zephyr-7b-beta',
|
|
|
84 |
],
|
85 |
value='mistralai/Mistral-7B-Instruct-v0.1'
|
86 |
)
|
87 |
+
|
88 |
+
# Status display
|
89 |
+
langchain_status = gr.Textbox(
|
90 |
+
label='Status', placeholder='', interactive=False, value="Loading guideline1.pdf..."
|
91 |
+
)
|
92 |
|
93 |
# Input fields for user information
|
94 |
infrastructure = gr.Textbox(label='Infrastructure')
|
95 |
+
location = gr.Textbox(label='Location Coordinates (lat,lon)')
|
96 |
cyclone_predicted_coordinates = gr.Textbox(label='Predicted Cyclone Coordinates (lat,lon)')
|
97 |
cyclone_speed = gr.Textbox(label='Cyclone Speed in Knots')
|
98 |
|
99 |
submit_btn = gr.Button('Generate Guideline')
|
100 |
output = gr.Textbox(label='Personalized Guideline', lines=10)
|
101 |
|
102 |
+
# Global variable to store the QA chain
|
103 |
+
qa = None
|
104 |
+
|
105 |
+
# Function to initialize the QA chain
|
106 |
+
def initialize_qa(repo_id_value):
|
107 |
+
global qa
|
108 |
+
pdf_path = 'guideline1.pdf' # Ensure this PDF is in the same directory
|
109 |
+
qa = load_pdf_to_langchain(pdf_path, repo_id_value)
|
110 |
+
return f"Loaded guideline1.pdf with LLM: {repo_id_value}"
|
111 |
+
|
112 |
+
# Initialize QA chain with default LLM
|
113 |
+
initial_status = initialize_qa(repo_id.value)
|
114 |
+
langchain_status.value = initial_status
|
115 |
+
|
116 |
+
# Update QA chain when LLM selection changes
|
117 |
+
def on_repo_id_change(repo_id_value):
|
118 |
+
status = initialize_qa(repo_id_value)
|
119 |
+
return status
|
120 |
+
|
121 |
repo_id.change(
|
122 |
+
on_repo_id_change,
|
123 |
+
inputs=repo_id,
|
124 |
+
outputs=langchain_status
|
|
|
125 |
)
|
126 |
+
|
127 |
submit_btn.click(
|
128 |
generate_guideline,
|
129 |
inputs=[infrastructure, location, cyclone_predicted_coordinates, cyclone_speed],
|