Spaces:
Sleeping
Sleeping
import os | |
import openai | |
import json | |
import graphviz | |
import streamlit as st | |
class MindMap: | |
def __init__(self): | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
def get_connections(self, text_chunks_libs:dict) -> list: | |
state_prompt = open("./prompts/mindmap.prompt") | |
PROMPT = state_prompt.read() | |
state_prompt.close() | |
final_connections = [] | |
for key in text_chunks_libs: | |
for text_chunk in text_chunks_libs[key]: | |
PROMPT = PROMPT.replace("$prompt", text_chunk) | |
response = openai.Completion.create( | |
engine="text-davinci-003", | |
prompt = PROMPT, | |
temperature=0.5, | |
max_tokens=2048, | |
top_p=1, | |
frequency_penalty=0.0, | |
presence_penalty=0.0, | |
) | |
relationships = response.choices[0].text | |
final_string = '{"relations":' + relationships + '}' | |
data = json.loads(final_string) | |
relations = data["relations"] | |
final_connections.extend(relations) | |
return final_connections | |
def generate_graph(self, text_chunks_libs:dict): | |
graph = graphviz.Digraph() | |
all_connections = self.get_connections(text_chunks_libs) | |
for connection in all_connections: | |
from_node = connection[0] | |
to_node = connection[2] | |
graph.edge(from_node, to_node) | |
st.graphviz_chart(graph) |