Spaces:
Runtime error
Runtime error
''' | |
CONFIG AND IMPORTS | |
''' | |
from config import default_config | |
from types import SimpleNamespace | |
import gradio as gr | |
import os, random | |
from pathlib import Path | |
import tiktoken | |
from getpass import getpass | |
from openai import OpenAI | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY", "")) | |
from langchain.text_splitter import MarkdownHeaderTextSplitter | |
import numpy as np | |
from langchain.embeddings import OpenAIEmbeddings | |
# from langchain.vectorstores import Chroma | |
from typing import Iterable | |
from gradio.themes.base import Base | |
from gradio.themes.utils import colors, fonts, sizes | |
import time | |
if os.getenv("OPENAI_API_KEY") is None: | |
if any(['VSCODE' in x for x in os.environ.keys()]): | |
print('Please enter password in the VS Code prompt at the top of your VS Code window!') | |
os.environ["OPENAI_API_KEY"] = getpass("Paste your OpenAI key from: https://platform.openai.com/account/api-keys\n") | |
assert os.getenv("OPENAI_API_KEY", "").startswith("sk-"), "This doesn't look like a valid OpenAI API key" | |
print("OpenAI API key configured") | |
embeddings_model = OpenAIEmbeddings() | |
md = "" | |
directory_path = "safety_docs" | |
for filename in os.listdir(directory_path): | |
if filename.endswith(".md"): | |
with open(os.path.join(directory_path, filename), 'r') as file: | |
content = file.read() | |
md = md + content | |
markdown_document = md | |
headers_to_split_on = [ | |
("#", "Header 1"), | |
("##", "Header 2"), | |
("###", "Header 3"), | |
] | |
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on) | |
md_header_splits = markdown_splitter.split_text(markdown_document) | |
def find_nearest_neighbor(argument="", max_args_in_output=2): | |
''' | |
INPUT: | |
argument (string) | |
RETURN the nearest neighbor(s) in vectorDB to argument as string | |
''' | |
embeddings = embeddings_model | |
embedding_matrix = np.array([embeddings.embed_query(text.page_content) for text in md_header_splits]) | |
argument_embedding = embeddings.embed_query(argument) | |
dot_products = np.dot(embedding_matrix, argument_embedding) | |
norms = np.linalg.norm(embedding_matrix, axis=1) * np.linalg.norm(argument_embedding) | |
cosine_similarities = dot_products / norms | |
nearest_indices = np.argsort(cosine_similarities)[-max_args_in_output:][::-1] | |
arr = [md_header_splits[index].metadata for index in nearest_indices] | |
output = "" | |
for thing in arr: | |
output = output + thing['Header 1'] + "\n" | |
return output | |
def get_gpt_response(user_prompt, system_prompt=default_config.system_prompt, model=default_config.model_name, n=1, max_tokens=200): | |
''' | |
INPUT: | |
Argument | |
user_prompt | |
system_prompt | |
model | |
''' | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}, | |
] | |
response = client.chat.completions.create(model=model, | |
messages=messages, | |
n=n, | |
max_tokens=max_tokens) | |
for choice in response.choices: | |
generation = choice.message.content | |
return generation | |
# return the gpt generated response | |
def greet1(argument): | |
user_prompt = default_config.user_prompt_1 + argument + default_config.user_prompt_2 | |
response = get_gpt_response(user_prompt=user_prompt) | |
return response | |
# return the nearest neighbor arguments | |
def greet2(argument): | |
nearest_neighbor = find_nearest_neighbor(argument) | |
return "Your argument may fall under the common arguments against AI safety. \nIs it one of these? \n" + nearest_neighbor + "\nSee the taxonomy of arguments below" | |
# theme = gr.themes.Monochrome() | |
theme = gr.themes.Monochrome( | |
# neutral_hue=gr.themes.colors.red, | |
# n, boxes, text, nothing bottom text most text | |
neutral_hue=gr.themes.Color("red", "#636363", "#636363", "lightgrey", "lightgrey", "lightgrey", "lightgrey", "grey", "red", "black", "red"), | |
primary_hue=gr.themes.Color("#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010"), | |
secondary_hue=gr.themes.Color("white", "white", "white", "white", "white", "white", "white", "white", "white", "white", "white"), | |
) | |
theme = theme.set( | |
body_background_fill="black", | |
block_title_background_fill="black", | |
block_background_fill="black", | |
body_text_color="white", | |
link_text_color='*primary_50', | |
link_text_color_dark='*primary_50', | |
link_text_color_active='*primary_50', | |
link_text_color_active_dark='*primary_50', | |
link_text_color_hover='*primary_50', | |
link_text_color_hover_dark='*primary_50', | |
link_text_color_visited='*primary_50', | |
link_text_color_visited_dark='*primary_50' | |
) | |
css_string = """ | |
@import url('https://fonts.googleapis.com/css2?family=Gabarito&family=Gothic+A1:wght@100;200;300;400;500;600;700;800;900&display=swap'); | |
force_black_bg { | |
background-color: blue !important; | |
color: white !important; | |
font-family: 'Gabarito', cursive !important; | |
} | |
force_black_bg *{ | |
background-color: blue !important; | |
color: white !important; | |
font-family: 'Gabarito', cursive !important; | |
} | |
footer{ | |
display:none !important | |
} | |
""" | |
css_string2 = "" | |
# with gr.Blocks(theme=theme, css=css_string2) as demo: | |
# with gr.Row(elem_id="force_black_bg"): | |
# with gr.Column(elem_id="force_black_bg"): | |
# seed = gr.Text( label="AI Safety Skepticism: What's Your Take?", placeholder="Enter an argument or something you'd like to say!") | |
# btn = gr.Button("Generate >") | |
# english = gr.Text(elem_id="themed_question_box", label="Common Argument Classifier") | |
# with gr.Column(): | |
# german = gr.Text(label="Safetybot Response") | |
# btn.click(greet2, inputs=[seed],outputs=english) | |
# btn.click(greet1, inputs=[seed],outputs=german) | |
# gr.Examples(["AGI is far away, I'm not worried", "AI is confined to a computer and cannot interact with the physical world", "AI isn't concious", "If we don't develop AGI, China will!", "If we don't develop AGI, the Americans will!"], inputs=[seed]) | |
with gr.Blocks(css=css_string) as demo: | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox() | |
clear = gr.ClearButton([msg, chatbot]) | |
def respond(message, chat_history): | |
bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"]) | |
chat_history.append((message, bot_message)) | |
time.sleep(2) | |
return "", chat_history | |
msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
demo.queue() | |
demo.launch() | |