Leonardo Parente commited on
Commit
960c913
1 Parent(s): bc4906f
Files changed (2) hide show
  1. app.py +50 -37
  2. orbgptlogo.png +0 -0
app.py CHANGED
@@ -1,16 +1,19 @@
 
 
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  from langchain.memory import ConversationBufferMemory
4
  from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
5
- from langchain.chains import LLMChain
6
- from langchain.prompts import PromptTemplate
7
  from langchain.embeddings import VoyageEmbeddings
8
  from langchain.vectorstores import SupabaseVectorStore
9
  from langchain.llms.huggingface_pipeline import HuggingFacePipeline
10
  from st_supabase_connection import SupabaseConnection
11
 
12
  msgs = StreamlitChatMessageHistory()
13
- memory = ConversationBufferMemory(memory_key="history", chat_memory=msgs)
 
 
14
 
15
  supabase_client = st.connection(
16
  name="orbgpt",
@@ -18,45 +21,51 @@ supabase_client = st.connection(
18
  ttl=None,
19
  )
20
 
21
- embeddings = VoyageEmbeddings(model="voyage-01")
22
- vector_store = SupabaseVectorStore(
23
- embedding=embeddings,
24
- client=supabase_client,
25
- table_name="documents",
26
- query_name="match_documents",
27
- )
28
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- model_path = "01-ai/Yi-6B-Chat"
31
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
32
- model = AutoModelForCausalLM.from_pretrained(
33
- model_path,
34
- device_map="auto",
35
- offload_folder="offload",
36
- offload_state_dict=True,
37
- torch_dtype="auto",
38
- ).eval()
39
- pipe = pipeline(
40
- "text-generation",
41
- model=model,
42
- tokenizer=tokenizer,
43
- max_new_tokens=10,
44
- use_fast=False,
45
- )
46
- hf = HuggingFacePipeline(pipeline=pipe)
47
 
48
- template = """Question: {question}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- Answer: Let's think step by step."""
51
- prompt = PromptTemplate.from_template(template)
52
 
53
- chain = prompt | hf
 
 
54
 
55
- question = "What is electroencephalography?"
56
-
57
- st.text(chain.invoke({"question": question}))
 
 
 
58
 
59
- st.title("🪩🤖")
60
 
61
  if len(msgs.messages) == 0:
62
  msgs.add_ai_message("Ask me anything about orb community projects!")
@@ -66,5 +75,9 @@ for msg in msgs.messages:
66
 
67
  if prompt := st.chat_input("Ask something"):
68
  st.chat_message("human").write(prompt)
69
- # Run
70
- st.chat_message("ai").write("hehe")
 
 
 
 
 
1
+ import base64
2
+ from pathlib import Path
3
  import streamlit as st
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
  from langchain.memory import ConversationBufferMemory
6
  from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
7
+ from langchain.chains import ConversationalRetrievalChain
 
8
  from langchain.embeddings import VoyageEmbeddings
9
  from langchain.vectorstores import SupabaseVectorStore
10
  from langchain.llms.huggingface_pipeline import HuggingFacePipeline
11
  from st_supabase_connection import SupabaseConnection
12
 
13
  msgs = StreamlitChatMessageHistory()
14
+ memory = ConversationBufferMemory(
15
+ memory_key="history", chat_memory=msgs, return_messages=True
16
+ )
17
 
18
  supabase_client = st.connection(
19
  name="orbgpt",
 
21
  ttl=None,
22
  )
23
 
 
 
 
 
 
 
 
24
 
25
+ @st.cache_resource
26
+ def load_retriever():
27
+ # load embeddings using VoyageAI and Supabase
28
+ embeddings = VoyageEmbeddings(model="voyage-01")
29
+ vector_store = SupabaseVectorStore(
30
+ embedding=embeddings,
31
+ client=supabase_client.client,
32
+ table_name="documents",
33
+ query_name="match_documents",
34
+ )
35
+ return vector_store.as_retriever()
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ @st.cache_resource
39
+ def load_model():
40
+ model_path = "llmware/bling-falcon-1b-0.1"
41
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ model_path,
44
+ device_map="auto",
45
+ offload_folder="offload",
46
+ offload_state_dict=True,
47
+ torch_dtype="auto",
48
+ ).eval()
49
+ pipe = pipeline(
50
+ "text-generation",
51
+ model=model,
52
+ tokenizer=tokenizer,
53
+ use_fast=False,
54
+ )
55
+ return HuggingFacePipeline(pipeline=pipe)
56
 
 
 
57
 
58
+ hf = load_model()
59
+ retriever = load_retriever()
60
+ chat = ConversationalRetrievalChain.from_llm(hf, retriever)
61
 
62
+ st.markdown(
63
+ "<div style='display: flex;justify-content: center;'><img width='150' src='data:image/png;base64,{}' class='img-fluid'></div>".format(
64
+ base64.b64encode(Path("orbgptlogo.png").read_bytes()).decode()
65
+ ),
66
+ unsafe_allow_html=True,
67
+ )
68
 
 
69
 
70
  if len(msgs.messages) == 0:
71
  msgs.add_ai_message("Ask me anything about orb community projects!")
 
75
 
76
  if prompt := st.chat_input("Ask something"):
77
  st.chat_message("human").write(prompt)
78
+ msgs.add_user_message(prompt)
79
+ with st.chat_message("ai"):
80
+ with st.spinner("Processing your question..."):
81
+ response = chat({"question": prompt, "chat_history": memory.buffer})
82
+ msgs.add_ai_message(response["answer"])
83
+ st.write(response["answer"])
orbgptlogo.png ADDED