khulnasoft commited on
Commit
fd5d1af
1 Parent(s): 927d153

Update awesome_chat.py

Browse files
Files changed (1) hide show
  1. awesome_chat.py +283 -152
awesome_chat.py CHANGED
@@ -1,6 +1,5 @@
1
  import base64
2
  import copy
3
- import datetime
4
  from io import BytesIO
5
  import io
6
  import os
@@ -19,50 +18,43 @@ from diffusers.utils import load_image
19
  from pydub import AudioSegment
20
  import threading
21
  from queue import Queue
 
 
 
 
22
  from get_token_ids import get_token_ids_for_task_parsing, get_token_ids_for_choose_model, count_tokens, get_max_context_length
23
  from huggingface_hub.inference_api import InferenceApi
24
  from huggingface_hub.inference_api import ALL_TASKS
25
- from models_server import models, status
26
- from functools import partial
27
- from huggingface_hub import Repository
28
 
29
  parser = argparse.ArgumentParser()
30
- parser.add_argument("--config", type=str, default="config.yaml.dev")
31
  parser.add_argument("--mode", type=str, default="cli")
32
  args = parser.parse_args()
33
 
34
  if __name__ != "__main__":
35
- args.config = "config.gradio.yaml"
 
36
 
37
  config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
38
 
39
- if not os.path.exists("logs"):
40
- os.mkdir("logs")
 
 
41
 
42
- now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
43
-
44
- DATASET_REPO_URL = "https://huggingface.co/datasets/deepcode-ai/HuggingSpace_logs"
45
- LOG_HF_TOKEN = os.environ.get("LOG_HF_TOKEN")
46
- if LOG_HF_TOKEN:
47
- repo = Repository(
48
- local_dir="logs", clone_from=DATASET_REPO_URL, use_auth_token=LOG_HF_TOKEN
49
- )
50
 
51
  logger = logging.getLogger(__name__)
52
- logger.setLevel(logging.INFO)
53
- logger.handlers = []
54
- logger.propagate = False
55
 
56
  handler = logging.StreamHandler()
57
  formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
58
  handler.setFormatter(formatter)
59
- if config["debug"]:
60
- handler.setLevel(logging.DEBUG)
61
  logger.addHandler(handler)
62
 
63
  log_file = config["log_file"]
64
  if log_file:
65
- log_file = log_file.replace("TIMESTAMP", now)
66
  filehandler = logging.FileHandler(log_file)
67
  filehandler.setLevel(logging.DEBUG)
68
  filehandler.setFormatter(formatter)
@@ -73,7 +65,7 @@ use_completion = config["use_completion"]
73
 
74
  # consistent: wrong msra model name
75
  LLM_encoding = LLM
76
- if LLM == "gpt-3.5-turbo":
77
  LLM_encoding = "text-davinci-003"
78
  task_parsing_highlight_ids = get_token_ids_for_task_parsing(LLM_encoding)
79
  choose_model_highlight_ids = get_token_ids_for_choose_model(LLM_encoding)
@@ -87,20 +79,35 @@ if use_completion:
87
  else:
88
  api_name = "chat/completions"
89
 
90
- if not config["dev"]:
91
- if not config["openai"]["key"].startswith("sk-") and not config["openai"]["key"]=="gradio":
92
- raise ValueError("Incrorrect OpenAI key. Please check your config.yaml file.")
93
- OPENAI_KEY = config["openai"]["key"]
94
- endpoint = f"https://api.openai.com/v1/{api_name}"
95
- if OPENAI_KEY.startswith("sk-"):
96
- HEADER = {
97
- "Authorization": f"Bearer {OPENAI_KEY}"
98
- }
99
- else:
100
- HEADER = None
101
  else:
102
- endpoint = f"{config['local']['endpoint']}/v1/{api_name}"
103
- HEADER = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  PROXY = None
106
  if config["proxy"]:
@@ -110,6 +117,19 @@ if config["proxy"]:
110
 
111
  inference_mode = config["inference_mode"]
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  parse_task_demos_or_presteps = open(config["demos_or_presteps"]["parse_task"], "r").read()
114
  choose_model_demos_or_presteps = open(config["demos_or_presteps"]["choose_model"], "r").read()
115
  response_results_demos_or_presteps = open(config["demos_or_presteps"]["response_results"], "r").read()
@@ -133,6 +153,18 @@ METADATAS = {}
133
  for model in MODELS:
134
  METADATAS[model["id"]] = model
135
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def convert_chat_to_completion(data):
137
  messages = data.pop('messages', [])
138
  tprompt = ""
@@ -155,19 +187,26 @@ def convert_chat_to_completion(data):
155
  return data
156
 
157
  def send_request(data):
158
- global HEADER
159
- openaikey = data.pop("openaikey")
 
160
  if use_completion:
161
  data = convert_chat_to_completion(data)
162
- if openaikey and openaikey.startswith("sk-"):
163
  HEADER = {
164
- "Authorization": f"Bearer {openaikey}"
165
  }
166
-
167
- response = requests.post(endpoint, json=data, headers=HEADER, proxies=PROXY)
168
- logger.debug(response.text.strip())
169
- if "choices" not in response.json():
 
 
 
 
 
170
  return response.json()
 
171
  if use_completion:
172
  return response.json()["choices"][0]["text"].strip()
173
  else:
@@ -177,7 +216,7 @@ def replace_slot(text, entries):
177
  for key, value in entries.items():
178
  if not isinstance(value, str):
179
  value = str(value)
180
- text = text.replace("{{" + key +"}}", value.replace('"', "'").replace('\n', "").replace('\\', '\\\\'))
181
  return text
182
 
183
  def find_json(s):
@@ -204,14 +243,13 @@ def get_id_reason(choose_str):
204
  return id.strip(), reason.strip(), choose
205
 
206
  def record_case(success, **args):
207
- if not success:
208
- return
209
- f = open(f"logs/log_success_{now}.jsonl", "a")
 
210
  log = args
211
  f.write(json.dumps(log) + "\n")
212
  f.close()
213
- if LOG_HF_TOKEN:
214
- commit_url = repo.push_to_hub(blocking=False)
215
 
216
  def image_to_bytes(img_url):
217
  img_byte = io.BytesIO()
@@ -266,20 +304,19 @@ def unfold(tasks):
266
 
267
  return tasks
268
 
