bowdbeg commited on
Commit
b639f33
1 Parent(s): 8eea7aa

dtype can be selectable

Browse files
Files changed (1) hide show
  1. __main__.py +4 -2
__main__.py CHANGED
@@ -16,6 +16,7 @@ parser.add_argument("references", type=str, help="Path to the numpy array contai
16
  parser.add_argument("--output", type=str, help="Path to the output file")
17
  parser.add_argument("--batch_size", type=int, help="Batch size to use for the computation")
18
  parser.add_argument("--num_process", type=int, help="Batch size to use for the computation", default=1)
 
19
  parser.add_argument("--debug", action="store_true", help="Debug mode")
20
  args = parser.parse_args()
21
 
@@ -23,8 +24,8 @@ if not args.predictions or not args.references:
23
  raise ValueError("You must provide the path to the predictions and references numpy arrays")
24
 
25
 
26
- predictions = np.load(args.predictions).astype(np.float32)
27
- references = np.load(args.references).astype(np.float32)
28
 
29
  if args.debug:
30
  predictions = predictions[:1000]
@@ -45,6 +46,7 @@ results = metric.compute(
45
  num_process=args.num_process,
46
  return_each_features=True,
47
  return_coverages=True,
 
48
  )
49
  logger.info(f"Time taken: {time.time() - s}")
50
 
 
16
  parser.add_argument("--output", type=str, help="Path to the output file")
17
  parser.add_argument("--batch_size", type=int, help="Batch size to use for the computation")
18
  parser.add_argument("--num_process", type=int, help="Batch size to use for the computation", default=1)
19
+ parser.add_argument("--dtype", type=str, help="Data type to use for the computation", default="float32")
20
  parser.add_argument("--debug", action="store_true", help="Debug mode")
21
  args = parser.parse_args()
22
 
 
24
  raise ValueError("You must provide the path to the predictions and references numpy arrays")
25
 
26
 
27
+ predictions = np.load(args.predictions).astype(args.dtype)
28
+ references = np.load(args.references).astype(args.dtype)
29
 
30
  if args.debug:
31
  predictions = predictions[:1000]
 
46
  num_process=args.num_process,
47
  return_each_features=True,
48
  return_coverages=True,
49
+ dtype=args.dtype,
50
  )
51
  logger.info(f"Time taken: {time.time() - s}")
52