aquibmoin commited on
Commit
c7af9e1
1 Parent(s): 569a931

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -14
app.py CHANGED
@@ -4,6 +4,10 @@ from openai import OpenAI
4
  import os
5
  import numpy as np
6
  from sklearn.metrics.pairwise import cosine_similarity
 
 
 
 
7
 
8
  # Load the NASA-specific bi-encoder model and tokenizer
9
  bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
@@ -14,27 +18,55 @@ bi_model = AutoModel.from_pretrained(bi_encoder_model_name)
14
  api_key = os.getenv('OPENAI_API_KEY')
15
  client = OpenAI(api_key=api_key)
16
 
 
 
 
17
  # Define a system message to introduce Exos
18
  system_message = "You are Exos, a helpful assistant specializing in Exoplanet research. Provide detailed and accurate responses related to Exoplanet research."
19
 
20
  def encode_text(text):
21
  inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
22
  outputs = bi_model(**inputs)
23
- return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten() # Ensure the output is 2D
24
 
25
  def retrieve_relevant_context(user_input, context_texts):
26
  user_embedding = encode_text(user_input).reshape(1, -1)
27
  context_embeddings = np.array([encode_text(text) for text in context_texts])
28
- context_embeddings = context_embeddings.reshape(len(context_embeddings), -1) # Flatten each embedding
29
  similarities = cosine_similarity(user_embedding, context_embeddings).flatten()
30
  most_relevant_idx = np.argmax(similarities)
31
  return context_texts[most_relevant_idx]
32
 