269
- def chitchat(messages, openaikey=None):
270
  data = {
271
  "model": LLM,
272
  "messages": messages,
273
- "openaikey": openaikey
 
 
274
  }
275
  return send_request(data)
276
 
277
- def parse_task(context, input, openaikey=None):
278
  demos_or_presteps = parse_task_demos_or_presteps
279
  messages = json.loads(demos_or_presteps)
280
- for message in messages:
281
- if not isinstance(message["content"], str):
282
- message["content"] = json.dumps(message["content"], ensure_ascii=False)
283
  messages.insert(0, {"role": "system", "content": parse_task_tprompt})
284
 
285
  # cut chat logs
@@ -304,11 +341,13 @@ def parse_task(context, input, openaikey=None):
304
  "messages": messages,
305
  "temperature": 0,
306
  "logit_bias": {item: config["logit_bias"]["parse_task"] for item in task_parsing_highlight_ids},
307
- "openaikey": openaikey
 
 
308
  }
309
  return send_request(data)
310
 
311
- def choose_model(input, task, metas, openaikey = None):
312
  prompt = replace_slot(choose_model_prompt, {
313
  "input": input,
314
  "task": task,
@@ -328,12 +367,14 @@ def choose_model(input, task, metas, openaikey = None):
328
  "messages": messages,
329
  "temperature": 0,
330
  "logit_bias": {item: config["logit_bias"]["choose_model"] for item in choose_model_highlight_ids}, # 5
331
- "openaikey": openaikey
 
 
332
  }
333
  return send_request(data)
334
 
335
 
336
- def response_results(input, results, openaikey=None):
337
  results = [v for k, v in sorted(results.items(), key=lambda item: item[0])]
338
  prompt = replace_slot(response_results_prompt, {
339
  "input": input,
@@ -342,7 +383,7 @@ def response_results(input, results, openaikey=None):
342
  "input": input,
343
  "processes": results
344
  })
345
- messages = json.loads(demos_or_presteps, strict=False)
346
  messages.insert(0, {"role": "system", "content": response_results_tprompt})
347
  messages.append({"role": "user", "content": prompt})
348
  logger.debug(messages)
@@ -350,19 +391,15 @@ def response_results(input, results, openaikey=None):
350
  "model": LLM,
351
  "messages": messages,
352
  "temperature": 0,
353
- "openaikey": openaikey
 
 
354
  }
355
  return send_request(data)
356
 
357
- def huggingface_model_inference(model_id, data, task, huggingfacetoken=None):
358
- if huggingfacetoken is None:
359
- HUGGINGFACE_HEADERS = {}
360
- else:
361
- HUGGINGFACE_HEADERS = {
362
- "Authorization": f"Bearer {huggingfacetoken}",
363
- }
364
  task_url = f"https://api-inference.huggingface.co/models/{model_id}" # InferenceApi does not yet support some tasks
365
- inference = InferenceApi(repo_id=model_id, token=huggingfacetoken)
366
 
367
  # NLP tasks
368
  if task == "question-answering":
@@ -426,7 +463,7 @@ def huggingface_model_inference(model_id, data, task, huggingfacetoken=None):
426
  name = str(uuid.uuid4())[:4]
427
  image.save(f"public/images/{name}.jpg")
428
  result = {}
429
- result["generated image with segmentation mask"] = f"/images/{name}.jpg"
430
  result["predicted"] = predicted
431
 
432
  if task == "object-detection":
@@ -447,7 +484,7 @@ def huggingface_model_inference(model_id, data, task, huggingfacetoken=None):
447
  name = str(uuid.uuid4())[:4]
448
  image.save(f"public/images/{name}.jpg")
449
  result = {}
450
- result["generated image with predicted box"] = f"/images/{name}.jpg"
451
  result["predicted"] = predicted
452
 
453
  if task in ["image-classification"]:
@@ -459,7 +496,7 @@ def huggingface_model_inference(model_id, data, task, huggingfacetoken=None):
459
  img_url = data["image"]
460
  img_data = image_to_bytes(img_url)
461
  HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
462
- r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
463
  result = {}
464
  if "generated_text" in r.json()[0]:
465
  result["generated text"] = r.json()[0].pop("generated_text")
@@ -493,62 +530,71 @@ def huggingface_model_inference(model_id, data, task, huggingfacetoken=None):
493
  return result
494
 
495
  def local_model_inference(model_id, data, task):
496
- inference = partial(models, model_id)
 
497
  # contronlet
498
  if model_id.startswith("lllyasviel/sd-controlnet-"):
499
  img_url = data["image"]
500
  text = data["text"]
501
- results = inference({"img_url": img_url, "text": text})
 
502
  if "path" in results:
503
  results["generated image"] = results.pop("path")
504
  return results
505
  if model_id.endswith("-control"):
506
  img_url = data["image"]
507
- results = inference({"img_url": img_url})
 
508
  if "path" in results:
509
  results["generated image"] = results.pop("path")
510
  return results
511
 
512
  if task == "text-to-video":
513
- results = inference(data)
 
514
  if "path" in results:
515
  results["generated video"] = results.pop("path")
516
  return results
517
 
518
  # NLP tasks
519
  if task == "question-answering" or task == "sentence-similarity":
520
- results = inference(json=data)
521
- return results
522
  if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
523
- results = inference(json=data)
524
- return results
525
 
526
  # CV tasks
527
  if task == "depth-estimation":
528
  img_url = data["image"]
529
- results = inference({"img_url": img_url})
 
530
  if "path" in results:
531
- results["generated depth image"] = results.pop("path")
532
  return results
533
  if task == "image-segmentation":
534
  img_url = data["image"]
535
- results = inference({"img_url": img_url})
536
- results["generated image with segmentation mask"] = results.pop("path")
 
537
  return results
538
  if task == "image-to-image":
539
  img_url = data["image"]
540
- results = inference({"img_url": img_url})
 
541
  if "path" in results:
542
  results["generated image"] = results.pop("path")
543
  return results
544
  if task == "text-to-image":
545
- results = inference(data)
 
546
  if "path" in results:
547
  results["generated image"] = results.pop("path")
548
  return results
549
  if task == "object-detection":
550
  img_url = data["image"]
551
- predicted = inference({"img_url": img_url})
 
552
  if "error" in predicted:
553
  return predicted
554
  image = load_image(img_url)
@@ -565,7 +611,7 @@ def local_model_inference(model_id, data, task):
565
  name = str(uuid.uuid4())[:4]
