Christoph Holthaus commited on
Commit
3c5e66e
1 Parent(s): 878d5c0

dev - magic

Browse files
Files changed (1) hide show
  1. app.py +5 -30
app.py CHANGED
@@ -64,14 +64,6 @@ MAX_MAX_NEW_TOKENS = 2048
64
  DEFAULT_MAX_NEW_TOKENS = 1024
65
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
66
 
67
-
68
-
69
-
70
- if torch.cuda.is_available():
71
- model_id = "mistralai/Mistral-7B-Instruct-v0.1"
72
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
73
- tokenizer = AutoTokenizer.from_pretrained(model_id)
74
-
75
  # we need to make sure we only run one thread or we probably run out of ram
76
  def generate(
77
  message: str,
@@ -87,34 +79,17 @@ def generate(
87
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
88
  conversation.append({"role": "user", "content": message})
89
 
90
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
91
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
92
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
93
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
94
- input_ids = input_ids.to(model.device)
95
-
96
- llm.generate('test')
97
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
98
- generate_kwargs = dict(
99
- {"input_ids": input_ids},
100
- streamer=streamer,
101
- max_new_tokens=max_new_tokens,
102
- do_sample=True,
103
- top_p=top_p,
104
- top_k=top_k,
105
- temperature=temperature,
106
- num_beams=1,
107
- repetition_penalty=repetition_penalty,
108
- )
109
- t = Thread(target=model.generate, kwargs=generate_kwargs)
110
- t.start()
111
 
112
  outputs = []
113
  for text in streamer:
114
  outputs.append(text)
115
  yield "".join(outputs)
116
 
117
-
118
  chat_interface = gr.ChatInterface(
119
  fn=generate,
120
  additional_inputs=[
 
64
  DEFAULT_MAX_NEW_TOKENS = 1024
65
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
66
 
 
 
 
 
 
 
 
 
67
  # we need to make sure we only run one thread or we probably run out of ram
68
  def generate(
69
  message: str,
 
79
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
80
  conversation.append({"role": "user", "content": message})
81
 
82
+ # Use LLaMa to create chat completion
83
+ llm.create_chat_completion(conversation, stream=True)
84
+
85
+ # Initialize a TextIteratorStreamer
86
+ streamer = TextIteratorStreamer(llm, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  outputs = []
89
  for text in streamer:
90
  outputs.append(text)
91
  yield "".join(outputs)
92
 
 
93
  chat_interface = gr.ChatInterface(
94
  fn=generate,
95
  additional_inputs=[