juancho72h commited on
Commit
ba1509a
1 Parent(s): b837587

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -94
app.py CHANGED
@@ -2,112 +2,124 @@ import os
2
  import pinecone
3
  import openai
4
  import gradio as gr
5
- import torch
6
  from dotenv import load_dotenv
7
- from pinecone import Pinecone
8
- from langchain_community.embeddings import HuggingFaceEmbeddings # Updated import
9
- from rapidfuzz import fuzz # Replaced fuzzywuzzy with rapidfuzz
10
- import logging
11
- import re # To help with preprocessing
12
-
13
- # Set up logging
14
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
-
16
- # Detect GPU availability and set device
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- print(f"Running on device: {device}")
19
-
20
- # Suppress specific warning about clean_up_tokenization_spaces
21
- import warnings
22
- warnings.filterwarnings("ignore", category=FutureWarning, message="clean_up_tokenization_spaces was not set")
23
 
24
  # Load environment variables
25
  load_dotenv()
26
 
27
- # Access Pinecone and OpenAI API keys from environment variables
28
- pinecone_api_key = os.getenv("PINECONE_API_KEY")
29
  openai.api_key = os.getenv("OPENAI_API_KEY")
 
 
 
 
 
30
  index_name = "amtrak-acela-ai-demo"
31
 
32
- # Initialize Pinecone using a class-based method
33
- pc = Pinecone(api_key=pinecone_api_key)
 
 
 
 
 
 
 
 
34
 
35
- # Check if the index exists, if not, create it
36
- def initialize_pinecone_index(index_name):
37
  available_indexes = pc.list_indexes().names()
38
  if index_name not in available_indexes:
39
- print(f"Index '{index_name}' does not exist.")
40
- # Create the index here if necessary for ZeroGPU usage
 
 
 
 
 
 
 
41
  return pc.Index(index_name)
42
 
43
- index = initialize_pinecone_index(index_name)
 
44
 
45
  # Initialize HuggingFace embedding model
46
  embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/msmarco-distilbert-base-v4")
47
 