566
  image.save(f"public/images/{name}.jpg")
567
  results = {}
568
- results["generated image with predicted box"] = f"/images/{name}.jpg"
569
  results["predicted"] = predicted
570
  return results
571
  if task in ["image-classification", "image-to-text", "document-question-answering", "visual-question-answering"]:
@@ -573,43 +619,40 @@ def local_model_inference(model_id, data, task):
573
  text = None
574
  if "text" in data:
575
  text = data["text"]
576
- results = inference({"img_url": img_url, "text": text})
 
577
  return results
578
  # AUDIO tasks
579
  if task == "text-to-speech":
580
- results = inference(data)
 
581
  if "path" in results:
582
  results["generated audio"] = results.pop("path")
583
  return results
584
  if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
585
  audio_url = data["audio"]
586
- results = inference({"audio_url": audio_url})
587
- return results
588
 
589
 
590
- def model_inference(model_id, data, hosted_on, task, huggingfacetoken=None):
591
- if huggingfacetoken:
592
- HUGGINGFACE_HEADERS = {
593
- "Authorization": f"Bearer {huggingfacetoken}",
594
- }
595
- else:
596
- HUGGINGFACE_HEADERS = None
597
  if hosted_on == "unknown":
598
- r = status(model_id)
599
- logger.debug("Local Server Status: " + str(r))
600
- if "loaded" in r and r["loaded"]:
 
601
  hosted_on = "local"
602
  else:
603
  huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
604
  r = requests.get(huggingfaceStatusUrl, headers=HUGGINGFACE_HEADERS, proxies=PROXY)
605
  logger.debug("Huggingface Status: " + str(r.json()))
606
- if "loaded" in r and r["loaded"]:
607
  hosted_on = "huggingface"
608
  try:
609
  if hosted_on == "local":
610
  inference_result = local_model_inference(model_id, data, task)
611
  elif hosted_on == "huggingface":
612
- inference_result = huggingface_model_inference(model_id, data, task, huggingfacetoken)
613
  except Exception as e:
614
  print(e)
615
  traceback.print_exc()
@@ -622,8 +665,8 @@ def get_model_status(model_id, url, headers, queue = None):
622
  if "huggingface" in url:
623
  r = requests.get(url, headers=headers, proxies=PROXY)
624
  else:
625
- r = status(model_id)
626
- if "loaded" in r and r["loaded"]:
627
  if queue:
628
  queue.put((model_id, True, endpoint_type))
629
  return True
@@ -632,13 +675,11 @@ def get_model_status(model_id, url, headers, queue = None):
632
  queue.put((model_id, False, None))
633
  return False
634
 
635
- def get_avaliable_models(candidates, topk=10, huggingfacetoken = None):
636
  all_available_models = {"local": [], "huggingface": []}
637
  threads = []
638
  result_queue = Queue()
639
- HUGGINGFACE_HEADERS = {
640
- "Authorization": f"Bearer {huggingfacetoken}",
641
- }
642
  for candidate in candidates:
643
  model_id = candidate["id"]
644
 
@@ -649,7 +690,8 @@ def get_avaliable_models(candidates, topk=10, huggingfacetoken = None):
649
  thread.start()
650
 
651
  if inference_mode != "huggingface" and config["local_deployment"] != "minimal":
652
- thread = threading.Thread(target=get_model_status, args=(model_id, "", {}, result_queue))
 
653
  threads.append(thread)
654
  thread.start()
655
 
@@ -675,7 +717,7 @@ def collect_result(command, choose, inference_result):
675
  return result
676
 
677
 
678
- def run_task(input, command, results, openaikey = None, huggingfacetoken = None):
679
  id = command["id"]
680
  args = command["args"]
681
  task = command["task"]
@@ -776,9 +818,9 @@ def run_task(input, command, results, openaikey = None, huggingfacetoken = None)
776
  choose = {"id": best_model_id, "reason": reason}
777
  messages = [{
778
  "role": "user",
779
- "content": f"[ {input} ] contains a task in JSON format {command}, 'task' indicates the task type and 'args' indicates the arguments required for the task. Don't explain the task to me, just help me do it and give me the result. The result must be in text form without any urls."
780
  }]
781
- response = chitchat(messages, openaikey)
782
  results[id] = collect_result(command, choose, {"response": response})
783
  return True
784
  else:
@@ -789,8 +831,8 @@ def run_task(input, command, results, openaikey = None, huggingfacetoken = None)
789
  results[id] = collect_result(command, "", inference_result)
790
  return False
791
 
792
- candidates = MODELS_MAP[task][:20]
793
- all_avaliable_models = get_avaliable_models(candidates, config["num_candidate_models"], huggingfacetoken)
794
  all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"]
795
  logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")
796
 
@@ -800,8 +842,7 @@ def run_task(input, command, results, openaikey = None, huggingfacetoken = None)
800
  inference_result = {"error": f"no available models on {command['task']} task."}
801
  results[id] = collect_result(command, "", inference_result)
802
  return False
803
-
804
- all_avaliable_model_ids = all_avaliable_model_ids[:1]
805
  if len(all_avaliable_model_ids) == 1:
806
  best_model_id = all_avaliable_model_ids[0]
807
  hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
@@ -817,14 +858,14 @@ def run_task(input, command, results, openaikey = None, huggingfacetoken = None)
817
  ),
818
  "likes": model.get("likes"),
819
  "description": model.get("description", "")[:config["max_description_length"]],
820
- "language": model.get("language"),
821
- "tags": model.get("tags"),
822
  }
823
  for model in candidates
824
  if model["id"] in all_avaliable_model_ids
825
  ]
826
 
827
- choose_str = choose_model(input, command, cand_models_info, openaikey)
828
  logger.debug(f"chosen model: {choose_str}")
829
  try:
830
  choose = json.loads(choose_str)
@@ -836,7 +877,7 @@ def run_task(input, command, results, openaikey = None, huggingfacetoken = None)
836
  choose_str = find_json(choose_str)
837
  best_model_id, reason, choose = get_id_reason(choose_str)
838
  hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
839
- inference_result = model_inference(best_model_id, args, hosted_on, command['task'], huggingfacetoken)
840
 
841
  if "error" in inference_result:
842
  logger.warning(f"Inference error: {inference_result['error']}")
@@ -847,42 +888,39 @@ def run_task(input, command, results, openaikey = None, huggingfacetoken = None)
847
  results[id] = collect_result(command, choose, inference_result)
