Spaces:
Sleeping
Sleeping
dtype can be selectable
Browse files- __main__.py +4 -2
__main__.py
CHANGED
@@ -16,6 +16,7 @@ parser.add_argument("references", type=str, help="Path to the numpy array contai
|
|
16 |
parser.add_argument("--output", type=str, help="Path to the output file")
|
17 |
parser.add_argument("--batch_size", type=int, help="Batch size to use for the computation")
|
18 |
parser.add_argument("--num_process", type=int, help="Batch size to use for the computation", default=1)
|
|
|
19 |
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
20 |
args = parser.parse_args()
|
21 |
|
@@ -23,8 +24,8 @@ if not args.predictions or not args.references:
|
|
23 |
raise ValueError("You must provide the path to the predictions and references numpy arrays")
|
24 |
|
25 |
|
26 |
-
predictions = np.load(args.predictions).astype(
|
27 |
-
references = np.load(args.references).astype(
|
28 |
|
29 |
if args.debug:
|
30 |
predictions = predictions[:1000]
|
@@ -45,6 +46,7 @@ results = metric.compute(
|
|
45 |
num_process=args.num_process,
|
46 |
return_each_features=True,
|
47 |
return_coverages=True,
|
|
|
48 |
)
|
49 |
logger.info(f"Time taken: {time.time() - s}")
|
50 |
|
|
|
16 |
parser.add_argument("--output", type=str, help="Path to the output file")
|
17 |
parser.add_argument("--batch_size", type=int, help="Batch size to use for the computation")
|
18 |
parser.add_argument("--num_process", type=int, help="Batch size to use for the computation", default=1)
|
19 |
+
parser.add_argument("--dtype", type=str, help="Data type to use for the computation", default="float32")
|
20 |
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
21 |
args = parser.parse_args()
|
22 |
|
|
|
24 |
raise ValueError("You must provide the path to the predictions and references numpy arrays")
|
25 |
|
26 |
|
27 |
+
predictions = np.load(args.predictions).astype(args.dtype)
|
28 |
+
references = np.load(args.references).astype(args.dtype)
|
29 |
|
30 |
if args.debug:
|
31 |
predictions = predictions[:1000]
|
|
|
46 |
num_process=args.num_process,
|
47 |
return_each_features=True,
|
48 |
return_coverages=True,
|
49 |
+
dtype=args.dtype,
|
50 |
)
|
51 |
logger.info(f"Time taken: {time.time() - s}")
|
52 |
|