aquibmoin commited on
Commit
536372f
1 Parent(s): 9b233f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModel
3
+ from openai import OpenAI
4
+ import os
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+
8
+ # Load the NASA-specific bi-encoder model and tokenizer
9
+ bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
10
+ bi_tokenizer = AutoTokenizer.from_pretrained(bi_encoder_model_name)
11
+ bi_model = AutoModel.from_pretrained(bi_encoder_model_name)
12
+
13
+ # Set up OpenAI client
14
+ api_key = os.getenv('OPENAI_API_KEY')
15
+ client = OpenAI(api_key=api_key)
16
+
17
+ # Define a system message to introduce Exos
18
+ system_message = "You are Exos, a helpful assistant specializing in Exoplanet research. Provide detailed and accurate responses related to Exoplanet research."
19
+
20
+ def encode_text(text):
21
+ inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
22
+ outputs = bi_model(**inputs)
23
+ return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten() # Ensure the output is 2D
24
+
25
+ def retrieve_relevant_context(user_input, context_texts):
26
+ user_embedding = encode_text(user_input).reshape(1, -1)
27
+ context_embeddings = np.array([encode_text(text) for text in context_texts])
28
+ context_embeddings = context_embeddings.reshape(len(context_embeddings), -1) # Flatten each embedding
29
+ similarities = cosine_similarity(user_embedding, context_embeddings).flatten()
30
+ most_relevant_idx = np.argmax(similarities)
31
+ return context_texts[most_relevant_idx]
32
+
33
+ def generate_response(user_input, relevant_context="", max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
34
+ if relevant_context:
35
+ combined_input = f"Context: {relevant_context}\nQuestion: {user_input}\nAnswer:"
36
+ else:
37
+ combined_input = f"Question: {user_input}\nAnswer:"
38
+
39
+ response = client.chat.completions.create(
40
+ model="gpt-4",
41
+ messages=[
42
+ {"role": "system", "content": system_message},
43
+ {"role": "user", "content": combined_input}
44
+ ],
45
+ max_tokens=max_tokens,
46
+ temperature=temperature,
47
+ top_p=top_p,
48
+ frequency_penalty=frequency_penalty,
49
+ presence_penalty=presence_penalty
50
+ )
51
+ return response.choices[0].message.content.strip()
52
+
53
+ def chatbot(user_input, context="", use_encoder=False, max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
54
+ if use_encoder and context:
55
+ context_texts = context.split("\n")
56
+ relevant_context = retrieve_relevant_context(user_input, context_texts)
57
+ else:
58
+ relevant_context = ""
59
+ response = generate_response(user_input, relevant_context, max_tokens, temperature, top_p, frequency_penalty, presence_penalty)
60
+ return response
61
+
62
+ # Create the Gradio interface
63
+ iface = gr.Interface(
64
+ fn=chatbot,
65
+ inputs=[
66
+ gr.Textbox(lines=2, placeholder="Enter your message here...", label="Your Question"),
67
+ gr.Textbox(lines=5, placeholder="Enter context here, separated by new lines...", label="Context"),
68
+ gr.Checkbox(label="Use NASA SMD Bi-Encoder for Context"),
69
+ gr.Slider(50, 500, value=150, step=10, label="Max Tokens"),
70
+ gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature"),
71
+ gr.Slider(0.0, 1.0, value=0.9, step=0.1, label="Top-p"),
72
+ gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Frequency Penalty"),
73
+ gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Presence Penalty")
74
+ ],
75
+ outputs=gr.Textbox(label="Exos says..."),
76
+ title="Exos - Your Exoplanet Research Assistant",
77
+ description="Exos is a helpful assistant specializing in Exoplanet research. Provide context to get more refined and relevant responses.",
78
+ )
79
+
80
+ # Launch the interface
81
+ iface.launch(share=True)