848
  return True
849
 
850
- def chat_huggingface(messages, openaikey = None, huggingfacetoken = None, return_planning = False, return_results = False):
851
  start = time.time()
852
  context = messages[:-1]
853
  input = messages[-1]["content"]
854
  logger.info("*"*80)
855
  logger.info(f"input: {input}")
856
 
857
- task_str = parse_task(context, input, openaikey)
858
- logger.info(task_str)
859
 
860
  if "error" in task_str:
861
- return str(task_str), {}
862
- else:
863
- task_str = task_str.strip()
 
 
864
 
865
  try:
866
  tasks = json.loads(task_str)
867
  except Exception as e:
868
  logger.debug(e)
869
- response = chitchat(messages, openaikey)
870
  record_case(success=False, **{"input": input, "task": task_str, "reason": "task parsing fail", "op":"chitchat"})
871
- return response, {}
872
-
873
  if task_str == "[]": # using LLM response for empty task
874
  record_case(success=False, **{"input": input, "task": [], "reason": "task parsing fail: empty", "op": "chitchat"})
875
- response = chitchat(messages, openaikey)
876
- return response, {}
877
 
878
- if len(tasks)==1 and tasks[0]["task"] in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]:
879
- record_case(success=True, **{"input": input, "task": tasks, "reason": "task parsing fail: empty", "op": "chitchat"})
880
- response = chitchat(messages, openaikey)
881
- best_model_id = "ChatGPT"
882
- reason = "ChatGPT performs well on some NLP tasks as well."
883
- choose = {"id": best_model_id, "reason": reason}
884
- return response, collect_result(tasks[0], choose, {"response": response})
885
-
886
 
887
  tasks = unfold(tasks)
888
  tasks = fix_dep(tasks)
