glenn-jocher
commited on
Commit
•
509dd51
1
Parent(s):
dd03b20
check_git_status() improvements (#1916)
Browse files* check_online()
* Update general.py
* update check_git_status()
* reverse rev-parse order
* fetch
* improved responsiveness
* comment
* comment
* remove hyp['giou'] compat warning
- train.py +1 -6
- utils/general.py +27 -6
train.py
CHANGED
@@ -6,7 +6,6 @@ import random
|
|
6 |
import time
|
7 |
from pathlib import Path
|
8 |
from threading import Thread
|
9 |
-
from warnings import warn
|
10 |
|
11 |
import numpy as np
|
12 |
import torch.distributed as dist
|
@@ -38,7 +37,7 @@ logger = logging.getLogger(__name__)
|
|
38 |
|
39 |
|
40 |
def train(hyp, opt, device, tb_writer=None, wandb=None):
|
41 |
-
logger.info(colorstr('
|
42 |
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
43 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
44 |
|
@@ -502,10 +501,6 @@ if __name__ == '__main__':
|
|
502 |
# Hyperparameters
|
503 |
with open(opt.hyp) as f:
|
504 |
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
|
505 |
-
if 'box' not in hyp:
|
506 |
-
warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
|
507 |
-
(opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
|
508 |
-
hyp['box'] = hyp.pop('giou')
|
509 |
|
510 |
# Train
|
511 |
logger.info(opt)
|
|
|
6 |
import time
|
7 |
from pathlib import Path
|
8 |
from threading import Thread
|
|
|
9 |
|
10 |
import numpy as np
|
11 |
import torch.distributed as dist
|
|
|
37 |
|
38 |
|
39 |
def train(hyp, opt, device, tb_writer=None, wandb=None):
|
40 |
+
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
41 |
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
42 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
43 |
|
|
|
501 |
# Hyperparameters
|
502 |
with open(opt.hyp) as f:
|
503 |
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
|
|
|
|
|
|
|
|
|
504 |
|
505 |
# Train
|
506 |
logger.info(opt)
|
utils/general.py
CHANGED
@@ -4,7 +4,6 @@ import glob
|
|
4 |
import logging
|
5 |
import math
|
6 |
import os
|
7 |
-
import platform
|
8 |
import random
|
9 |
import re
|
10 |
import subprocess
|
@@ -35,6 +34,7 @@ def set_logging(rank=-1):
|
|
35 |
|
36 |
|
37 |
def init_seeds(seed=0):
|
|
|
38 |
random.seed(seed)
|
39 |
np.random.seed(seed)
|
40 |
init_torch_seeds(seed)
|
@@ -46,12 +46,33 @@ def get_latest_run(search_dir='.'):
|
|
46 |
return max(last_list, key=os.path.getctime) if last_list else ''
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def check_git_status():
|
50 |
-
# Suggest 'git pull' if
|
51 |
-
|
52 |
-
|
53 |
-
if '
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
def check_requirements(file='requirements.txt'):
|
|
|
4 |
import logging
|
5 |
import math
|
6 |
import os
|
|
|
7 |
import random
|
8 |
import re
|
9 |
import subprocess
|
|
|
34 |
|
35 |
|
36 |
def init_seeds(seed=0):
|
37 |
+
# Initialize random number generator (RNG) seeds
|
38 |
random.seed(seed)
|
39 |
np.random.seed(seed)
|
40 |
init_torch_seeds(seed)
|
|
|
46 |
return max(last_list, key=os.path.getctime) if last_list else ''
|
47 |
|
48 |
|
49 |
+
def check_online():
|
50 |
+
# Check internet connectivity
|
51 |
+
import socket
|
52 |
+
try:
|
53 |
+
socket.create_connection(("1.1.1.1", 53)) # check host accesability
|
54 |
+
return True
|
55 |
+
except OSError:
|
56 |
+
return False
|
57 |
+
|
58 |
+
|
59 |
def check_git_status():
|
60 |
+
# Suggest 'git pull' if YOLOv5 is out of date
|
61 |
+
print(colorstr('github: '), end='')
|
62 |
+
try:
|
63 |
+
if Path('.git').exists() and check_online():
|
64 |
+
url = subprocess.check_output(
|
65 |
+
'git fetch && git config --get remote.origin.url', shell=True).decode('utf-8')[:-1]
|
66 |
+
n = int(subprocess.check_output(
|
67 |
+
'git rev-list $(git rev-parse --abbrev-ref HEAD)..origin/master --count', shell=True)) # commits behind
|
68 |
+
if n > 0:
|
69 |
+
s = f"⚠️ WARNING: code is out of date by {n} {'commits' if n > 1 else 'commmit'}. " \
|
70 |
+
f"Use 'git pull' to update or 'git clone {url}' to download latest."
|
71 |
+
else:
|
72 |
+
s = f'up to date with {url} ✅'
|
73 |
+
except Exception as e:
|
74 |
+
s = str(e)
|
75 |
+
print(s)
|
76 |
|
77 |
|
78 |
def check_requirements(file='requirements.txt'):
|