48
- # Initialize chat history manually
49
- chat_history = []
50
-
51
- # Helper function to preprocess text (removing unnecessary words)
52
- def preprocess_text(text):
53
- # Convert text to lowercase and remove special characters
54
- text = re.sub(r'[^\w\s]', '', text.lower())
55
- return text.strip()
56
-
57
- # Helper function to recursively flatten any list to a string
58
- def flatten_to_string(data):
59
- if isinstance(data, list):
60
- return " ".join([flatten_to_string(item) for item in data])
61
- if data is None:
62
- return ""
63
- return str(data)
64
-
65
- # Function to interact with Pinecone and OpenAI GPT-4
66
- def get_model_response(human_input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  try:
68
- # Preprocess the human input (cleaning up unnecessary words)
69
- processed_input = preprocess_text(human_input)
70
 
71
- # Embed the query
72
- query_embedding = torch.tensor(embedding_model.embed_query(human_input)).to(device)
73
- query_embedding = query_embedding.cpu().numpy().tolist()
74
 
75
- # Query Pinecone index with top_k=5 to get more potential matches
76
- search_results = index.query(vector=query_embedding, top_k=5, include_metadata=True)
 
77
 
78
- context_list, images = [], []
79
  for ind, result in enumerate(search_results['matches']):
80
- document_content = flatten_to_string(result.get('metadata', {}).get('content', 'No content found'))
81
- image_url = flatten_to_string(result.get('metadata', {}).get('image_path', None))
82
- figure_desc = flatten_to_string(result.get('metadata', {}).get('figure_description', ''))
83
 
84
- # Preprocess the figure description and match keywords
85
- processed_figure_desc = preprocess_text(figure_desc)
86
- similarity_score = fuzz.token_set_ratio(processed_input, processed_figure_desc)
87
- logging.info(f"Matching '{processed_input}' with '{processed_figure_desc}', similarity score: {similarity_score}")
88
 
89
- if similarity_score >= 80: # Keep the threshold at 80 for now
90
- context_list.append(f"Relevant information: {document_content}")
91
- if image_url and figure_desc:
92
- images.append((figure_desc, image_url))
93
 
 
94
  context_string = '\n\n'.join(context_list)
95
 
96
- # Add user message to chat history
97
- chat_history.append({"role": "user", "content": human_input})
98
-
99
- # Create messages for OpenAI's API
100
- messages = [{"role": "system", "content": "You are a helpful assistant."}] + chat_history + [
101
- {"role": "system", "content": f"Here is some context:\n{context_string}"},
102
- {"role": "user", "content": human_input}
103
  ]
104
 
105
- # Validate messages before sending to OpenAI
106
- for message in messages:
107
- if not isinstance(message, dict) or "role" not in message or "content" not in message:
108
- raise ValueError(f"Invalid message format: {message}")
109
-
110
- # Send the conversation to OpenAI's API
111
  response = openai.ChatCompletion.create(
112
  model="gpt-3.5-turbo",
113
  messages=messages,
@@ -115,32 +127,43 @@ def get_model_response(human_input):
115
  temperature=0.5
116
  )
117
 
 
118
  output_text = response['choices'][0]['message']['content'].strip()
119
 
120
- # Add assistant message to chat history
121
- chat_history.append({"role": "assistant", "content": output_text})
122
-
123
  return output_text, images
124
 
125
  except Exception as e:
126
  return f"Error invoking model: {str(e)}", []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- # Function to format text and images for display and track conversation
129
- def get_model_response_with_images(human_input, history=None):
130
- output_text, images = get_model_response(human_input)
131
- if images:
132
- # Append images in Markdown format for Gradio to render
133
- image_output = "".join([f"\n\n**{figure_desc}**\n![{figure_desc}]({image_path})" for figure_desc, image_path in images])
134
- return output_text + image_output
135
  return output_text
136
 
137
- # Set up Gradio interface
138
  gr_interface = gr.ChatInterface(
139
- fn=get_model_response_with_images,
140
- title="Maintenance Assistant",
141
- description="Ask questions related to the RMMM documents."
142
  )
143
 
144
- # Ensure ZeroGPU or Hugging Face Spaces handles launching properly
145
- if __name__ == "__main__":
146
- gr_interface.launch()
 
2
  import pinecone
3
  import openai
4
  import gradio as gr
 
5
  from dotenv import load_dotenv
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain.text_splitter import CharacterTextSplitter
8
+ from langchain.docstore.document import Document
9
+ import boto3
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Load environment variables
12
  load_dotenv()
13
 
14
+ # Access secrets from environment variables
 
15
  openai.api_key = os.getenv("OPENAI_API_KEY")
16
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
17
+ aws_access_key = os.getenv("AWS_ACCESS_KEY_ID")
18
+ aws_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
19
+ bucket_name = 'amtrak-superliner-ai-poc'
20
+ txt_file_name = 'combined_extracted_text.txt'
21
  index_name = "amtrak-acela-ai-demo"
22
 
23
+ # Initialize Pinecone using the new class-based method
24
+ pc = pinecone.Pinecone(api_key=pinecone_api_key)
25
+
26
+ # Initialize AWS S3 client
27
+ s3_client = boto3.client(
28
+ 's3',
29
+ aws_access_key_id=aws_access_key,
30
+ aws_secret_access_key=aws_secret_key,
31
+ region_name='us-east-1'
32
+ )
33
 
34
+ # Initialize Pinecone index (check if it exists, otherwise create it)
35
+ def initialize_pinecone_index(index_name, embedding_dim):
36
  available_indexes = pc.list_indexes().names()
37
  if index_name not in available_indexes:
38
+ pc.create_index(
39
+ name=index_name,
40
+ dimension=embedding_dim,
41
+ metric="cosine",
42
+ spec=pinecone.ServerlessSpec(
43
+ cloud="aws",
44
+ region="us-east-1"
45
+ )
46
+ )
47
  return pc.Index(index_name)
48
 
49
+ embedding_dim = 768
50
+ index = initialize_pinecone_index(index_name, embedding_dim)
51
 
52
  # Initialize HuggingFace embedding model
53
  embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/msmarco-distilbert-base-v4")
54
 
55
+ # Download and load text from S3
56
+ def download_text_from_s3(s3_client, bucket_name, file_name):
57
+ local_txt_path = os.path.join(os.getcwd(), file_name)
58
+ s3_client.download_file(bucket_name, file_name, local_txt_path)
59
+ with open(local_txt_path, 'r', encoding='utf-8') as f:
60
+ return f.read()
61
+
62
+ doc_text = download_text_from_s3(s3_client, bucket_name, txt_file_name)
63
+
64
+ # Split and embed the document text
65
+ def process_text_into_embeddings(doc_text):
66
+ text_splitter = CharacterTextSplitter(separator='\n', chunk_size=3000, chunk_overlap=500)
67
+ docs = text_splitter.split_documents([Document(page_content=doc_text)])
68
+ doc_embeddings = embedding_model.embed_documents([doc.page_content for doc in docs])
69
+ return docs, doc_embeddings
70
+
71
+ # Check if embeddings already exist in Pinecone
72
+ def check_embeddings_in_pinecone(index):
73
+ try:
74
+ stats = index.describe_index_stats()
75
+ return stats['total_vector_count'] > 0
76
+ except Exception as e:
77
+ print(f"Error checking Pinecone index: {e}")
78
+ return False
79
+
80
+ # Only process embeddings if they don't already exist in Pinecone
81
+ if not check_embeddings_in_pinecone(index):
82
+ split_docs, doc_embeddings = process_text_into_embeddings(doc_text)
83
+ for i, doc in enumerate(split_docs):
84
+ metadata = {'content': doc.page_content}
85
+ index.upsert(vectors=[(str(i), doc_embeddings[i], metadata)])
86
+ else:
87
+ print("Embeddings already exist in Pinecone. Skipping embedding process.")
88
+
89
+ # Query Pinecone and OpenAI GPT-4 to generate a response
90
+ def get_model_response(human_input, chat_history=None):
91
  try:
92
+ # Embed the query using the embedding model
93
+ query_embedding = embedding_model.embed_query(human_input)
94
 
95
+ # Query Pinecone index to retrieve relevant content
96
+ search_results = index.query(vector=query_embedding, top_k=3, include_metadata=True)
 
97
 
98
+ # Prepare content and image data
99
+ context_list = []
100
+ images = []
101
 
102
+ # Extract the content from Pinecone's search results
103
  for ind, result in enumerate(search_results['matches']):
104
+ document_content = result.get('metadata', {}).get('content', 'No content found')
105
+ image_url = result.get('metadata', {}).get('image_path', None)
106
+ figure_desc = result.get('metadata', {}).get('figure_description', '')
107
 
108
+ context_list.append(f"Document {ind+1}: {document_content}")
 
 
 
109
 
110
+ if image_url and figure_desc: # Only append images that exist and have description
111
+ images.append((figure_desc, image_url))
 
 
112
 
113
+ # Combine context from the search results
114
  context_string = '\n\n'.join(context_list)
115
 
116
+ # Build messages list for OpenAI
117
+ messages = [
118
+ {"role": "system", "content": "You are a helpful assistant."}, # System prompt
119
+ {"role": "user", "content": f"Here is some context:\n{context_string}\n\nUser's question: {human_input}"}
 
 
 
120
  ]
121
 
122
+ # Send the conversation to OpenAI's API, using GPT-3.5 instead of GPT-4
 
 
 
 
 
123
  response = openai.ChatCompletion.create(
124
  model="gpt-3.5-turbo",
125
  messages=messages,
 
127
  temperature=0.5
128
  )
129
 
130
+ # Get the model's response
131
  output_text = response['choices'][0]['message']['content'].strip()
132
 
133
+ # Return both the output and any images found
 
 
134
  return output_text, images
135
 
136
  except Exception as e:
137
  return f"Error invoking model: {str(e)}", []
138
+
139
+ # Function to format text and images for display
140
+ def get_model_response_with_history(human_input, chat_history=None):
141
+ if chat_history is None:
142
+ chat_history = []
143
+
144
+ output_text, chat_history = get_model_response(human_input, chat_history)
145
+
146
+ # Handle image display
147
+ def process_image(image_data):
148
+ if isinstance(image_data, list):
149
+ # If a list is passed, flatten it to a string
150
+ return " ".join(str(item) for item in image_data)
151
+ return str(image_data)
152
+
153
+ if chat_history:
154
+ # Ensure that any file/image alt_text is handled correctly
155
+ for message in chat_history:
156
+ if "alt_text" in message:
157
+ message["alt_text"] = process_image(message.get("alt_text", ""))
158
 
 
 
 
 
 
 
 
159
  return output_text
160
 
161
+ # Set up Gradio interface without share=True to avoid the error for now
162
  gr_interface = gr.ChatInterface(
163
+ fn=get_model_response_with_history,
164
+ title="Maintenance Assistant",
165
+ description="Ask questions related to the RMM documents."
166
  )
167
 
168
+ # Launch the Gradio interface
169
+ gr_interface.launch()