@@ -897,24 +935,23 @@ def chat_huggingface(messages, openaikey = None, huggingfacetoken = None, return
897
  d = dict()
898
  retry = 0
899
  while True:
900
- num_threads = len(threads)
901
  for task in tasks:
902
- dep = task["dep"]
903
  # logger.debug(f"d.keys(): {d.keys()}, dep: {dep}")
904
- for dep_id in dep:
905
  if dep_id >= task["id"]:
906
  task["dep"] = [-1]
907
- dep = [-1]
908
  break
909
- if len(list(set(dep).intersection(d.keys()))) == len(dep) or dep[0] == -1:
 
910
  tasks.remove(task)
911
- thread = threading.Thread(target=run_task, args=(input, task, d, openaikey, huggingfacetoken))
912
  thread.start()
913
  threads.append(thread)
914
- if num_threads == len(threads):
915
  time.sleep(0.5)
916
  retry += 1
917
- if retry > 80:
918
  logger.debug("User has waited too long, Loop break.")
919
  break
920
  if len(tasks) == 0:
@@ -928,7 +965,7 @@ def chat_huggingface(messages, openaikey = None, huggingfacetoken = None, return
928
  if return_results:
929
  return results
930
 
931
- response = response_results(input, results, openaikey).strip()
932
 
933
  end = time.time()
934
  during = end - start
@@ -936,4 +973,98 @@ def chat_huggingface(messages, openaikey = None, huggingfacetoken = None, return
936
  answer = {"message": response}
937
  record_case(success=True, **{"input": input, "task": task_str, "results": results, "response": response, "during": during, "op":"response"})
938
  logger.info(f"response: {response}")
939
- return response, results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
2
  import copy
 
3
  from io import BytesIO
4
  import io
5
  import os
 
18
  from pydub import AudioSegment
19
  import threading
20
  from queue import Queue
21
+ import flask
22
+ from flask import request, jsonify
23
+ import waitress
24
+ from flask_cors import CORS, cross_origin
25
  from get_token_ids import get_token_ids_for_task_parsing, get_token_ids_for_choose_model, count_tokens, get_max_context_length
26
  from huggingface_hub.inference_api import InferenceApi
27
  from huggingface_hub.inference_api import ALL_TASKS
 
 
 
28
 
29
  parser = argparse.ArgumentParser()
30
+ parser.add_argument("--config", type=str, default="configs/config.default.yaml")
31
  parser.add_argument("--mode", type=str, default="cli")
32
  args = parser.parse_args()
33
 
34
  if __name__ != "__main__":
35
+ args.config = "configs/config.gradio.yaml"
36
+ args.mode = "gradio"
37
 
38
  config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
39
 
40
+ os.makedirs("logs", exist_ok=True)
41
+ os.makedirs("public/images", exist_ok=True)
42
+ os.makedirs("public/audios", exist_ok=True)
43
+ os.makedirs("public/videos", exist_ok=True)
44
 
 
 
 
 
 
 
 
 
45
 
46
  logger = logging.getLogger(__name__)
47
+ logger.setLevel(logging.DEBUG)
 
 
48
 
49
  handler = logging.StreamHandler()
50
  formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
51
  handler.setFormatter(formatter)
52
+ if not config["debug"]:
53
+ handler.setLevel(logging.CRITICAL)
54
  logger.addHandler(handler)
55
 
56
  log_file = config["log_file"]
57
  if log_file:
 
58
  filehandler = logging.FileHandler(log_file)
59
  filehandler.setLevel(logging.DEBUG)
60
  filehandler.setFormatter(formatter)
 
65
 
66
  # consistent: wrong msra model name
67
  LLM_encoding = LLM
68
+ if config["dev"] and LLM == "gpt-3.5-turbo":
69
  LLM_encoding = "text-davinci-003"
70
  task_parsing_highlight_ids = get_token_ids_for_task_parsing(LLM_encoding)
71
  choose_model_highlight_ids = get_token_ids_for_choose_model(LLM_encoding)
 
79
  else:
80
  api_name = "chat/completions"
81
 
82
+ API_TYPE = None
83
+ # priority: local > azure > openai
84
+ if "dev" in config and config["dev"]:
85
+ API_TYPE = "local"
86
+ elif "azure" in config:
87
+ API_TYPE = "azure"
88
+ elif "openai" in config:
89
+ API_TYPE = "openai"
 
 
 
90
  else:
91
+ logger.warning(f"No endpoint specified in {args.config}. The endpoint will be set dynamically according to the client.")
92
+
93
+ if args.mode in ["test", "cli"]:
94
+ assert API_TYPE, "Only server mode supports dynamic endpoint."
95
+
96
+ API_KEY = None
97
+ API_ENDPOINT = None
98
+ if API_TYPE == "local":
99
+ API_ENDPOINT = f"{config['local']['endpoint']}/v1/{api_name}"
100
+ elif API_TYPE == "azure":
101
+ API_ENDPOINT = f"{config['azure']['base_url']}/openai/deployments/{config['azure']['deployment_name']}/{api_name}?api-version={config['azure']['api_version']}"
102
+ API_KEY = config["azure"]["api_key"]
103
+ elif API_TYPE == "openai":
104
+ API_ENDPOINT = f"https://api.openai.com/v1/{api_name}"
105
+ if config["openai"]["api_key"].startswith("sk-"): # Check for valid OpenAI key in config file
106
+ API_KEY = config["openai"]["api_key"]
107
+ elif "OPENAI_API_KEY" in os.environ and os.getenv("OPENAI_API_KEY").startswith("sk-"): # Check for environment variable OPENAI_API_KEY
108
+ API_KEY = os.getenv("OPENAI_API_KEY")
109
+ else:
110
+ raise ValueError(f"Incorrect OpenAI key. Please check your {args.config} file.")
111
 
112
  PROXY = None
113
  if config["proxy"]:
 
117
 
118
  inference_mode = config["inference_mode"]
119
 
120
+ # check the local_inference_endpoint
121
+ Model_Server = None
122
+ if inference_mode!="huggingface":
123
+ Model_Server = "http://" + config["local_inference_endpoint"]["host"] + ":" + str(config["local_inference_endpoint"]["port"])
124
+ message = f"The server of local inference endpoints is not running, please start it first. (or using `inference_mode: huggingface` in {args.config} for a feature-limited experience)"
125
+ try:
126
+ r = requests.get(Model_Server + "/running")
127
+ if r.status_code != 200:
128
+ raise ValueError(message)
129
+ except:
130
+ raise ValueError(message)
131
+
132
+
133
  parse_task_demos_or_presteps = open(config["demos_or_presteps"]["parse_task"], "r").read()
134
  choose_model_demos_or_presteps = open(config["demos_or_presteps"]["choose_model"], "r").read()
135
  response_results_demos_or_presteps = open(config["demos_or_presteps"]["response_results"], "r").read()
 
153
  for model in MODELS:
154
  METADATAS[model["id"]] = model
155
 
156
+ HUGGINGFACE_HEADERS = {}
157
+ if config["huggingface"]["token"] and config["huggingface"]["token"].startswith("hf_"): # Check for valid huggingface token in config file
158
+ HUGGINGFACE_HEADERS = {
159
+ "Authorization": f"Bearer {config['huggingface']['token']}",
160
+ }
161
+ elif "HUGGINGFACE_ACCESS_TOKEN" in os.environ and os.getenv("HUGGINGFACE_ACCESS_TOKEN").startswith("hf_"): # Check for environment variable HUGGINGFACE_ACCESS_TOKEN
162
+ HUGGINGFACE_HEADERS = {
163
+ "Authorization": f"Bearer {os.getenv('HUGGINGFACE_ACCESS_TOKEN')}",
164
+ }
165
+ else:
166
+ raise ValueError(f"Incorrect HuggingFace token. Please check your {args.config} file.")
167
+
168
  def convert_chat_to_completion(data):
169
  messages = data.pop('messages', [])
170
  tprompt = ""
 
187
  return data
188
 
189
  def send_request(data):
190
+ api_key = data.pop("api_key")
191
+ api_type = data.pop("api_type")
192
+ api_endpoint = data.pop("api_endpoint")
193
  if use_completion:
194
  data = convert_chat_to_completion(data)
195
+ if api_type == "openai":
196
  HEADER = {
197
+ "Authorization": f"Bearer {api_key}"
198
  }
199
+ elif api_type == "azure":
200
+ HEADER = {
201
+ "api-key": api_key,
202
+ "Content-Type": "application/json"
203
+ }
204
+ else:
205
+ HEADER = None
206
+ response = requests.post(api_endpoint, json=data, headers=HEADER, proxies=PROXY)
207
+ if "error" in response.json():
208
  return response.json()
209
+ logger.debug(response.text.strip())
210
  if use_completion:
211
  return response.json()["choices"][0]["text"].strip()
212
  else:
 
216
  for key, value in entries.items():
217
  if not isinstance(value, str):
218
  value = str(value)
219
+ text = text.replace("{{" + key +"}}", value.replace('"', "'").replace('\n', ""))
220
  return text
221
 
222
  def find_json(s):
 
243
  return id.strip(), reason.strip(), choose
244
 
245
  def record_case(success, **args):
246
+ if success:
247
+ f = open("logs/log_success.jsonl", "a")
248
+ else:
249
+ f = open("logs/log_fail.jsonl", "a")
250
  log = args
251
  f.write(json.dumps(log) + "\n")
252
  f.close()
 
 
253
 
254
  def image_to_bytes(img_url):
255
  img_byte = io.BytesIO()
 
304
 
305
  return tasks
306
 
307
+ def chitchat(messages, api_key, api_type, api_endpoint):
308
  data = {
309
  "model": LLM,
310
  "messages": messages,
311
+ "api_key": api_key,
312
+ "api_type": api_type,
313
+ "api_endpoint": api_endpoint
314
  }
315
  return send_request(data)
316
 
317
+ def parse_task(context, input, api_key, api_type, api_endpoint):
318
  demos_or_presteps = parse_task_demos_or_presteps
319
  messages = json.loads(demos_or_presteps)
 
 
 
320
  messages.insert(0, {"role": "system", "content": parse_task_tprompt})
321
 
322
  # cut chat logs
 
341
  "messages": messages,
342
  "temperature": 0,
343
  "logit_bias": {item: config["logit_bias"]["parse_task"] for item in task_parsing_highlight_ids},
344
+ "api_key": api_key,
345
+ "api_type": api_type,
346
+ "api_endpoint": api_endpoint
347
  }
348
  return send_request(data)
349
 
350
+ def choose_model(input, task, metas, api_key, api_type, api_endpoint):
351
  prompt = replace_slot(choose_model_prompt, {
352
  "input": input,
353
  "task": task,
 
367
  "messages": messages,
368
  "temperature": 0,
369
  "logit_bias": {item: config["logit_bias"]["choose_model"] for item in choose_model_highlight_ids}, # 5
370
+ "api_key": api_key,
371
+ "api_type": api_type,
372
+ "api_endpoint": api_endpoint
373
  }
374
  return send_request(data)
375
 
376
 
377
+ def response_results(input, results, api_key, api_type, api_endpoint):
378
  results = [v for k, v in sorted(results.items(), key=lambda item: item[0])]
379
  prompt = replace_slot(response_results_prompt, {
380
  "input": input,
 
383
  "input": input,
384
  "processes": results
385
  })
386
+ messages = json.loads(demos_or_presteps)
387
  messages.insert(0, {"role": "system", "content": response_results_tprompt})
388
  messages.append({"role": "user", "content": prompt})
389
  logger.debug(messages)
 
391
  "model": LLM,
392
  "messages": messages,
393
  "temperature": 0,
394
+ "api_key": api_key,
395
+ "api_type": api_type,
396
+ "api_endpoint": api_endpoint
397
  }
398
  return send_request(data)
399
 
400
+ def huggingface_model_inference(model_id, data, task):
 
 
 
 
 
 
401
  task_url = f"https://api-inference.huggingface.co/models/{model_id}" # InferenceApi does not yet support some tasks
402
+ inference = InferenceApi(repo_id=model_id, token=config["huggingface"]["token"])
403
 
404
  # NLP tasks
405
  if task == "question-answering":
 
463
  name = str(uuid.uuid4())[:4]
464
  image.save(f"public/images/{name}.jpg")
465
  result = {}
466
+ result["generated image"] = f"/images/{name}.jpg"
467
  result["predicted"] = predicted
468
 
469
  if task == "object-detection":
 
484
  name = str(uuid.uuid4())[:4]
485
  image.save(f"public/images/{name}.jpg")
486
  result = {}
487
+ result["generated image"] = f"/images/{name}.jpg"
488
  result["predicted"] = predicted
489
 
490
  if task in ["image-classification"]:
 
496
  img_url = data["image"]
497
  img_data = image_to_bytes(img_url)
498
  HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
499
+ r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data, proxies=PROXY)
500
  result = {}
501
  if "generated_text" in r.json()[0]:
502
  result["generated text"] = r.json()[0].pop("generated_text")
 
530
  return result
531
 
532
  def local_model_inference(model_id, data, task):
533
+ task_url = f"{Model_Server}/models/{model_id}"
534
+
535
  # contronlet
536
  if model_id.startswith("lllyasviel/sd-controlnet-"):
537
  img_url = data["image"]
538
  text = data["text"]
539
+ response = requests.post(task_url, json={"img_url": img_url, "text": text})
540
+ results = response.json()
541
  if "path" in results:
542
  results["generated image"] = results.pop("path")
543
  return results
544
  if model_id.endswith("-control"):
545
  img_url = data["image"]
546
+ response = requests.post(task_url, json={"img_url": img_url})
547
+ results = response.json()
548
  if "path" in results:
549
  results["generated image"] = results.pop("path")
550
  return results
551
 
552
  if task == "text-to-video":
553
+ response = requests.post(task_url, json=data)
554
+ results = response.json()
555
  if "path" in results:
556
  results["generated video"] = results.pop("path")
557
  return results
558
 
559
  # NLP tasks
560
  if task == "question-answering" or task == "sentence-similarity":
561
+ response = requests.post(task_url, json=data)
562
+ return response.json()
563
  if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
564
+ response = requests.post(task_url, json=data)
565
+ return response.json()
566
 
567
  # CV tasks
568
  if task == "depth-estimation":
569
  img_url = data["image"]
570
+ response = requests.post(task_url, json={"img_url": img_url})
571
+ results = response.json()
572
  if "path" in results:
573
+ results["generated image"] = results.pop("path")
574
  return results
575
  if task == "image-segmentation":
576
  img_url = data["image"]
577
+ response = requests.post(task_url, json={"img_url": img_url})
578
+ results = response.json()
579
+ results["generated image"] = results.pop("path")
580
  return results
581
  if task == "image-to-image":
582
  img_url = data["image"]
583
+ response = requests.post(task_url, json={"img_url": img_url})
584
+ results = response.json()
585
  if "path" in results:
586
  results["generated image"] = results.pop("path")
587
  return results
588
  if task == "text-to-image":
589
+ response = requests.post(task_url, json=data)
590
+ results = response.json()
591
  if "path" in results:
592
  results["generated image"] = results.pop("path")
593
  return results
594
  if task == "object-detection":
595
  img_url = data["image"]
596
+ response = requests.post(task_url, json={"img_url": img_url})
597
+ predicted = response.json()
598
  if "error" in predicted:
599
  return predicted
600
  image = load_image(img_url)
 
611
  name = str(uuid.uuid4())[:4]
612
  image.save(f"public/images/{name}.jpg")
613
  results = {}
614
+ results["generated image"] = f"/images/{name}.jpg"
615
  results["predicted"] = predicted
616
  return results
617
  if task in ["image-classification", "image-to-text", "document-question-answering", "visual-question-answering"]:
 
619
  text = None
620
  if "text" in data:
621
  text = data["text"]
622
+ response = requests.post(task_url, json={"img_url": img_url, "text": text})
623
+ results = response.json()
624
  return results
625
  # AUDIO tasks
626
  if task == "text-to-speech":
627
+ response = requests.post(task_url, json=data)
628
+ results = response.json()
629
  if "path" in results:
630
  results["generated audio"] = results.pop("path")
631
  return results
632
  if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
633
  audio_url = data["audio"]
634
+ response = requests.post(task_url, json={"audio_url": audio_url})
635
+ return response.json()
636
 
637
 
638
+ def model_inference(model_id, data, hosted_on, task):
 
 
 
 
 
 
639
  if hosted_on == "unknown":
640
+ localStatusUrl = f"{Model_Server}/status/{model_id}"
641
+ r = requests.get(localStatusUrl)
642
+ logger.debug("Local Server Status: " + str(r.json()))
643
+ if r.status_code == 200 and "loaded" in r.json() and r.json()["loaded"]:
644
  hosted_on = "local"
645
  else:
646
  huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
647
  r = requests.get(huggingfaceStatusUrl, headers=HUGGINGFACE_HEADERS, proxies=PROXY)
648
  logger.debug("Huggingface Status: " + str(r.json()))
649
+ if r.status_code == 200 and "loaded" in r.json() and r.json()["loaded"]:
650
  hosted_on = "huggingface"
651
  try:
652
  if hosted_on == "local":
653
  inference_result = local_model_inference(model_id, data, task)
654
  elif hosted_on == "huggingface":
655
+ inference_result = huggingface_model_inference(model_id, data, task)
656
  except Exception as e:
657
  print(e)
658
  traceback.print_exc()
 
665
  if "huggingface" in url:
666
  r = requests.get(url, headers=headers, proxies=PROXY)
667
  else:
668
+ r = requests.get(url)
669
+ if r.status_code == 200 and "loaded" in r.json() and r.json()["loaded"]:
670
  if queue:
671
  queue.put((model_id, True, endpoint_type))
672
  return True
 
675
  queue.put((model_id, False, None))
676
  return False
677
 
678
+ def get_avaliable_models(candidates, topk=5):
679
  all_available_models = {"local": [], "huggingface": []}
680
  threads = []
681
  result_queue = Queue()
682
+
 
 
683
  for candidate in candidates:
684
  model_id = candidate["id"]
685
 
 
690
  thread.start()
691
 
692
  if inference_mode != "huggingface" and config["local_deployment"] != "minimal":
693
+ localStatusUrl = f"{Model_Server}/status/{model_id}"
694
+ thread = threading.Thread(target=get_model_status, args=(model_id, localStatusUrl, {}, result_queue))
695
  threads.append(thread)
696
  thread.start()
697
 
 
717
  return result
718
 
719
 
720
+ def run_task(input, command, results, api_key, api_type, api_endpoint):
721
  id = command["id"]
722
  args = command["args"]
723
  task = command["task"]
 
818
  choose = {"id": best_model_id, "reason": reason}
819
  messages = [{
820
  "role": "user",
821
+ "content": f"[ {input} ] contains a task in JSON format {command}. Now you are a {command['task']} system, the arguments are {command['args']}. Just help me do {command['task']} and give me the result. The result must be in text form without any urls."
822
  }]
823
+ response = chitchat(messages, api_key, api_type, api_endpoint)
824
  results[id] = collect_result(command, choose, {"response": response})
825
  return True
826
  else:
 
831
  results[id] = collect_result(command, "", inference_result)
832
  return False
833
 
834
+ candidates = MODELS_MAP[task][:10]
835
+ all_avaliable_models = get_avaliable_models(candidates, config["num_candidate_models"])
836
  all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"]
837
  logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")
838
 
 
842
  inference_result = {"error": f"no available models on {command['task']} task."}
843
  results[id] = collect_result(command, "", inference_result)
844
  return False
845
+
 
846
  if len(all_avaliable_model_ids) == 1:
847
  best_model_id = all_avaliable_model_ids[0]
848
  hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
 
858
  ),
