versae commited on
Commit
09379f0
1 Parent(s): ad4cf83

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +7 -2
gradio_app.py CHANGED
@@ -41,7 +41,12 @@ DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
41
  if DEVICE != "cpu" and not torch.cuda.is_available():
42
  DEVICE = "cpu"
43
  logger.info(f"DEVICE {DEVICE}")
44
- DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
 
 
 
 
 
45
  MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
46
  MODEL_REVISION = os.environ.get("MODEL_REVISION", "main")
47
  MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
@@ -147,7 +152,7 @@ class TextGeneration:
147
  self.model_name_or_path, revision=MODEL_REVISION,
148
  use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
149
  pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
150
- torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True
151
  ).to(device=DEVICE, non_blocking=False)
152
  _ = self.model.eval()
153
  device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1])
 
41
  if DEVICE != "cpu" and not torch.cuda.is_available():
42
  DEVICE = "cpu"
43
  logger.info(f"DEVICE {DEVICE}")
44
+ DTYPE = getattr(
45
+ torch,
46
+ os.environ.get("DTYPE", ""),
47
+ torch.float32 if DEVICE == "cpu" else torch.float16
48
+ )
49
+ LOW_CPU_MEM = bool(os.environ.get("LOW_CPU_MEM", False if DEVICE == "cpu" else True))
50
  MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
51
  MODEL_REVISION = os.environ.get("MODEL_REVISION", "main")
52
  MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
 
152
  self.model_name_or_path, revision=MODEL_REVISION,
153
  use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
154
  pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
155
+ torch_dtype=DTYPE, low_cpu_mem_usage=LOW_CPU_MEM,
156
  ).to(device=DEVICE, non_blocking=False)
157
  _ = self.model.eval()
158
  device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1])