Sourab Mangrulkar commited on
Commit
6f4afc6
1 Parent(s): 12051f8
Files changed (2) hide show
  1. agent.py +0 -60
  2. app.py +24 -15
agent.py DELETED
@@ -1,60 +0,0 @@
1
- import os
2
- from threading import Thread
3
- from typing import Iterator
4
-
5
- from transformers import AutoTokenizer, TextIteratorStreamer
6
-
7
- model_id = "meta-llama/Llama-2-7b-chat-hf"
8
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.environ["HUGGINGFACE_TOKEN"])
9
-
10
-
11
- def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
12
- texts = [f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"]
13
- # The first user input is _not_ stripped
14
- do_strip = False
15
- for user_input, response in chat_history:
16
- user_input = user_input.strip() if do_strip else user_input
17
- do_strip = True
18
- texts.append(f"{user_input} [/INST] {response.strip()} </s><s>[INST] ")
19
- message = message.strip() if do_strip else message
20
- texts.append(f"{message} [/INST]")
21
- return "".join(texts)
22
-
23
-
24
- def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
25
- prompt = get_prompt(message, chat_history, system_prompt)
26
- input_ids = tokenizer([prompt], return_tensors="np", add_special_tokens=False)["input_ids"]
27
- return input_ids.shape[-1]
28
-
29
-
30
- def run(
31
- message: str,
32
- chat_history: list[tuple[str, str]],
33
- system_prompt: str,
34
- max_new_tokens: int = 1024,
35
- temperature: float = 0.8,
36
- top_p: float = 0.95,
37
- top_k: int = 50,
38
- ) -> Iterator[str]:
39
- prompt = get_prompt(message, chat_history, system_prompt)
40
- inputs = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).to("cuda")
41
-
42
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
43
- generate_kwargs = dict(
44
- inputs,
45
- streamer=streamer,
46
- max_new_tokens=max_new_tokens,
47
- do_sample=True,
48
- top_p=top_p,
49
- top_k=top_k,
50
- temperature=temperature,
51
- num_beams=1,
52
- eos_token_id=tokenizer.eos_token_id,
53
- )
54
- t = Thread(target=model.generate, kwargs=generate_kwargs)
55
- t.start()
56
-
57
- outputs = []
58
- for text in streamer:
59
- outputs.append(text)
60
- yield "".join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import argparse
2
  import os
3
  import json
4
  import re
@@ -11,8 +10,7 @@ import pandas as pd
11
  import torch
12
 
13
  from easyllm.clients import huggingface
14
-
15
- from agent import get_input_token_length
16
 
17
  huggingface.prompt_builder = "llama2"
18
  huggingface.api_key = os.environ["HUGGINGFACE_TOKEN"]
@@ -30,9 +28,12 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
30
  print("Running on device:", torch_device)
31
  print("CPU threads:", torch.get_num_threads())
32
 
 
33
  biencoder = SentenceTransformer("intfloat/e5-large-v2", device=torch_device)
34
  cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", max_length=512, device=torch_device)
35
 
 
 
36
 
37
  def create_qa_prompt(query, relevant_chunks):
38
  stuffed_context = " ".join(relevant_chunks)
@@ -60,11 +61,30 @@ Follow Up Input: {question}
60
  """
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # https://www.philschmid.de/llama-2#how-to-prompt-llama-2-chat
64
  def get_completion(
65
  prompt,
66
  system_prompt=None,
67
- model="meta-llama/Llama-2-70b-chat-hf",
68
  max_new_tokens=1024,
69
  temperature=0.2,
70
  top_p=0.95,
@@ -429,14 +449,3 @@ with gr.Blocks(css="style.css") as demo:
429
  )
430
 
431
  demo.queue(max_size=20).launch(debug=True, share=True)
432
-
433
-
434
- # if __name__ == "__main__":
435
- # parser = argparse.ArgumentParser(description="Script to create and use an HNSW index for similarity search.")
436
- # parser.add_argument("--input_file", help="Input file containing text chunks in a Parquet format")
437
- # parser.add_argument("--index_file", help="HNSW index file with .bin extension")
438
- # args = parser.parse_args()
439
-
440
- # data_df = pd.read_parquet(args.input_file).reset_index()
441
- # search_index = load_hnsw_index(args.index_file)
442
- # main()
 
 
1
  import os
2
  import json
3
  import re
 
10
  import torch
11
 
12
  from easyllm.clients import huggingface
13
+ from transformers import AutoTokenizer
 
14
 
15
  huggingface.prompt_builder = "llama2"
16
  huggingface.api_key = os.environ["HUGGINGFACE_TOKEN"]
 
28
  print("Running on device:", torch_device)
29
  print("CPU threads:", torch.get_num_threads())
30
 
31
+ model_id = "meta-llama/Llama-2-70b-chat-hf"
32
  biencoder = SentenceTransformer("intfloat/e5-large-v2", device=torch_device)
33
  cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", max_length=512, device=torch_device)
34
 
35
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.environ["HUGGINGFACE_TOKEN"])
36
+
37
 
38
  def create_qa_prompt(query, relevant_chunks):
39
  stuffed_context = " ".join(relevant_chunks)
 
61
  """
62
 
63
 
64
+ def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
65
+ texts = [f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"]
66
+ # The first user input is _not_ stripped
67
+ do_strip = False
68
+ for user_input, response in chat_history:
69
+ user_input = user_input.strip() if do_strip else user_input
70
+ do_strip = True
71
+ texts.append(f"{user_input} [/INST] {response.strip()} </s><s>[INST] ")
72
+ message = message.strip() if do_strip else message
73
+ texts.append(f"{message} [/INST]")
74
+ return "".join(texts)
75
+
76
+
77
+ def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
78
+ prompt = get_prompt(message, chat_history, system_prompt)
79
+ input_ids = tokenizer([prompt], return_tensors="np", add_special_tokens=False)["input_ids"]
80
+ return input_ids.shape[-1]
81
+
82
+
83
  # https://www.philschmid.de/llama-2#how-to-prompt-llama-2-chat
84
  def get_completion(
85
  prompt,
86
  system_prompt=None,
87
+ model=model_id,
88
  max_new_tokens=1024,
89
  temperature=0.2,
90
  top_p=0.95,
 
449
  )
450
 
451
  demo.queue(max_size=20).launch(debug=True, share=True)