859
  "likes": model.get("likes"),
860
  "description": model.get("description", "")[:config["max_description_length"]],
861
+ # "language": model.get("meta").get("language") if model.get("meta") else None,
862
+ "tags": model.get("meta").get("tags") if model.get("meta") else None,
863
  }
864
  for model in candidates
865
  if model["id"] in all_avaliable_model_ids
866
  ]
867
 
868
+ choose_str = choose_model(input, command, cand_models_info, api_key, api_type, api_endpoint)
869
  logger.debug(f"chosen model: {choose_str}")
870
  try:
871
  choose = json.loads(choose_str)
 
877
  choose_str = find_json(choose_str)
878
  best_model_id, reason, choose = get_id_reason(choose_str)
879
  hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
880
+ inference_result = model_inference(best_model_id, args, hosted_on, command['task'])
881
 
882
  if "error" in inference_result:
883
  logger.warning(f"Inference error: {inference_result['error']}")
 
888
  results[id] = collect_result(command, choose, inference_result)
889
  return True
890
 
891
+ def chat_huggingface(messages, api_key, api_type, api_endpoint, return_planning = False, return_results = False):
892
  start = time.time()
893
  context = messages[:-1]
894
  input = messages[-1]["content"]
895
  logger.info("*"*80)
896
  logger.info(f"input: {input}")
