AppleSwing
commited on
Commit
•
dbe8db4
1
Parent(s):
6e99f9d
Apply GPU type verification on backend debug mode
Browse files- backend-cli.py +6 -1
backend-cli.py
CHANGED
@@ -448,7 +448,8 @@ def get_args():
|
|
448 |
parser.add_argument("--precision", type=str, default="float32,float16,8bit,4bit", help="Precision to debug")
|
449 |
parser.add_argument("--inference-framework", type=str, default="hf-chat", help="Inference framework to debug")
|
450 |
parser.add_argument("--limit", type=int, default=None, help="Limit for the number of samples")
|
451 |
-
parser.add_argument("--gpu-type", type=str, default="NVIDIA-A100-PCIe-80GB",
|
|
|
452 |
return parser.parse_args()
|
453 |
|
454 |
|
@@ -480,6 +481,10 @@ if __name__ == "__main__":
|
|
480 |
inference_framework=args.inference_framework, # Use inference framework from arguments
|
481 |
gpu_type=args.gpu_type
|
482 |
)
|
|
|
|
|
|
|
|
|
483 |
results = process_evaluation(task, eval_request, limit=args.limit)
|
484 |
except Exception as e:
|
485 |
print(f"debug running error: {e}")
|
|
|
448 |
parser.add_argument("--precision", type=str, default="float32,float16,8bit,4bit", help="Precision to debug")
|
449 |
parser.add_argument("--inference-framework", type=str, default="hf-chat", help="Inference framework to debug")
|
450 |
parser.add_argument("--limit", type=int, default=None, help="Limit for the number of samples")
|
451 |
+
parser.add_argument("--gpu-type", type=str, default="NVIDIA-A100-PCIe-80GB",
|
452 |
+
help="GPU type. NVIDIA-A100-PCIe-80GB; NVIDIA-RTX-A5000-24G; NVIDIA-H100-PCIe-80G")
|
453 |
return parser.parse_args()
|
454 |
|
455 |
|
|
|
481 |
inference_framework=args.inference_framework, # Use inference framework from arguments
|
482 |
gpu_type=args.gpu_type
|
483 |
)
|
484 |
+
curr_gpu_type = get_gpu_details()
|
485 |
+
if eval_request.gpu_type != curr_gpu_type:
|
486 |
+
print(f"GPU type mismatch: {eval_request.gpu_type} vs {curr_gpu_type}")
|
487 |
+
raise Exception("GPU type mismatch")
|
488 |
results = process_evaluation(task, eval_request, limit=args.limit)
|
489 |
except Exception as e:
|
490 |
print(f"debug running error: {e}")
|