glenn-jocher commited on
Commit
509dd51
1 Parent(s): dd03b20

check_git_status() improvements (#1916)

Browse files

* check_online()

* Update general.py

* update check_git_status()

* reverse rev-parse order

* fetch

* improved responsiveness

* comment

* comment

* remove hyp['giou'] compat warning

Files changed (2) hide show
  1. train.py +1 -6
  2. utils/general.py +27 -6
train.py CHANGED
@@ -6,7 +6,6 @@ import random
6
  import time
7
  from pathlib import Path
8
  from threading import Thread
9
- from warnings import warn
10
 
11
  import numpy as np
12
  import torch.distributed as dist
@@ -38,7 +37,7 @@ logger = logging.getLogger(__name__)
38
 
39
 
40
  def train(hyp, opt, device, tb_writer=None, wandb=None):
41
- logger.info(colorstr('Hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
42
  save_dir, epochs, batch_size, total_batch_size, weights, rank = \
43
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
44
 
@@ -502,10 +501,6 @@ if __name__ == '__main__':
502
  # Hyperparameters
503
  with open(opt.hyp) as f:
504
  hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
505
- if 'box' not in hyp:
506
- warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
507
- (opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
508
- hyp['box'] = hyp.pop('giou')
509
 
510
  # Train
511
  logger.info(opt)
 
6
  import time
7
  from pathlib import Path
8
  from threading import Thread
 
9
 
10
  import numpy as np
11
  import torch.distributed as dist
 
37
 
38
 
39
  def train(hyp, opt, device, tb_writer=None, wandb=None):
40
+ logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
41
  save_dir, epochs, batch_size, total_batch_size, weights, rank = \
42
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
43
 
 
501
  # Hyperparameters
502
  with open(opt.hyp) as f:
503
  hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
 
 
 
 
504
 
505
  # Train
506
  logger.info(opt)
utils/general.py CHANGED
@@ -4,7 +4,6 @@ import glob
4
  import logging
5
  import math
6
  import os
7
- import platform
8
  import random
9
  import re
10
  import subprocess
@@ -35,6 +34,7 @@ def set_logging(rank=-1):
35
 
36
 
37
  def init_seeds(seed=0):
 
38
  random.seed(seed)
39
  np.random.seed(seed)
40
  init_torch_seeds(seed)
@@ -46,12 +46,33 @@ def get_latest_run(search_dir='.'):
46
  return max(last_list, key=os.path.getctime) if last_list else ''
47
 
48
 
 
 
 
 
 
 
 
 
 
 
49
  def check_git_status():
50
- # Suggest 'git pull' if repo is out of date
51
- if Path('.git').exists() and platform.system() in ['Linux', 'Darwin'] and not Path('/.dockerenv').is_file():
52
- s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
53
- if 'Your branch is behind' in s:
54
- print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  def check_requirements(file='requirements.txt'):
 
4
  import logging
5
  import math
6
  import os
 
7
  import random
8
  import re
9
  import subprocess
 
34
 
35
 
36
  def init_seeds(seed=0):
37
+ # Initialize random number generator (RNG) seeds
38
  random.seed(seed)
39
  np.random.seed(seed)
40
  init_torch_seeds(seed)
 
46
  return max(last_list, key=os.path.getctime) if last_list else ''
47
 
48
 
49
+ def check_online():
50
+ # Check internet connectivity
51
+ import socket
52
+ try:
53
+ socket.create_connection(("1.1.1.1", 53)) # check host accesability
54
+ return True
55
+ except OSError:
56
+ return False
57
+
58
+
59
  def check_git_status():
60
+ # Suggest 'git pull' if YOLOv5 is out of date
61
+ print(colorstr('github: '), end='')
62
+ try:
63
+ if Path('.git').exists() and check_online():
64
+ url = subprocess.check_output(
65
+ 'git fetch && git config --get remote.origin.url', shell=True).decode('utf-8')[:-1]
66
+ n = int(subprocess.check_output(
67
+ 'git rev-list $(git rev-parse --abbrev-ref HEAD)..origin/master --count', shell=True)) # commits behind
68
+ if n > 0:
69
+ s = f"⚠️ WARNING: code is out of date by {n} {'commits' if n > 1 else 'commmit'}. " \
70
+ f"Use 'git pull' to update or 'git clone {url}' to download latest."
71
+ else:
72
+ s = f'up to date with {url} ✅'
73
+ except Exception as e:
74
+ s = str(e)
75
+ print(s)
76
 
77
 
78
  def check_requirements(file='requirements.txt'):