897
 
898
+ task_str = parse_task(context, input, api_key, api_type, api_endpoint)
 
899
 
900
  if "error" in task_str:
901
+ record_case(success=False, **{"input": input, "task": task_str, "reason": f"task parsing error: {task_str['error']['message']}", "op":"report message"})
902
+ return {"message": task_str["error"]["message"]}
903
+
904
+ task_str = task_str.strip()
905
+ logger.info(task_str)
906
 
907
  try:
908
  tasks = json.loads(task_str)
909
  except Exception as e:
910
  logger.debug(e)
911
+ response = chitchat(messages, api_key, api_type, api_endpoint)
912
  record_case(success=False, **{"input": input, "task": task_str, "reason": "task parsing fail", "op":"chitchat"})
913
+ return {"message": response}
914
+
915
  if task_str == "[]": # using LLM response for empty task
916
  record_case(success=False, **{"input": input, "task": [], "reason": "task parsing fail: empty", "op": "chitchat"})
917
+ response = chitchat(messages, api_key, api_type, api_endpoint)
918
+ return {"message": response}
919
 
920
+ if len(tasks) == 1 and tasks[0]["task"] in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]:
921
+ record_case(success=True, **{"input": input, "task": tasks, "reason": "chitchat tasks", "op": "chitchat"})
922
+ response = chitchat(messages, api_key, api_type, api_endpoint)
923
+ return {"message": response}
 
 
 
 
924
 
925
  tasks = unfold(tasks)
926
  tasks = fix_dep(tasks)
 
935
  d = dict()
936
  retry = 0
937
  while True:
938
+ num_thread = len(threads)
939
  for task in tasks:
 
940
  # logger.debug(f"d.keys(): {d.keys()}, dep: {dep}")
941
+ for dep_id in task["dep"]:
942
  if dep_id >= task["id"]:
943
  task["dep"] = [-1]
 
944
  break
945
+ dep = task["dep"]
946
+ if dep[0] == -1 or len(list(set(dep).intersection(d.keys()))) == len(dep):
947
  tasks.remove(task)
948
+ thread = threading.Thread(target=run_task, args=(input, task, d, api_key, api_type, api_endpoint))
949
  thread.start()
950
  threads.append(thread)
951
+ if num_thread == len(threads):
952
  time.sleep(0.5)
953
  retry += 1
954
+ if retry > 160:
955
  logger.debug("User has waited too long, Loop break.")
956
  break
