File size: 3,832 Bytes
35d97c8
 
 
552fb1e
35d97c8
 
 
 
552fb1e
 
35d97c8
 
9202468
 
 
35d97c8
552fb1e
 
 
 
 
 
 
 
 
 
35d97c8
 
 
 
 
 
 
 
 
552fb1e
 
 
 
35d97c8
 
 
552fb1e
 
 
 
 
 
 
35d97c8
 
 
 
 
 
 
 
4385b66
35d97c8
 
 
 
 
 
 
 
 
 
 
 
552fb1e
35d97c8
 
552fb1e
35d97c8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import openai

# from huggingface_hub.inference_api import InferenceApi

class ChatService:
    def __init__(self, api="openai", model_id = "gpt-3.5-turbo"):
    # def __init__(self, api="huggingface", model_id = "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
        self._api = api
        self._device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self._system_prompt = None
        self._user_name = None
        self._agent_name = None

        if self._api=="openai":
            openai.api_key = os.getenv("OPENAI_API_KEY")
            self._model_id = model_id
        elif self._api=="huggingface":
            self._system_prompt = "Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.\n-----\n"
            self._user_name = "<|prompter|>"
            self._agent_name = "<|assistant|>"
            self._tokenizer = AutoTokenizer.from_pretrained(model_id)
            self._model = AutoModelForCausalLM.from_pretrained(model_id,torch_dtype=torch.float16)
            # self._model = AutoModelForCausalLM.from_pretrained(model_id).half()
            self._model.eval().to(self._device)
        else:
            raise Exception(f"Unknown API: {self._api}")

        self.reset()

    def reset(self):
        self._user_history = []
        self._agent_history = []
        self._full_history = self._system_prompt if self._system_prompt else ""
        self._messages = []
        if self._system_prompt:
            self._messages.append({"role": "system", "content": self._system_prompt})


    def _chat(self, prompt):
        if self._api=="openai":
            response = openai.ChatCompletion.create(
                model=self._model_id,
                messages=self._messages,
                )
            agent_response = response['choices'][0]['message']['content']
        elif self._api=="huggingface":
            tokens = self._tokenizer.encode(prompt, return_tensors="pt", padding=True)
            tokens = tokens.to(self._device)
            outputs = self._model.generate(
                tokens,
                early_stopping=True,
                max_new_tokens=200,
                do_sample=True,
                top_k=40,
                temperature=1.0, # use 1.0 for debugging/deteministic results
                pad_token_id=self._tokenizer.eos_token_id,
            )
            agent_response = self._tokenizer.decode(outputs[0], truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])
        else:
            raise Exception(f"API not implemented: {self._api}")
        return agent_response
    
    def chat(self, prompt):
        if self._user_name:
            self._full_history += f"{self._user_name}: {prompt}\n"
        else:
            self._full_history += f"{prompt}\n"
        self._messages.append({"role": "user", "content": prompt})
        self._user_history.append(prompt)
        agent_response = self._chat(self._full_history)
        self._messages.append({"role": "assistant", "content": agent_response})
        if self._agent_name:
            self._full_history += f"{self._agent_name}: {agent_response}\n"
        else:
            self._full_history += f"{agent_response}\n"
        self._agent_history.append(agent_response)
        return agent_response