33
- def generate_response(user_input, relevant_context="", max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  if relevant_context:
35
- combined_input = f"Context: {relevant_context}\nQuestion: {user_input}\nAnswer:"
36
  else:
37
- combined_input = f"Question: {user_input}\nAnswer:"
38
 
39
  response = client.chat.completions.create(
40
  model="gpt-4-turbo",
@@ -48,23 +80,84 @@ def generate_response(user_input, relevant_context="", max_tokens=150, temperatu
48
  frequency_penalty=frequency_penalty,
49
  presence_penalty=presence_penalty
50
  )
 
 
 
 
 
 
 
 
 
51
  return response.choices[0].message.content.strip()
52
 
 
 
 
 
 
 
 
 
 
 
 
53
  def chatbot(user_input, context="", use_encoder=False, max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
54
  if use_encoder and context:
55
  context_texts = context.split("\n")
56
  relevant_context = retrieve_relevant_context(user_input, context_texts)
57
  else:
58
  relevant_context = ""
59
- response = generate_response(user_input, relevant_context, max_tokens, temperature, top_p, frequency_penalty, presence_penalty)
60
- return response
61
 
62
- # Create the Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  iface = gr.Interface(
64
  fn=chatbot,
65
  inputs=[
66
- gr.Textbox(lines=2, placeholder="Enter your message here...", label="Your Question"),
67
- gr.Textbox(lines=5, placeholder="Enter context here, separated by new lines...", label="Context"),
68
  gr.Checkbox(label="Use NASA SMD Bi-Encoder for Context"),
69
  gr.Slider(50, 1000, value=150, step=10, label="Max Tokens"),
70
  gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature"),
@@ -72,10 +165,14 @@ iface = gr.Interface(
72
  gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Frequency Penalty"),
73
  gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Presence Penalty")
74
  ],
75
- outputs=gr.Textbox(label="Exos says..."),
76
- title="Exos - Your Exoplanet Research Assistant",
77
- description="Exos is a helpful assistant specializing in Exoplanet research. Provide context to get more refined and relevant responses.",
 
 
 
 
 
78
  )
79
 
80
- # Launch the interface
81
  iface.launch(share=True)
 
4
  import os
5
  import numpy as np
6
  from sklearn.metrics.pairwise import cosine_similarity
7
+ from docx import Document
8
+ import io
9
+ import tempfile
10
+ from astroquery.nasa_ads import ADS
11
 
12
  # Load the NASA-specific bi-encoder model and tokenizer
13
  bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
 
18
  api_key = os.getenv('OPENAI_API_KEY')
19
  client = OpenAI(api_key=api_key)
20
 
21
+ # Set up NASA ADS token
22
+ ADS.TOKEN = os.getenv('ADS_API_KEY') # Ensure your ADS API key is stored in environment variables
23
+
24
  # Define a system message to introduce Exos
25
  system_message = "You are Exos, a helpful assistant specializing in Exoplanet research. Provide detailed and accurate responses related to Exoplanet research."
26
 
27
  def encode_text(text):
28
  inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
29
  outputs = bi_model(**inputs)
30
+ return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()
31
 
32
  def retrieve_relevant_context(user_input, context_texts):
33
  user_embedding = encode_text(user_input).reshape(1, -1)
34
  context_embeddings = np.array([encode_text(text) for text in context_texts])
35
+ context_embeddings = context_embeddings.reshape(len(context_embeddings), -1)
36
  similarities = cosine_similarity(user_embedding, context_embeddings).flatten()
37
  most_relevant_idx = np.argmax(similarities)
38
  return context_texts[most_relevant_idx]
39
 
40
+ def fetch_nasa_ads_references(prompt):
41
+ try:
42
+ # Use the entire prompt for the query
43
+ simplified_query = prompt
44
+
45
+ # Query NASA ADS for relevant papers
46
+ papers = ADS.query_simple(simplified_query)
47
+
48
+ if not papers or len(papers) == 0:
49
+ return [("No results found", "N/A", "N/A")]
50
+
51
+ # Include authors in the references
52
+ references = [
53
+ (
54
+ paper['title'][0],
55
+ ", ".join(paper['author'][:3]) + (" et al." if len(paper['author']) > 3 else ""),
56
+ paper['bibcode']
57
+ )
58
+ for paper in papers[:5] # Limit to 5 references
59
+ ]
60
+ return references
61
+
62
+ except Exception as e:
63
+ return [("Error fetching references", str(e), "N/A")]
64
+
65
+ def generate_response(user_input, relevant_context="", references=[], max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
66
  if relevant_context:
67
+ combined_input = f"Context: {relevant_context}\nQuestion: {user_input}\nAnswer (please organize the answer in a structured format with topics and subtopics):"
68
  else:
69
+ combined_input = f"Question: {user_input}\nAnswer (please organize the answer in a structured format with topics and subtopics):"
70
 
71
  response = client.chat.completions.create(
72
  model="gpt-4-turbo",
 
80
  frequency_penalty=frequency_penalty,
81
  presence_penalty=presence_penalty
82
  )
83
+
84
+ # Append references to the response
85
+ if references:
86
+ response_content = response.choices[0].message.content.strip()
87
+ references_text = "\n\nADS References:\n" + "\n".join(
88
+ [f"- {title} by {authors} (Bibcode: {bibcode})" for title, authors, bibcode in references]
89
+ )
90
+ return f"{response_content}\n{references_text}"
91
+
92
  return response.choices[0].message.content.strip()
93
 
94
+ def export_to_word(response_content):
95
+ doc = Document()
96
+ doc.add_heading('AI Generated SCDD', 0)
97
+ for line in response_content.split('\n'):
98
+ doc.add_paragraph(line)
99
+
100
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".docx")
101
+ doc.save(temp_file.name)
102
+
103
+ return temp_file.name
104
+
105
  def chatbot(user_input, context="", use_encoder=False, max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
106
  if use_encoder and context:
107
  context_texts = context.split("\n")
108
  relevant_context = retrieve_relevant_context(user_input, context_texts)
109
  else:
110
  relevant_context = ""
 
 
111
 
112
+ # Fetch NASA ADS references using the full prompt
113
+ references = fetch_nasa_ads_references(user_input)
114
+
115
+ # Generate response from GPT-4
116
+ response = generate_response(user_input, relevant_context, references, max_tokens, temperature, top_p, frequency_penalty, presence_penalty)
117
+
118
+ # Export the response to a Word document
119
+ word_doc_path = export_to_word(response)
120
+
121
+ # Embed Miro iframe
122
+ iframe_html = """
123
+ <iframe width="768" height="432" src="https://miro.com/app/live-embed/uXjVKuVTcF8=/?moveToViewport=-331,-462,5434,3063&embedId=710273023721" frameborder="0" scrolling="no" allow="fullscreen; clipboard-read; clipboard-write" allowfullscreen></iframe>
124
+ """
125
+
126
+ mapify_button_html = """
127
+ <style>
128
+ .mapify-button {
129
+ background: linear-gradient(135deg, #1E90FF 0%, #87CEFA 100%);
130
+ border: none;
131
+ color: white;
132
+ padding: 15px 35px;
133
+ text-align: center;
134
+ text-decoration: none;
135
+ display: inline-block;
136
+ font-size: 18px;
137
+ font-weight: bold;
138
+ margin: 20px 2px;
139
+ cursor: pointer;
140
+ border-radius: 25px;
141
+ transition: all 0.3s ease;
142
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
143
+ }
144
+ .mapify-button:hover {
145
+ background: linear-gradient(135deg, #4682B4 0%, #1E90FF 100%);
146
+ box-shadow: 0 6px 20px rgba(0, 0, 0, 0.3);
147
+ transform: scale(1.05);
148
+ }
149
+ </style>
150
+ <a href="https://mapify.so/app/new" target="_blank">
151
+ <button class="mapify-button">Create Mind Map on Mapify</button>
152
+ </a>
153
+ """
154
+ return response, iframe_html, mapify_button_html, word_doc_path
155
+
156
  iface = gr.Interface(
157
  fn=chatbot,
158
  inputs=[
159
+ gr.Textbox(lines=2, placeholder="Formulate your science goal...", label="Prompt"),
160
+ gr.Textbox(lines=5, placeholder="Enter some context here...", label="Context"),
161
  gr.Checkbox(label="Use NASA SMD Bi-Encoder for Context"),
162
  gr.Slider(50, 1000, value=150, step=10, label="Max Tokens"),
163
  gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature"),
 
165
  gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Frequency Penalty"),
166
  gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Presence Penalty")
167
  ],
168
+ outputs=[
169
+ gr.Textbox(label="Model Response..."),
170
+ gr.HTML(label="Miro"),
171
+ gr.HTML(label="Generate Mind Map on Mapify"),
172
+ gr.File(label="Download SCDD", type="filepath"),
173
+ ],
174
+ title="SCDDBot - NASA SMD SCDD AI Assistant [version-0.2a]",
175
+ description="SCDDBot is an AI-powered assistant for generating and visualising HWO Science Cases",
176
  )
177
 
 
178
  iface.launch(share=True)