957
  if len(tasks) == 0:
 
965
  if return_results:
966
  return results
967
 
968
+ response = response_results(input, results, api_key, api_type, api_endpoint).strip()
969
 
970
  end = time.time()
971
  during = end - start
 
973
  answer = {"message": response}
974
  record_case(success=True, **{"input": input, "task": task_str, "results": results, "response": response, "during": during, "op":"response"})
975
  logger.info(f"response: {response}")
976
+ return answer
977
+
978
+ def test():
979
+ # single round examples
980
+ inputs = [
981
+ "Given a collection of image A: /examples/a.jpg, B: /examples/b.jpg, C: /examples/c.jpg, please tell me how many zebras in these picture?"
982
+ "Can you give me a picture of a small bird flying in the sky with trees and clouds. Generate a high definition image if possible.",
983
+ "Please answer all the named entities in the sentence: Iron Man is a superhero appearing in American comic books published by Marvel Comics. The character was co-created by writer and editor Stan Lee, developed by scripter Larry Lieber, and designed by artists Don Heck and Jack Kirby.",
984
+ "please dub for me: 'Iron Man is a superhero appearing in American comic books published by Marvel Comics. The character was co-created by writer and editor Stan Lee, developed by scripter Larry Lieber, and designed by artists Don Heck and Jack Kirby.'"
985
+ "Given an image: https://huggingface.co/datasets/mishig/sample_images/resolve/main/palace.jpg, please answer the question: What is on top of the building?",
986
+ "Please generate a canny image based on /examples/f.jpg"
987
+ ]
988
+
989
+ for input in inputs:
990
+ messages = [{"role": "user", "content": input}]
991
+ chat_huggingface(messages, API_KEY, API_TYPE, API_ENDPOINT, return_planning = False, return_results = False)
992
+
993
+ # multi rounds example
994
+ messages = [
995
+ {"role": "user", "content": "Please generate a canny image based on /examples/f.jpg"},
996
+ {"role": "assistant", "content": """Sure. I understand your request. Based on the inference results of the models, I have generated a canny image for you. The workflow I used is as follows: First, I used the image-to-text model (nlpconnect/vit-gpt2-image-captioning) to convert the image /examples/f.jpg to text. The generated text is "a herd of giraffes and zebras grazing in a field". Second, I used the canny-control model (canny-control) to generate a canny image from the text. Unfortunately, the model failed to generate the canny image. Finally, I used the canny-text-to-image model (lllyasviel/sd-controlnet-canny) to generate a canny image from the text. The generated image is located at /images/f16d.png. I hope this answers your request. Is there anything else I can help you with?"""},
997
+ {"role": "user", "content": """then based on the above canny image and a prompt "a photo of a zoo", generate a new image."""},
998
+ ]
999
+ chat_huggingface(messages, API_KEY, API_TYPE, API_ENDPOINT, return_planning = False, return_results = False)
1000
+
1001
+ def cli():
1002
+ messages = []
1003
+ print("Welcome to Jarvis! A collaborative system that consists of an LLM as the controller and numerous expert models as collaborative executors. Jarvis can plan tasks, schedule Hugging Face models, generate friendly responses based on your requests, and help you with many things. Please enter your request (`exit` to exit).")
1004
+ while True:
1005
+ message = input("[ User ]: ")
1006
+ if message == "exit":
1007
+ break
1008
+ messages.append({"role": "user", "content": message})
1009
+ answer = chat_huggingface(messages, API_KEY, API_TYPE, API_ENDPOINT, return_planning=False, return_results=False)
1010
+ print("[ Jarvis ]: ", answer["message"])
1011
+ messages.append({"role": "assistant", "content": answer["message"]})
1012
+
1013
+
1014
+ def server():
1015
+ http_listen = config["http_listen"]
1016
+ host = http_listen["host"]
1017
+ port = http_listen["port"]
1018
+
1019
+ app = flask.Flask(__name__, static_folder="public", static_url_path="/")
1020
+ app.config['DEBUG'] = False
1021
+ CORS(app)
1022
+
1023
+ @cross_origin()
1024
+ @app.route('/tasks', methods=['POST'])
1025
+ def tasks():
1026
+ data = request.get_json()
1027
+ messages = data["messages"]
1028
+ api_key = data.get("api_key", API_KEY)
1029
+ api_endpoint = data.get("api_endpoint", API_ENDPOINT)
1030
+ api_type = data.get("api_type", API_TYPE)
1031
+ if api_key is None or api_type is None or api_endpoint is None:
1032
+ return jsonify({"error": "Please provide api_key, api_type and api_endpoint"})
1033
+ response = chat_huggingface(messages, api_key, api_type, api_endpoint, return_planning=True)
1034
+ return jsonify(response)
1035
+
1036
+ @cross_origin()
1037
+ @app.route('/results', methods=['POST'])
1038
+ def results():
1039
+ data = request.get_json()
1040
+ messages = data["messages"]
1041
+ api_key = data.get("api_key", API_KEY)
1042
+ api_endpoint = data.get("api_endpoint", API_ENDPOINT)
1043
+ api_type = data.get("api_type", API_TYPE)
1044
+ if api_key is None or api_type is None or api_endpoint is None:
1045
+ return jsonify({"error": "Please provide api_key, api_type and api_endpoint"})
1046
+ response = chat_huggingface(messages, api_key, api_type, api_endpoint, return_results=True)
1047
+ return jsonify(response)
1048
+
1049
+ @cross_origin()
1050
+ @app.route('/hugginggpt', methods=['POST'])
1051
+ def chat():
1052
+ data = request.get_json()
1053
+ messages = data["messages"]
1054
+ api_key = data.get("api_key", API_KEY)
1055
+ api_endpoint = data.get("api_endpoint", API_ENDPOINT)
1056
+ api_type = data.get("api_type", API_TYPE)
1057
+ if api_key is None or api_type is None or api_endpoint is None:
1058
+ return jsonify({"error": "Please provide api_key, api_type and api_endpoint"})
1059
+ response = chat_huggingface(messages, api_key, api_type, api_endpoint)
1060
+ return jsonify(response)
1061
+ print("server running...")
1062
+ waitress.serve(app, host=host, port=port)
1063
+
1064
+ if __name__ == "__main__":
1065
+ if args.mode == "test":
1066
+ test()
1067
+ elif args.mode == "server":
1068
+ server()
1069
+ elif args.mode == "cli":
1070
+ cli()