threadshare commited on
Commit
ecf21af
1 Parent(s): a81f7bb

回退单核

Browse files
Files changed (1) hide show
  1. handler.py +3 -12
handler.py CHANGED
@@ -4,20 +4,13 @@ import os
4
  from threading import Thread
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
- from torch.nn.parallel import DistributedDataParallel as DDP
8
- import torch.distributed as dist
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 512
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
13
 
14
- def setup_distributed():
15
- dist.init_process_group(backend='nccl')
16
- torch.cuda.set_device(dist.get_rank())
17
-
18
  class EndpointHandler:
19
  def __init__(self, path=""):
20
- setup_distributed()
21
  local_config_path = "./config.json"
22
  remote_model_name = "threadshare/Peach-9B-8k-Roleplay"
23
 
@@ -28,11 +21,9 @@ class EndpointHandler:
28
  self.model_name_or_path = remote_model_name
29
 
30
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True, flash_atten=True)
31
- model = AutoModelForCausalLM.from_pretrained(
32
  self.model_name_or_path, torch_dtype=torch.bfloat16,
33
  trust_remote_code=True, device_map="auto")
34
-
35
- self.model = DDP(model.to(dist.get_rank()), device_ids=[dist.get_rank()])
36
 
37
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
38
  # print json data
@@ -75,7 +66,7 @@ class EndpointHandler:
75
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
76
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
77
 
78
- input_ids = input_ids.to(dist.get_rank())
79
  streamer = TextIteratorStreamer(self.tokenizer, timeout=50.0, skip_prompt=True, skip_special_tokens=True)
80
  generate_kwargs = dict(
81
  input_ids=input_ids,
@@ -89,7 +80,7 @@ class EndpointHandler:
89
  no_repeat_ngram_size=8,
90
  repetition_penalty=repetition_penalty
91
  )
92
- t = Thread(target=self.model.module.generate, kwargs=generate_kwargs)
93
  t.start()
94
  outputs = []
95
  for text in streamer:
 
4
  from threading import Thread
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
 
7
 
8
  MAX_MAX_NEW_TOKENS = 2048
9
  DEFAULT_MAX_NEW_TOKENS = 512
10
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
11
 
 
 
 
 
12
  class EndpointHandler:
13
  def __init__(self, path=""):
 
14
  local_config_path = "./config.json"
15
  remote_model_name = "threadshare/Peach-9B-8k-Roleplay"
16
 
 
21
  self.model_name_or_path = remote_model_name
22
 
23
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True, flash_atten=True)
24
+ self.model = AutoModelForCausalLM.from_pretrained(
25
  self.model_name_or_path, torch_dtype=torch.bfloat16,
26
  trust_remote_code=True, device_map="auto")
 
 
27
 
28
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
29
  # print json data
 
66
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
67
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
68
 
69
+ input_ids = input_ids.to("cuda")
70
  streamer = TextIteratorStreamer(self.tokenizer, timeout=50.0, skip_prompt=True, skip_special_tokens=True)
71
  generate_kwargs = dict(
72
  input_ids=input_ids,
 
80
  no_repeat_ngram_size=8,
81
  repetition_penalty=repetition_penalty
82
  )
83
+ t = Thread(target=self.model.generate, kwargs=generate_kwargs)
84
  t.start()
85
  outputs = []
86
  for text in streamer: