Kalen Michael glenn-jocher commited on
Commit
43b2817
1 Parent(s): 0dc725e

Feature/fix export on url (#4823)

Browse files

* added callbacks

* added back callback to main

* added save_dir to callback output

* merged in upstream

* removed ghost code

* added url check

* Add url2file()

* Update file-only

Co-authored-by: Glenn Jocher <[email protected]>

Files changed (2) hide show
  1. export.py +2 -2
  2. utils/general.py +7 -0
export.py CHANGED
@@ -41,7 +41,7 @@ from models.experimental import attempt_load
41
  from models.yolo import Detect
42
  from utils.activations import SiLU
43
  from utils.datasets import LoadImages
44
- from utils.general import colorstr, check_dataset, check_img_size, check_requirements, file_size, set_logging
45
  from utils.torch_utils import select_device
46
 
47
 
@@ -244,7 +244,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
244
  include = [x.lower() for x in include]
245
  tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports
246
  imgsz *= 2 if len(imgsz) == 1 else 1 # expand
247
- file = Path(weights)
248
 
249
  # Load PyTorch model
250
  device = select_device(device)
 
41
  from models.yolo import Detect
42
  from utils.activations import SiLU
43
  from utils.datasets import LoadImages
44
+ from utils.general import colorstr, check_dataset, check_img_size, check_requirements, file_size, set_logging, url2file
45
  from utils.torch_utils import select_device
46
 
47
 
 
244
  include = [x.lower() for x in include]
245
  tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports
246
  imgsz *= 2 if len(imgsz) == 1 else 1 # expand
247
+ file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)
248
 
249
  # Load PyTorch model
250
  device = select_device(device)
utils/general.py CHANGED
@@ -360,6 +360,13 @@ def check_dataset(data, autodownload=True):
360
  return data # dictionary
361
 
362
 
 
 
 
 
 
 
 
363
  def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
364
  # Multi-threaded file download and unzip function, used in data.yaml for autodownload
365
  def download_one(url, dir):
 
360
  return data # dictionary
361
 
362
 
363
+ def url2file(url):
364
+ # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
365
+ url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
366
+ file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
367
+ return file
368
+
369
+
370
  def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
371
  # Multi-threaded file download and unzip function, used in data.yaml for autodownload
372
  def download_one(url, dir):