xiziwang
commited on
Commit
•
2e36228
1
Parent(s):
a9c14c0
push files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +61 -0
- __pycache__/dataLoader_multiperson.cpython-37.pyc +0 -0
- __pycache__/loconet.cpython-37.pyc +0 -0
- __pycache__/loss_multi.cpython-37.pyc +0 -0
- __pycache__/talkNet_config_multi.cpython-37.pyc +0 -0
- builder.py +95 -0
- configs/multi.yaml +51 -0
- dataLoaderTalkSet.py +182 -0
- dataLoader_multiperson.py +402 -0
- dlhammer/.gitignore +3 -0
- dlhammer/LICENSE +201 -0
- dlhammer/README.md +2 -0
- dlhammer/dlhammer/.ipynb_checkpoints/argparser-checkpoint.py +110 -0
- dlhammer/dlhammer/.ipynb_checkpoints/bootstrap-checkpoint.py +33 -0
- dlhammer/dlhammer/__init__.py +1 -0
- dlhammer/dlhammer/argparser.py +109 -0
- dlhammer/dlhammer/bootstrap.py +33 -0
- dlhammer/dlhammer/logger.py +66 -0
- dlhammer/dlhammer/test/config.yml +32 -0
- dlhammer/dlhammer/test/test_args.py +20 -0
- dlhammer/dlhammer/test/test_logger.py +22 -0
- dlhammer/dlhammer/utils/__init__.py +0 -0
- dlhammer/dlhammer/utils/misc.py +125 -0
- dlhammer/dlhammer/utils/system.py +25 -0
- environment.yml +298 -0
- legacy/talkNet_multi_multicard.py +124 -0
- legacy/talkNet_multicard.py +146 -0
- legacy/talkNet_orig.py +102 -0
- legacy/trainTalkNet_multicard.py +171 -0
- legacy/train_multi.py +156 -0
- loconet.py +182 -0
- loss_multi.py +72 -0
- metrics/AverageMeter.py +18 -0
- metrics/__pycache__/.nfs000000035f4a8257000000eb +0 -0
- metrics/__pycache__/AverageMeter.cpython-36.pyc +0 -0
- metrics/__pycache__/AverageMeter.cpython-38.pyc +0 -0
- metrics/__pycache__/accuracy.cpython-36.pyc +0 -0
- metrics/__pycache__/accuracy.cpython-38.pyc +0 -0
- metrics/accuracy.py +20 -0
- model/.DS_Store +0 -0
- model/__init__.py +5 -0
- model/__pycache__/__init__.cpython-36.pyc +0 -0
- model/__pycache__/__init__.cpython-37.pyc +0 -0
- model/__pycache__/attentionLayer.cpython-37.pyc +0 -0
- model/__pycache__/convLayer.cpython-37.pyc +0 -0
- model/__pycache__/loconet_encoder.cpython-37.pyc +0 -0
- model/__pycache__/position_encoding.cpython-36.pyc +0 -0
- model/__pycache__/talkNetModel.cpython-37.pyc +0 -0
- model/__pycache__/transformer.cpython-36.pyc +0 -0
- model/__pycache__/utils.cpython-36.pyc +0 -0
README.md
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## LoCoNet: Long-Short Context Network for Active Speaker Detection
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
### Dependencies
|
6 |
+
|
7 |
+
Start from building the environment
|
8 |
+
```
|
9 |
+
conda env create -f requirements.yml
|
10 |
+
conda activate loconet
|
11 |
+
```
|
12 |
+
export PYTHONPATH=**project_dir**/dlhammer:$PYTHONPATH
|
13 |
+
and replace **project_dir** with your code base location
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
### Data preparation
|
18 |
+
|
19 |
+
We follow TalkNet's data preparation script to download and prepare the AVA dataset.
|
20 |
+
|
21 |
+
```
|
22 |
+
python train.py --dataPathAVA AVADataPath --download
|
23 |
+
```
|
24 |
+
|
25 |
+
`AVADataPath` is the folder you want to save the AVA dataset and its preprocessing outputs, the details can be found in [here](https://github.com/TaoRuijie/TalkNet_ASD/blob/main/utils/tools.py#L34) . Please read them carefully.
|
26 |
+
|
27 |
+
After AVA dataset is downloaded, please change the DATA.dataPathAVA entry in the config file.
|
28 |
+
|
29 |
+
#### Training script
|
30 |
+
```
|
31 |
+
python -W ignore::UserWarning train.py --cfg configs/multi.yaml OUTPUT_DIR <output directory>
|
32 |
+
```
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
#### Pretrained model
|
37 |
+
|
38 |
+
Please download the LoCoNet trained weights on AVA dataset [here](https://drive.google.com/file/d/1EX-V464jCD6S-wg68yGuAa-UcsMrw8mK/view?usp=sharing).
|
39 |
+
|
40 |
+
```
|
41 |
+
python -W ignore::UserWarning test_multicard.py --cfg configs/multi.yaml RESUME_PATH {model download path}
|
42 |
+
```
|
43 |
+
|
44 |
+
### Citation
|
45 |
+
|
46 |
+
Please cite the following if our paper or code is helpful to your research.
|
47 |
+
```
|
48 |
+
@article{wang2023loconet,
|
49 |
+
title={LoCoNet: Long-Short Context Network for Active Speaker Detection},
|
50 |
+
author={Wang, Xizi and Cheng, Feng and Bertasius, Gedas and Crandall, David},
|
51 |
+
journal={arXiv preprint arXiv:2301.08237},
|
52 |
+
year={2023}
|
53 |
+
}
|
54 |
+
```
|
55 |
+
|
56 |
+
|
57 |
+
### Acknowledge
|
58 |
+
|
59 |
+
The code base of this project is studied from [TalkNet](https://github.com/TaoRuijie/TalkNet-ASD) which is a very easy-to-use ASD pipeline.
|
60 |
+
|
61 |
+
|
__pycache__/dataLoader_multiperson.cpython-37.pyc
ADDED
Binary file (10.8 kB). View file
|
|
__pycache__/loconet.cpython-37.pyc
ADDED
Binary file (6.26 kB). View file
|
|
__pycache__/loss_multi.cpython-37.pyc
ADDED
Binary file (2.61 kB). View file
|
|
__pycache__/talkNet_config_multi.cpython-37.pyc
ADDED
Binary file (6.59 kB). View file
|
|
builder.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#================================================================
|
3 |
+
# Don't go gently into that good night.
|
4 |
+
#
|
5 |
+
# author: klaus
|
6 |
+
# description:
|
7 |
+
#
|
8 |
+
#================================================================
|
9 |
+
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
from mmcv.cnn import MODELS as MMCV_MODELS
|
13 |
+
from mmcv.utils import Registry
|
14 |
+
|
15 |
+
from mmaction.utils import import_module_error_func
|
16 |
+
|
17 |
+
MODELS = Registry('models', parent=MMCV_MODELS)
|
18 |
+
BACKBONES = MODELS
|
19 |
+
NECKS = MODELS
|
20 |
+
HEADS = MODELS
|
21 |
+
RECOGNIZERS = MODELS
|
22 |
+
LOSSES = MODELS
|
23 |
+
LOCALIZERS = MODELS
|
24 |
+
|
25 |
+
try:
|
26 |
+
from mmdet.models.builder import DETECTORS, build_detector
|
27 |
+
except (ImportError, ModuleNotFoundError):
|
28 |
+
# Define an empty registry and building func, so that can import
|
29 |
+
DETECTORS = MODELS
|
30 |
+
|
31 |
+
@import_module_error_func('mmdet')
|
32 |
+
def build_detector(cfg, train_cfg, test_cfg):
|
33 |
+
pass
|
34 |
+
|
35 |
+
|
36 |
+
def build_backbone(cfg):
|
37 |
+
"""Build backbone."""
|
38 |
+
return BACKBONES.build(cfg)
|
39 |
+
|
40 |
+
|
41 |
+
def build_head(cfg):
|
42 |
+
"""Build head."""
|
43 |
+
return HEADS.build(cfg)
|
44 |
+
|
45 |
+
|
46 |
+
def build_recognizer(cfg, train_cfg=None, test_cfg=None):
|
47 |
+
"""Build recognizer."""
|
48 |
+
if train_cfg is not None or test_cfg is not None:
|
49 |
+
warnings.warn(
|
50 |
+
'train_cfg and test_cfg is deprecated, '
|
51 |
+
'please specify them in model. Details see this '
|
52 |
+
'PR: https://github.com/open-mmlab/mmaction2/pull/629', UserWarning)
|
53 |
+
assert cfg.get(
|
54 |
+
'train_cfg'
|
55 |
+
) is None or train_cfg is None, 'train_cfg specified in both outer field and model field' # noqa: E501
|
56 |
+
assert cfg.get(
|
57 |
+
'test_cfg'
|
58 |
+
) is None or test_cfg is None, 'test_cfg specified in both outer field and model field ' # noqa: E501
|
59 |
+
return RECOGNIZERS.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
60 |
+
|
61 |
+
|
62 |
+
def build_loss(cfg):
|
63 |
+
"""Build loss."""
|
64 |
+
return LOSSES.build(cfg)
|
65 |
+
|
66 |
+
|
67 |
+
def build_localizer(cfg):
|
68 |
+
"""Build localizer."""
|
69 |
+
return LOCALIZERS.build(cfg)
|
70 |
+
|
71 |
+
|
72 |
+
def build_model(cfg, train_cfg=None, test_cfg=None):
|
73 |
+
"""Build model."""
|
74 |
+
args = cfg.copy()
|
75 |
+
obj_type = args.pop('type')
|
76 |
+
if obj_type in LOCALIZERS:
|
77 |
+
return build_localizer(cfg)
|
78 |
+
if obj_type in RECOGNIZERS:
|
79 |
+
return build_recognizer(cfg, train_cfg, test_cfg)
|
80 |
+
if obj_type in DETECTORS:
|
81 |
+
if train_cfg is not None or test_cfg is not None:
|
82 |
+
warnings.warn(
|
83 |
+
'train_cfg and test_cfg is deprecated, '
|
84 |
+
'please specify them in model. Details see this '
|
85 |
+
'PR: https://github.com/open-mmlab/mmaction2/pull/629', UserWarning)
|
86 |
+
return build_detector(cfg, train_cfg, test_cfg)
|
87 |
+
model_in_mmdet = ['FastRCNN']
|
88 |
+
if obj_type in model_in_mmdet:
|
89 |
+
raise ImportError('Please install mmdet for spatial temporal detection tasks.')
|
90 |
+
raise ValueError(f'{obj_type} is not registered in ' 'LOCALIZERS, RECOGNIZERS or DETECTORS')
|
91 |
+
|
92 |
+
|
93 |
+
def build_neck(cfg):
|
94 |
+
"""Build neck."""
|
95 |
+
return NECKS.build(cfg)
|
configs/multi.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SEED: "20210617"
|
2 |
+
NUM_GPUS: 4
|
3 |
+
NUM_WORKERS: 6
|
4 |
+
LOG_NAME: 'config.txt'
|
5 |
+
OUTPUT_DIR: '/nfs/joltik/data/ssd/xiziwang/TalkNet_models/' # savePath
|
6 |
+
evalDataType: "val"
|
7 |
+
downloadAVA: False
|
8 |
+
evaluation: False
|
9 |
+
RESUME: False
|
10 |
+
RESUME_PATH: ""
|
11 |
+
RESUME_EPOCH: 0
|
12 |
+
|
13 |
+
DATA:
|
14 |
+
dataPathAVA: '/nfs/jolteon/data/ssd/xiziwang/AVA_dataset/'
|
15 |
+
|
16 |
+
DATALOADER:
|
17 |
+
nDataLoaderThread: 4
|
18 |
+
|
19 |
+
|
20 |
+
SOLVER:
|
21 |
+
OPTIMIZER: "adam"
|
22 |
+
BASE_LR: 5e-5
|
23 |
+
SCHEDULER:
|
24 |
+
NAME: "multistep"
|
25 |
+
GAMMA: 0.95
|
26 |
+
|
27 |
+
MODEL:
|
28 |
+
NUM_SPEAKERS: 3
|
29 |
+
CLIP_LENGTH: 200
|
30 |
+
AV: "speaker_temporal"
|
31 |
+
AV_layers: 3
|
32 |
+
ADJUST_ATTENTION: 0
|
33 |
+
|
34 |
+
TRAIN:
|
35 |
+
BATCH_SIZE: 1
|
36 |
+
MAX_EPOCH: 25
|
37 |
+
AUDIO_AUG: 1
|
38 |
+
TEST_INTERVAL: 1
|
39 |
+
TRAINER_GPU: 4
|
40 |
+
|
41 |
+
|
42 |
+
VAL:
|
43 |
+
BATCH_SIZE: 1
|
44 |
+
|
45 |
+
TEST:
|
46 |
+
BATCH_SIZE: 1
|
47 |
+
DATASET: 'seen'
|
48 |
+
MODEL: 'unseen'
|
49 |
+
|
50 |
+
|
51 |
+
|
dataLoaderTalkSet.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch, numpy, cv2, imageio, random, python_speech_features
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from scipy.io import wavfile
|
4 |
+
from glob import glob
|
5 |
+
from torchvision.transforms import RandomCrop
|
6 |
+
from scipy import signal
|
7 |
+
|
8 |
+
def get_noise_list(musanPath, rirPath):
|
9 |
+
augment_files = glob(os.path.join(musanPath, '*/*/*/*.wav'))
|
10 |
+
noiselist = {}
|
11 |
+
rir = numpy.load(rirPath)
|
12 |
+
for file in augment_files:
|
13 |
+
if not file.split('/')[-4] in noiselist:
|
14 |
+
noiselist[file.split('/')[-4]] = []
|
15 |
+
noiselist[file.split('/')[-4]].append(file)
|
16 |
+
return rir, noiselist
|
17 |
+
|
18 |
+
def augment_wav(audio, aug_type, rir, noiselist):
|
19 |
+
if aug_type == 'rir':
|
20 |
+
rir_gains = numpy.random.uniform(-7,3,1)
|
21 |
+
rir_filts = random.choice(rir)
|
22 |
+
rir = numpy.multiply(rir_filts, pow(10, 0.1 * rir_gains))
|
23 |
+
audio = signal.convolve(audio, rir, mode='full')[:len(audio)]
|
24 |
+
else:
|
25 |
+
noisecat = aug_type
|
26 |
+
noisefile = random.choice(noiselist[noisecat].copy())
|
27 |
+
snr = [random.uniform({'noise':[0,15],'music':[5,15]}[noisecat][0], {'noise':[0,15],'music':[5,15]}[noisecat][1])]
|
28 |
+
_, noiseaudio = wavfile.read(noisefile)
|
29 |
+
if len(noiseaudio) < len(audio):
|
30 |
+
shortage = len(audio) - len(noiseaudio)
|
31 |
+
noiseaudio = numpy.pad(noiseaudio, (0, shortage), 'wrap')
|
32 |
+
else:
|
33 |
+
noiseaudio = noiseaudio[:len(audio)]
|
34 |
+
|
35 |
+
noise_db = 10 * numpy.log10(numpy.mean(abs(noiseaudio ** 2)) + 1e-4)
|
36 |
+
clean_db = 10 * numpy.log10(numpy.mean(abs(audio ** 2)) + 1e-4)
|
37 |
+
noise = numpy.sqrt(10 ** ((clean_db - noise_db - snr) / 10)) * noiseaudio
|
38 |
+
audio = audio + noise
|
39 |
+
return audio.astype(numpy.int16)
|
40 |
+
|
41 |
+
def load_audio(data, data_path, length, start, end, audio_aug, rirlist = None, noiselist = None):
|
42 |
+
# Find the path of the audio data
|
43 |
+
data_type = data[0]
|
44 |
+
id_name = data[1][:8]
|
45 |
+
file_name = data[1].split('/')[0] + '_' + data[1].split('/')[1] + '_' + data[1].split('/')[2] + \
|
46 |
+
'_' + data[2].split('/')[0] + '_' + data[2].split('/')[1] + '_' + data[2].split('/')[2] + '.wav'
|
47 |
+
audio_file_path = os.path.join(data_path, data_type, id_name, file_name)
|
48 |
+
# Load audio, compute MFCC, cut it to the required length
|
49 |
+
_, audio = wavfile.read(audio_file_path)
|
50 |
+
|
51 |
+
if audio_aug == True:
|
52 |
+
augtype = random.randint(0,3)
|
53 |
+
if augtype == 1: # rir
|
54 |
+
audio = augment_wav(audio, 'rir', rirlist, noiselist)
|
55 |
+
elif augtype == 2:
|
56 |
+
audio = augment_wav(audio, 'noise', rirlist, noiselist)
|
57 |
+
elif augtype == 3:
|
58 |
+
audio = augment_wav(audio, 'music', rirlist, noiselist)
|
59 |
+
else:
|
60 |
+
audio = audio
|
61 |
+
|
62 |
+
feature = python_speech_features.mfcc(audio, 16000, numcep = 13, winlen = 0.025, winstep = 0.010)
|
63 |
+
length_audio = int(round(length * 100))
|
64 |
+
if feature.shape[0] < length_audio:
|
65 |
+
shortage = length_audio - feature.shape[0]
|
66 |
+
feature = numpy.pad(feature, ((0, shortage), (0,0)), 'wrap')
|
67 |
+
feature = feature[int(round(start * 100)):int(round(end * 100)),:]
|
68 |
+
return feature
|
69 |
+
|
70 |
+
def load_video(data, data_path, length, start, end, visual_aug):
|
71 |
+
# Find the path of the visual data
|
72 |
+
data_type = data[0]
|
73 |
+
id_name = data[1][:8]
|
74 |
+
file_name = data[1].split('/')[0] + '_' + data[1].split('/')[1] + '_' + data[1].split('/')[2] + \
|
75 |
+
'_' + data[2].split('/')[0] + '_' + data[2].split('/')[1] + '_' + data[2].split('/')[2] + '.mp4'
|
76 |
+
video_file_path = os.path.join(data_path, data_type, id_name, file_name)
|
77 |
+
# Load visual frame-by-frame, cut it to the required length
|
78 |
+
length_video = int(round((end - start) * 25))
|
79 |
+
video = cv2.VideoCapture(video_file_path)
|
80 |
+
faces = []
|
81 |
+
augtype = 'orig'
|
82 |
+
|
83 |
+
if visual_aug == True:
|
84 |
+
new = int(112*random.uniform(0.7, 1))
|
85 |
+
x, y = numpy.random.randint(0, 112 - new), numpy.random.randint(0, 112 - new)
|
86 |
+
M = cv2.getRotationMatrix2D((112/2,112/2), random.uniform(-15, 15), 1)
|
87 |
+
augtype = random.choice(['orig', 'flip', 'crop', 'rotate'])
|
88 |
+
|
89 |
+
num_frame = 0
|
90 |
+
while video.isOpened():
|
91 |
+
ret, frames = video.read()
|
92 |
+
if ret == True:
|
93 |
+
num_frame += 1
|
94 |
+
if num_frame >= int(round(start * 25)) and num_frame < int(round(end * 25)):
|
95 |
+
face = cv2.cvtColor(frames, cv2.COLOR_BGR2GRAY)
|
96 |
+
face = cv2.resize(face, (224,224))
|
97 |
+
face = face[int(112-(112/2)):int(112+(112/2)), int(112-(112/2)):int(112+(112/2))]
|
98 |
+
if augtype == 'orig':
|
99 |
+
faces.append(face)
|
100 |
+
elif augtype == 'flip':
|
101 |
+
faces.append(cv2.flip(face, 1))
|
102 |
+
elif augtype == 'crop':
|
103 |
+
faces.append(cv2.resize(face[y:y+new, x:x+new] , (112,112)))
|
104 |
+
elif augtype == 'rotate':
|
105 |
+
faces.append(cv2.warpAffine(face, M, (112,112)))
|
106 |
+
else:
|
107 |
+
break
|
108 |
+
video.release()
|
109 |
+
faces = numpy.array(faces)
|
110 |
+
if faces.shape[0] < length_video:
|
111 |
+
shortage = length_video - faces.shape[0]
|
112 |
+
faces = numpy.pad(faces, ((0,shortage), (0,0),(0,0)), 'wrap')
|
113 |
+
# faces = numpy.array(faces)[int(round(start * 25)):int(round(end * 25)),:,:]
|
114 |
+
return faces
|
115 |
+
|
116 |
+
def load_label(data, length, start, end):
|
117 |
+
labels_all = []
|
118 |
+
labels = []
|
119 |
+
data_type = data[0]
|
120 |
+
start_T, end_T, start_F, end_F = float(data[4]), float(data[5]), float(data[6]), float(data[7])
|
121 |
+
for i in range(int(round(length * 100))):
|
122 |
+
if data_type == 'TAudio':
|
123 |
+
labels_all.append(1)
|
124 |
+
elif data_type == 'FAudio' or data_type == 'FSilence':
|
125 |
+
labels_all.append(0)
|
126 |
+
else:
|
127 |
+
if i >= int(round(start_T * 100)) and i <= int(round(end_T * 100)):
|
128 |
+
labels_all.append(1)
|
129 |
+
else:
|
130 |
+
labels_all.append(0)
|
131 |
+
for i in range(int(round(length * 25))):
|
132 |
+
labels.append(int(round(sum(labels_all[i*4: (i+1)*4]) / 4)))
|
133 |
+
return labels[round(start*25): round(end*25)]
|
134 |
+
|
135 |
+
class loader_TalkSet(object):
|
136 |
+
def __init__(self, trial_file_name, data_path, audio_aug, visual_aug, musanPath, rirPath,**kwargs):
|
137 |
+
self.data_path = data_path
|
138 |
+
self.audio_aug = audio_aug
|
139 |
+
self.visual_aug = visual_aug
|
140 |
+
self.minibatch = []
|
141 |
+
self.rir, self.noiselist = get_noise_list(musanPath, rirPath)
|
142 |
+
mix_lst = open(trial_file_name).read().splitlines()
|
143 |
+
mix_lst = list(filter(lambda x: float(x.split()[3]) >= 1, mix_lst)) # filter the video less than 1s
|
144 |
+
# mix_lst = list(filter(lambda x: x.split()[0] == 'TSilence', mix_lst))
|
145 |
+
sorted_mix_lst = sorted(mix_lst, key=lambda data: (float(data.split()[3]), int(data.split()[-1])), reverse=True)
|
146 |
+
start = 0
|
147 |
+
while True:
|
148 |
+
length_total = float(sorted_mix_lst[start].split()[3])
|
149 |
+
batch_size = int(250 / length_total)
|
150 |
+
end = min(len(sorted_mix_lst), start + batch_size)
|
151 |
+
self.minibatch.append(sorted_mix_lst[start:end])
|
152 |
+
if end == len(sorted_mix_lst):
|
153 |
+
break
|
154 |
+
start = end
|
155 |
+
# self.minibatch = self.minibatch[0:5]
|
156 |
+
|
157 |
+
def __getitem__(self, index):
|
158 |
+
batch_lst = self.minibatch[index]
|
159 |
+
length_total = float(batch_lst[-1].split()[3])
|
160 |
+
length_total = (int(round(length_total * 100)) - int(round(length_total * 100)) % 4) / 100
|
161 |
+
audio_feature, video_feature, labels = [], [], []
|
162 |
+
duration = random.choice([1,2,4,6])
|
163 |
+
#duration = 6
|
164 |
+
length = min(length_total, duration)
|
165 |
+
if length == duration:
|
166 |
+
start = int(round(random.randint(0, round(length_total * 25) - round(length * 25)) * 0.04 * 100)) / 100
|
167 |
+
end = int(round((start + length) * 100)) / 100
|
168 |
+
else:
|
169 |
+
start, end = 0, length
|
170 |
+
|
171 |
+
for line in batch_lst:
|
172 |
+
data = line.split()
|
173 |
+
audio_feature.append(load_audio(data, self.data_path, length_total, start, end, audio_aug = self.audio_aug, rirlist = self.rir, noiselist = self.noiselist))
|
174 |
+
video_feature.append(load_video(data, self.data_path, length_total, start, end, visual_aug = self.visual_aug))
|
175 |
+
labels.append(load_label(data, length_total, start, end))
|
176 |
+
|
177 |
+
return torch.FloatTensor(numpy.array(audio_feature)), \
|
178 |
+
torch.FloatTensor(numpy.array(video_feature)), \
|
179 |
+
torch.LongTensor(numpy.array(labels))
|
180 |
+
|
181 |
+
def __len__(self):
|
182 |
+
return len(self.minibatch)
|
dataLoader_multiperson.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch, numpy, cv2, random, glob, python_speech_features, json, math
|
2 |
+
from scipy.io import wavfile
|
3 |
+
from torchvision.transforms import RandomCrop
|
4 |
+
from operator import itemgetter
|
5 |
+
from torchvggish import vggish_input, vggish_params, mel_features
|
6 |
+
|
7 |
+
|
8 |
+
def overlap(audio, noiseAudio):
|
9 |
+
snr = [random.uniform(-5, 5)]
|
10 |
+
if len(noiseAudio) < len(audio):
|
11 |
+
shortage = len(audio) - len(noiseAudio)
|
12 |
+
noiseAudio = numpy.pad(noiseAudio, (0, shortage), 'wrap')
|
13 |
+
else:
|
14 |
+
noiseAudio = noiseAudio[:len(audio)]
|
15 |
+
noiseDB = 10 * numpy.log10(numpy.mean(abs(noiseAudio**2)) + 1e-4)
|
16 |
+
cleanDB = 10 * numpy.log10(numpy.mean(abs(audio**2)) + 1e-4)
|
17 |
+
noiseAudio = numpy.sqrt(10**((cleanDB - noiseDB - snr) / 10)) * noiseAudio
|
18 |
+
audio = audio + noiseAudio
|
19 |
+
return audio.astype(numpy.int16)
|
20 |
+
|
21 |
+
|
22 |
+
def load_audio(data, dataPath, numFrames, audioAug, audioSet=None):
|
23 |
+
dataName = data[0]
|
24 |
+
fps = float(data[2])
|
25 |
+
audio = audioSet[dataName]
|
26 |
+
if audioAug == True:
|
27 |
+
augType = random.randint(0, 1)
|
28 |
+
if augType == 1:
|
29 |
+
audio = overlap(dataName, audio, audioSet)
|
30 |
+
else:
|
31 |
+
audio = audio
|
32 |
+
# fps is not always 25, in order to align the visual, we modify the window and step in MFCC extraction process based on fps
|
33 |
+
audio = python_speech_features.mfcc(audio,
|
34 |
+
16000,
|
35 |
+
numcep=13,
|
36 |
+
winlen=0.025 * 25 / fps,
|
37 |
+
winstep=0.010 * 25 / fps)
|
38 |
+
maxAudio = int(numFrames * 4)
|
39 |
+
if audio.shape[0] < maxAudio:
|
40 |
+
shortage = maxAudio - audio.shape[0]
|
41 |
+
audio = numpy.pad(audio, ((0, shortage), (0, 0)), 'wrap')
|
42 |
+
audio = audio[:int(round(numFrames * 4)), :]
|
43 |
+
return audio
|
44 |
+
|
45 |
+
|
46 |
+
def load_single_audio(audio, fps, numFrames, audioAug=False):
|
47 |
+
audio = python_speech_features.mfcc(audio,
|
48 |
+
16000,
|
49 |
+
numcep=13,
|
50 |
+
winlen=0.025 * 25 / fps,
|
51 |
+
winstep=0.010 * 25 / fps)
|
52 |
+
maxAudio = int(numFrames * 4)
|
53 |
+
if audio.shape[0] < maxAudio:
|
54 |
+
shortage = maxAudio - audio.shape[0]
|
55 |
+
audio = numpy.pad(audio, ((0, shortage), (0, 0)), 'wrap')
|
56 |
+
audio = audio[:int(round(numFrames * 4)), :]
|
57 |
+
return audio
|
58 |
+
|
59 |
+
|
60 |
+
def load_visual(data, dataPath, numFrames, visualAug):
|
61 |
+
dataName = data[0]
|
62 |
+
videoName = data[0][:11]
|
63 |
+
faceFolderPath = os.path.join(dataPath, videoName, dataName)
|
64 |
+
faceFiles = glob.glob("%s/*.jpg" % faceFolderPath)
|
65 |
+
sortedFaceFiles = sorted(faceFiles,
|
66 |
+
key=lambda data: (float(data.split('/')[-1][:-4])),
|
67 |
+
reverse=False)
|
68 |
+
faces = []
|
69 |
+
H = 112
|
70 |
+
if visualAug == True:
|
71 |
+
new = int(H * random.uniform(0.7, 1))
|
72 |
+
x, y = numpy.random.randint(0, H - new), numpy.random.randint(0, H - new)
|
73 |
+
M = cv2.getRotationMatrix2D((H / 2, H / 2), random.uniform(-15, 15), 1)
|
74 |
+
augType = random.choice(['orig', 'flip', 'crop', 'rotate'])
|
75 |
+
else:
|
76 |
+
augType = 'orig'
|
77 |
+
for faceFile in sortedFaceFiles[:numFrames]:
|
78 |
+
face = cv2.imread(faceFile)
|
79 |
+
|
80 |
+
face = cv2.cvtColor(face, cv2.COLOR_BGR2GRAY)
|
81 |
+
face = cv2.resize(face, (H, H))
|
82 |
+
if augType == 'orig':
|
83 |
+
faces.append(face)
|
84 |
+
elif augType == 'flip':
|
85 |
+
faces.append(cv2.flip(face, 1))
|
86 |
+
elif augType == 'crop':
|
87 |
+
faces.append(cv2.resize(face[y:y + new, x:x + new], (H, H)))
|
88 |
+
elif augType == 'rotate':
|
89 |
+
faces.append(cv2.warpAffine(face, M, (H, H)))
|
90 |
+
faces = numpy.array(faces)
|
91 |
+
return faces
|
92 |
+
|
93 |
+
|
94 |
+
def load_label(data, numFrames):
|
95 |
+
res = []
|
96 |
+
labels = data[3].replace('[', '').replace(']', '')
|
97 |
+
labels = labels.split(',')
|
98 |
+
for label in labels:
|
99 |
+
res.append(int(label))
|
100 |
+
res = numpy.array(res[:numFrames])
|
101 |
+
return res
|
102 |
+
|
103 |
+
|
104 |
+
class train_loader(object):
|
105 |
+
|
106 |
+
def __init__(self, cfg, trialFileName, audioPath, visualPath, num_speakers):
|
107 |
+
self.cfg = cfg
|
108 |
+
self.audioPath = audioPath
|
109 |
+
self.visualPath = visualPath
|
110 |
+
self.candidate_speakers = num_speakers
|
111 |
+
self.path = os.path.join(cfg.DATA.dataPathAVA, "csv")
|
112 |
+
self.entity_data = json.load(open(os.path.join(self.path, 'train_entity.json')))
|
113 |
+
self.ts_to_entity = json.load(open(os.path.join(self.path, 'train_ts.json')))
|
114 |
+
self.mixLst = open(trialFileName).read().splitlines()
|
115 |
+
self.list_length = len(self.mixLst)
|
116 |
+
random.shuffle(self.mixLst)
|
117 |
+
|
118 |
+
def load_single_audio(self, audio, fps, numFrames, audioAug=False, aug_audio=None):
|
119 |
+
if audioAug:
|
120 |
+
augType = random.randint(0, 1)
|
121 |
+
if augType == 1:
|
122 |
+
audio = overlap(audio, aug_audio)
|
123 |
+
else:
|
124 |
+
audio = audio
|
125 |
+
|
126 |
+
res = vggish_input.waveform_to_examples(audio, 16000, numFrames, fps, return_tensor=False)
|
127 |
+
return res
|
128 |
+
|
129 |
+
def load_visual_label_mask(self, videoName, entityName, target_ts, context_ts, visualAug=True):
|
130 |
+
|
131 |
+
faceFolderPath = os.path.join(self.visualPath, videoName, entityName)
|
132 |
+
|
133 |
+
faces = []
|
134 |
+
H = 112
|
135 |
+
if visualAug == True:
|
136 |
+
new = int(H * random.uniform(0.7, 1))
|
137 |
+
x, y = numpy.random.randint(0, H - new), numpy.random.randint(0, H - new)
|
138 |
+
M = cv2.getRotationMatrix2D((H / 2, H / 2), random.uniform(-15, 15), 1)
|
139 |
+
augType = random.choice(['orig', 'flip', 'crop', 'rotate'])
|
140 |
+
else:
|
141 |
+
augType = 'orig'
|
142 |
+
labels_dict = self.entity_data[videoName][entityName]
|
143 |
+
labels = numpy.zeros(len(target_ts))
|
144 |
+
mask = numpy.zeros(len(target_ts))
|
145 |
+
|
146 |
+
for i, time in enumerate(target_ts):
|
147 |
+
if time not in context_ts:
|
148 |
+
faces.append(numpy.zeros((H, H)))
|
149 |
+
else:
|
150 |
+
labels[i] = labels_dict[time]
|
151 |
+
mask[i] = 1
|
152 |
+
time = "%.2f" % float(time)
|
153 |
+
faceFile = os.path.join(faceFolderPath, str(time) + '.jpg')
|
154 |
+
|
155 |
+
face = cv2.imread(faceFile)
|
156 |
+
|
157 |
+
face = cv2.cvtColor(face, cv2.COLOR_BGR2GRAY)
|
158 |
+
face = cv2.resize(face, (H, H))
|
159 |
+
if augType == 'orig':
|
160 |
+
faces.append(face)
|
161 |
+
elif augType == 'flip':
|
162 |
+
faces.append(cv2.flip(face, 1))
|
163 |
+
elif augType == 'crop':
|
164 |
+
faces.append(cv2.resize(face[y:y + new, x:x + new], (H, H)))
|
165 |
+
elif augType == 'rotate':
|
166 |
+
faces.append(cv2.warpAffine(face, M, (H, H)))
|
167 |
+
faces = numpy.array(faces)
|
168 |
+
return faces, labels, mask
|
169 |
+
|
170 |
+
def get_speaker_context(self, videoName, target_entity, all_ts, center_ts):
|
171 |
+
|
172 |
+
context_speakers = list(self.ts_to_entity[videoName][center_ts])
|
173 |
+
context = {}
|
174 |
+
chosen_speakers = []
|
175 |
+
context[target_entity] = all_ts
|
176 |
+
context_speakers.remove(target_entity)
|
177 |
+
num_frames = len(all_ts)
|
178 |
+
for candidate in context_speakers:
|
179 |
+
candidate_ts = self.entity_data[videoName][candidate]
|
180 |
+
shared_ts = set(all_ts).intersection(set(candidate_ts))
|
181 |
+
if (len(shared_ts) > (num_frames / 2)):
|
182 |
+
context[candidate] = shared_ts
|
183 |
+
chosen_speakers.append(candidate)
|
184 |
+
context_speakers = chosen_speakers
|
185 |
+
random.shuffle(context_speakers)
|
186 |
+
if not context_speakers:
|
187 |
+
context_speakers.insert(0, target_entity) # make sure is at 0
|
188 |
+
while len(context_speakers) < self.candidate_speakers:
|
189 |
+
context_speakers.append(random.choice(context_speakers))
|
190 |
+
elif len(context_speakers) < self.candidate_speakers:
|
191 |
+
context_speakers.insert(0, target_entity) # make sure is at 0
|
192 |
+
while len(context_speakers) < self.candidate_speakers:
|
193 |
+
context_speakers.append(random.choice(context_speakers[1:]))
|
194 |
+
else:
|
195 |
+
context_speakers.insert(0, target_entity) # make sure is at 0
|
196 |
+
context_speakers = context_speakers[:self.candidate_speakers]
|
197 |
+
|
198 |
+
assert set(context_speakers).issubset(set(list(context.keys()))), target_entity
|
199 |
+
assert target_entity in context_speakers, target_entity
|
200 |
+
|
201 |
+
return context_speakers, context
|
202 |
+
|
203 |
+
def __getitem__(self, index):
|
204 |
+
|
205 |
+
target_video = self.mixLst[index]
|
206 |
+
data = target_video.split('\t')
|
207 |
+
fps = float(data[2])
|
208 |
+
videoName = data[0][:11]
|
209 |
+
target_entity = data[0]
|
210 |
+
all_ts = list(self.entity_data[videoName][target_entity].keys())
|
211 |
+
numFrames = int(data[1])
|
212 |
+
assert numFrames == len(all_ts)
|
213 |
+
|
214 |
+
center_ts = all_ts[math.floor(numFrames / 2)]
|
215 |
+
|
216 |
+
# get context speakers which have more than half time overlapped with target speaker
|
217 |
+
context_speakers, context = self.get_speaker_context(videoName, target_entity, all_ts,
|
218 |
+
center_ts)
|
219 |
+
|
220 |
+
if self.cfg.TRAIN.AUDIO_AUG:
|
221 |
+
other_indices = list(range(0, index)) + list(range(index + 1, self.list_length))
|
222 |
+
augment_entity = self.mixLst[random.choice(other_indices)]
|
223 |
+
augment_data = augment_entity.split('\t')
|
224 |
+
augment_entity = augment_data[0]
|
225 |
+
augment_videoname = augment_data[0][:11]
|
226 |
+
aug_sr, aug_audio = wavfile.read(
|
227 |
+
os.path.join(self.audioPath, augment_videoname, augment_entity + '.wav'))
|
228 |
+
else:
|
229 |
+
aug_audio = None
|
230 |
+
|
231 |
+
audio_path = os.path.join(self.audioPath, videoName, target_entity + '.wav')
|
232 |
+
sr, audio = wavfile.read(os.path.join(self.audioPath, videoName, target_entity + '.wav'))
|
233 |
+
audio = self.load_single_audio(audio,
|
234 |
+
fps,
|
235 |
+
numFrames,
|
236 |
+
audioAug=self.cfg.TRAIN.AUDIO_AUG,
|
237 |
+
aug_audio=aug_audio)
|
238 |
+
|
239 |
+
visualFeatures, labels, masks = [], [], []
|
240 |
+
|
241 |
+
# target_label = list(self.entity_data[videoName][target_entity].values())
|
242 |
+
visual, target_labels, target_masks = self.load_visual_label_mask(
|
243 |
+
videoName, target_entity, all_ts, all_ts)
|
244 |
+
|
245 |
+
for idx, context_entity in enumerate(context_speakers):
|
246 |
+
if context_entity == target_entity:
|
247 |
+
label = target_labels
|
248 |
+
visualfeat = visual
|
249 |
+
mask = target_masks
|
250 |
+
else:
|
251 |
+
visualfeat, label, mask = self.load_visual_label_mask(videoName, context_entity,
|
252 |
+
all_ts,
|
253 |
+
context[context_entity])
|
254 |
+
visualFeatures.append(visualfeat)
|
255 |
+
labels.append(label)
|
256 |
+
masks.append(mask)
|
257 |
+
|
258 |
+
audio = torch.FloatTensor(audio)[None, :, :]
|
259 |
+
visualFeatures = torch.FloatTensor(numpy.array(visualFeatures))
|
260 |
+
audio_t = audio.shape[1]
|
261 |
+
video_t = visualFeatures.shape[1]
|
262 |
+
if audio_t != video_t * 4:
|
263 |
+
print(visualFeatures.shape, audio.shape, videoName, target_entity, numFrames)
|
264 |
+
labels = torch.LongTensor(numpy.array(labels))
|
265 |
+
masks = torch.LongTensor(numpy.array(masks))
|
266 |
+
print(audio.shape)
|
267 |
+
return audio, visualFeatures, labels, masks
|
268 |
+
|
269 |
+
def __len__(self):
|
270 |
+
return len(self.mixLst)
|
271 |
+
|
272 |
+
|
273 |
+
class val_loader(object):
|
274 |
+
|
275 |
+
def __init__(self, cfg, trialFileName, audioPath, visualPath, num_speakers):
|
276 |
+
self.cfg = cfg
|
277 |
+
self.audioPath = audioPath
|
278 |
+
self.visualPath = visualPath
|
279 |
+
self.candidate_speakers = num_speakers
|
280 |
+
self.path = os.path.join(cfg.DATA.dataPathAVA, "csv")
|
281 |
+
self.entity_data = json.load(open(os.path.join(self.path, 'val_entity.json')))
|
282 |
+
self.ts_to_entity = json.load(open(os.path.join(self.path, 'val_ts.json')))
|
283 |
+
self.mixLst = open(trialFileName).read().splitlines()
|
284 |
+
|
285 |
+
def load_single_audio(self, audio, fps, numFrames, audioAug=False, aug_audio=None):
|
286 |
+
|
287 |
+
res = vggish_input.waveform_to_examples(audio, 16000, numFrames, fps, return_tensor=False)
|
288 |
+
return res
|
289 |
+
|
290 |
+
def load_visual_label_mask(self, videoName, entityName, target_ts, context_ts):
|
291 |
+
|
292 |
+
faceFolderPath = os.path.join(self.visualPath, videoName, entityName)
|
293 |
+
|
294 |
+
faces = []
|
295 |
+
H = 112
|
296 |
+
labels_dict = self.entity_data[videoName][entityName]
|
297 |
+
labels = numpy.zeros(len(target_ts))
|
298 |
+
mask = numpy.zeros(len(target_ts))
|
299 |
+
|
300 |
+
for i, time in enumerate(target_ts):
|
301 |
+
if time not in context_ts:
|
302 |
+
faces.append(numpy.zeros((H, H)))
|
303 |
+
else:
|
304 |
+
labels[i] = labels_dict[time]
|
305 |
+
mask[i] = 1
|
306 |
+
time = "%.2f" % float(time)
|
307 |
+
faceFile = os.path.join(faceFolderPath, str(time) + '.jpg')
|
308 |
+
|
309 |
+
face = cv2.imread(faceFile)
|
310 |
+
face = cv2.cvtColor(face, cv2.COLOR_BGR2GRAY)
|
311 |
+
face = cv2.resize(face, (H, H))
|
312 |
+
faces.append(face)
|
313 |
+
faces = numpy.array(faces)
|
314 |
+
return faces, labels, mask
|
315 |
+
|
316 |
+
def get_speaker_context(self, videoName, target_entity, all_ts, center_ts):
|
317 |
+
|
318 |
+
context_speakers = list(self.ts_to_entity[videoName][center_ts])
|
319 |
+
context = {}
|
320 |
+
chosen_speakers = []
|
321 |
+
context[target_entity] = all_ts
|
322 |
+
context_speakers.remove(target_entity)
|
323 |
+
num_frames = len(all_ts)
|
324 |
+
for candidate in context_speakers:
|
325 |
+
candidate_ts = self.entity_data[videoName][candidate]
|
326 |
+
shared_ts = set(all_ts).intersection(set(candidate_ts))
|
327 |
+
context[candidate] = shared_ts
|
328 |
+
chosen_speakers.append(candidate)
|
329 |
+
# if (len(shared_ts) > (num_frames / 2)):
|
330 |
+
# context[candidate] = shared_ts
|
331 |
+
# chosen_speakers.append(candidate)
|
332 |
+
context_speakers = chosen_speakers
|
333 |
+
random.shuffle(context_speakers)
|
334 |
+
if not context_speakers:
|
335 |
+
context_speakers.insert(0, target_entity) # make sure is at 0
|
336 |
+
while len(context_speakers) < self.candidate_speakers:
|
337 |
+
context_speakers.append(random.choice(context_speakers))
|
338 |
+
elif len(context_speakers) < self.candidate_speakers:
|
339 |
+
context_speakers.insert(0, target_entity) # make sure is at 0
|
340 |
+
while len(context_speakers) < self.candidate_speakers:
|
341 |
+
context_speakers.append(random.choice(context_speakers[1:]))
|
342 |
+
else:
|
343 |
+
context_speakers.insert(0, target_entity) # make sure is at 0
|
344 |
+
context_speakers = context_speakers[:self.candidate_speakers]
|
345 |
+
|
346 |
+
assert set(context_speakers).issubset(set(list(context.keys()))), target_entity
|
347 |
+
|
348 |
+
return context_speakers, context
|
349 |
+
|
350 |
+
def __getitem__(self, index):
|
351 |
+
|
352 |
+
target_video = self.mixLst[index]
|
353 |
+
data = target_video.split('\t')
|
354 |
+
fps = float(data[2])
|
355 |
+
videoName = data[0][:11]
|
356 |
+
target_entity = data[0]
|
357 |
+
all_ts = list(self.entity_data[videoName][target_entity].keys())
|
358 |
+
numFrames = int(data[1])
|
359 |
+
# print(numFrames, len(all_ts))
|
360 |
+
assert numFrames == len(all_ts)
|
361 |
+
|
362 |
+
center_ts = all_ts[math.floor(numFrames / 2)]
|
363 |
+
|
364 |
+
# get context speakers which have more than half time overlapped with target speaker
|
365 |
+
context_speakers, context = self.get_speaker_context(videoName, target_entity, all_ts,
|
366 |
+
center_ts)
|
367 |
+
|
368 |
+
sr, audio = wavfile.read(os.path.join(self.audioPath, videoName, target_entity + '.wav'))
|
369 |
+
audio = self.load_single_audio(audio, fps, numFrames, audioAug=False)
|
370 |
+
|
371 |
+
visualFeatures, labels, masks = [], [], []
|
372 |
+
|
373 |
+
# target_label = list(self.entity_data[videoName][target_entity].values())
|
374 |
+
target_visual, target_labels, target_masks = self.load_visual_label_mask(
|
375 |
+
videoName, target_entity, all_ts, all_ts)
|
376 |
+
|
377 |
+
for idx, context_entity in enumerate(context_speakers):
|
378 |
+
if context_entity == target_entity:
|
379 |
+
label = target_labels
|
380 |
+
visualfeat = target_visual
|
381 |
+
mask = target_masks
|
382 |
+
else:
|
383 |
+
visualfeat, label, mask = self.load_visual_label_mask(videoName, context_entity,
|
384 |
+
all_ts,
|
385 |
+
context[context_entity])
|
386 |
+
visualFeatures.append(visualfeat)
|
387 |
+
labels.append(label)
|
388 |
+
masks.append(mask)
|
389 |
+
|
390 |
+
audio = torch.FloatTensor(audio)[None, :, :]
|
391 |
+
visualFeatures = torch.FloatTensor(numpy.array(visualFeatures))
|
392 |
+
audio_t = audio.shape[1]
|
393 |
+
video_t = visualFeatures.shape[1]
|
394 |
+
if audio_t != video_t * 4:
|
395 |
+
print(visualFeatures.shape, audio.shape, videoName, target_entity, numFrames)
|
396 |
+
labels = torch.LongTensor(numpy.array(labels))
|
397 |
+
masks = torch.LongTensor(numpy.array(masks))
|
398 |
+
|
399 |
+
return audio, visualFeatures, labels, masks
|
400 |
+
|
401 |
+
def __len__(self):
|
402 |
+
return len(self.mixLst)
|
dlhammer/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
*.log
|
2 |
+
.vim-arsync
|
3 |
+
__pycache__/
|
dlhammer/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
dlhammer/README.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# dl-hammer
|
2 |
+
tools for deep learning coding.
|
dlhammer/dlhammer/.ipynb_checkpoints/argparser-checkpoint.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#================================================================
|
3 |
+
# Don't go gently into that good night.
|
4 |
+
#
|
5 |
+
# author: klaus
|
6 |
+
# description:
|
7 |
+
#
|
8 |
+
#================================================================
|
9 |
+
|
10 |
+
import os
|
11 |
+
import argparse
|
12 |
+
import datetime
|
13 |
+
from functools import partial
|
14 |
+
import yaml
|
15 |
+
from easydict import EasyDict
|
16 |
+
|
17 |
+
# from .utils import get_vacant_gpu
|
18 |
+
from .logger import bootstrap_logger, logger
|
19 |
+
from .utils.system import get_available_gpuids
|
20 |
+
from .utils.misc import merge_dict, merge_opts, to_string, eval_dict_leaf
|
21 |
+
|
22 |
+
CONFIG = EasyDict()
|
23 |
+
|
24 |
+
BASE_CONFIG = {
|
25 |
+
'OUTPUT_DIR': './workspace',
|
26 |
+
'SESSION': 'base',
|
27 |
+
'NUM_GPUS': 1,
|
28 |
+
'LOG_NAME': 'log.txt'
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
def bootstrap_args(default_params=None):
|
33 |
+
"""get the params from yaml file and args. The args will override arguemnts in the yaml file.
|
34 |
+
Returns: EasyDict instance.
|
35 |
+
|
36 |
+
"""
|
37 |
+
parser = define_default_arg_parser()
|
38 |
+
cfg = update_config(parser, default_params)
|
39 |
+
create_workspace(cfg) #create workspace
|
40 |
+
|
41 |
+
CONFIG.update(cfg)
|
42 |
+
bootstrap_logger(get_logfile(CONFIG)) # setup logger
|
43 |
+
setup_gpu(CONFIG.NUM_GPUS) #setup gpu
|
44 |
+
|
45 |
+
return cfg
|
46 |
+
|
47 |
+
|
48 |
+
def setup_gpu(ngpu):
|
49 |
+
gpuids = get_available_gpuids()
|
50 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in gpuids[:ngpu]])
|
51 |
+
|
52 |
+
|
53 |
+
def get_logfile(config):
|
54 |
+
return os.path.join(config.WORKSPACE, config.LOG_NAME)
|
55 |
+
|
56 |
+
|
57 |
+
def define_default_arg_parser():
|
58 |
+
"""Define a default arg_parser.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
A argparse.ArgumentParser. More arguments can be added.
|
62 |
+
|
63 |
+
"""
|
64 |
+
parser = argparse.ArgumentParser()
|
65 |
+
parser.add_argument('--cfg', help='load configs from yaml file', default='', type=str)
|
66 |
+
parser.add_argument('opts',
|
67 |
+
default=None,
|
68 |
+
nargs='*',
|
69 |
+
help='modify config options using the command-line')
|
70 |
+
|
71 |
+
return parser
|
72 |
+
|
73 |
+
|
74 |
+
def update_config(arg_parser, default_config=None):
|
75 |
+
""" update argparser to args.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
arg_parser: argparse.ArgumentParser.
|
79 |
+
"""
|
80 |
+
|
81 |
+
parsed, unknown = arg_parser.parse_known_args()
|
82 |
+
if default_config and parsed.cfg == "" and "cfg" in default_config:
|
83 |
+
parsed.cfg = default_config["cfg"]
|
84 |
+
|
85 |
+
config = EasyDict(BASE_CONFIG.copy())
|
86 |
+
config['cfg'] = parsed.cfg
|
87 |
+
# update default config
|
88 |
+
if default_config is not None:
|
89 |
+
config.update(default_config)
|
90 |
+
|
91 |
+
# merge config from yaml
|
92 |
+
if os.path.isfile(config.cfg):
|
93 |
+
with open(config.cfg, 'r') as f:
|
94 |
+
yml_config = yaml.full_load(f)
|
95 |
+
config = merge_dict(config, yml_config)
|
96 |
+
|
97 |
+
# merge opts
|
98 |
+
config = merge_opts(config, parsed.opts)
|
99 |
+
|
100 |
+
# eval values
|
101 |
+
config = eval_dict_leaf(config)
|
102 |
+
|
103 |
+
return config
|
104 |
+
|
105 |
+
|
106 |
+
def create_workspace(cfg):
|
107 |
+
cfg_name, ext = os.path.splitext(os.path.basename(cfg.cfg))
|
108 |
+
workspace = os.path.join(cfg.OUTPUT_DIR, cfg_name, cfg.SESSION)
|
109 |
+
os.makedirs(workspace, exist_ok=True)
|
110 |
+
cfg.WORKSPACE = workspace
|
dlhammer/dlhammer/.ipynb_checkpoints/bootstrap-checkpoint.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#================================================================
|
3 |
+
# Don't go gently into that good night.
|
4 |
+
#
|
5 |
+
# author: klaus
|
6 |
+
# description:
|
7 |
+
#
|
8 |
+
#================================================================
|
9 |
+
|
10 |
+
import sys
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from .logger import bootstrap_logger, logger
|
14 |
+
from .argparser import bootstrap_args, CONFIG
|
15 |
+
from .utils.misc import to_string
|
16 |
+
|
17 |
+
__all__ = ['bootstrap', 'logger', 'CONFIG']
|
18 |
+
|
19 |
+
|
20 |
+
def bootstrap(default_cfg=None, print_cfg=True):
|
21 |
+
"""TODO: Docstring for bootstrap.
|
22 |
+
|
23 |
+
Kwargs:
|
24 |
+
use_argparser (TODO): TODO
|
25 |
+
use_logger (TODO): TODO
|
26 |
+
|
27 |
+
Returns: TODO
|
28 |
+
|
29 |
+
"""
|
30 |
+
config = bootstrap_args(default_cfg)
|
31 |
+
if print_cfg:
|
32 |
+
logger.info(to_string(config))
|
33 |
+
return config
|
dlhammer/dlhammer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .bootstrap import *
|
dlhammer/dlhammer/argparser.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#================================================================
|
3 |
+
# Don't go gently into that good night.
|
4 |
+
#
|
5 |
+
# author: klaus
|
6 |
+
# description:
|
7 |
+
#
|
8 |
+
#================================================================
|
9 |
+
|
10 |
+
import os
|
11 |
+
import argparse
|
12 |
+
import datetime
|
13 |
+
from functools import partial
|
14 |
+
import yaml
|
15 |
+
from easydict import EasyDict
|
16 |
+
|
17 |
+
# from .utils import get_vacant_gpu
|
18 |
+
from .logger import bootstrap_logger, logger
|
19 |
+
from .utils.system import get_available_gpuids
|
20 |
+
from .utils.misc import merge_dict, merge_opts, to_string, eval_dict_leaf
|
21 |
+
|
22 |
+
CONFIG = EasyDict()
|
23 |
+
|
24 |
+
BASE_CONFIG = {
|
25 |
+
'OUTPUT_DIR': './workspace',
|
26 |
+
'NUM_GPUS': 1,
|
27 |
+
'LOG_NAME': 'log.txt'
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def bootstrap_args(default_params=None):
|
32 |
+
"""get the params from yaml file and args. The args will override arguemnts in the yaml file.
|
33 |
+
Returns: EasyDict instance.
|
34 |
+
|
35 |
+
"""
|
36 |
+
parser = define_default_arg_parser()
|
37 |
+
cfg = update_config(parser, default_params)
|
38 |
+
create_workspace(cfg) #create workspace
|
39 |
+
|
40 |
+
CONFIG.update(cfg)
|
41 |
+
bootstrap_logger(get_logfile(CONFIG)) # setup logger
|
42 |
+
setup_gpu(CONFIG.NUM_GPUS) #setup gpu
|
43 |
+
|
44 |
+
return cfg
|
45 |
+
|
46 |
+
|
47 |
+
def setup_gpu(ngpu):
|
48 |
+
gpuids = get_available_gpuids()
|
49 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in gpuids[:ngpu]])
|
50 |
+
|
51 |
+
|
52 |
+
def get_logfile(config):
|
53 |
+
return os.path.join(config.WORKSPACE, config.LOG_NAME)
|
54 |
+
|
55 |
+
|
56 |
+
def define_default_arg_parser():
|
57 |
+
"""Define a default arg_parser.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
A argparse.ArgumentParser. More arguments can be added.
|
61 |
+
|
62 |
+
"""
|
63 |
+
parser = argparse.ArgumentParser()
|
64 |
+
parser.add_argument('--cfg', help='load configs from yaml file', default='', type=str)
|
65 |
+
parser.add_argument('opts',
|
66 |
+
default=None,
|
67 |
+
nargs='*',
|
68 |
+
help='modify config options using the command-line')
|
69 |
+
|
70 |
+
return parser
|
71 |
+
|
72 |
+
|
73 |
+
def update_config(arg_parser, default_config=None):
|
74 |
+
""" update argparser to args.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
arg_parser: argparse.ArgumentParser.
|
78 |
+
"""
|
79 |
+
|
80 |
+
parsed, unknown = arg_parser.parse_known_args()
|
81 |
+
if default_config and parsed.cfg == "" and "cfg" in default_config:
|
82 |
+
parsed.cfg = default_config["cfg"]
|
83 |
+
|
84 |
+
config = EasyDict(BASE_CONFIG.copy())
|
85 |
+
config['cfg'] = parsed.cfg
|
86 |
+
# update default config
|
87 |
+
if default_config is not None:
|
88 |
+
config.update(default_config)
|
89 |
+
|
90 |
+
# merge config from yaml
|
91 |
+
if os.path.isfile(config.cfg):
|
92 |
+
with open(config.cfg, 'r') as f:
|
93 |
+
yml_config = yaml.full_load(f)
|
94 |
+
config = merge_dict(config, yml_config)
|
95 |
+
|
96 |
+
# merge opts
|
97 |
+
config = merge_opts(config, parsed.opts)
|
98 |
+
|
99 |
+
# eval values
|
100 |
+
config = eval_dict_leaf(config)
|
101 |
+
|
102 |
+
return config
|
103 |
+
|
104 |
+
|
105 |
+
def create_workspace(cfg):
|
106 |
+
cfg_name, ext = os.path.splitext(os.path.basename(cfg.cfg))
|
107 |
+
workspace = os.path.join(cfg.OUTPUT_DIR)
|
108 |
+
os.makedirs(workspace, exist_ok=True)
|
109 |
+
cfg.WORKSPACE = workspace
|
dlhammer/dlhammer/bootstrap.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#================================================================
|
3 |
+
# Don't go gently into that good night.
|
4 |
+
#
|
5 |
+
# author: klaus
|
6 |
+
# description:
|
7 |
+
#
|
8 |
+
#================================================================
|
9 |
+
|
10 |
+
import sys
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from .logger import bootstrap_logger, logger
|
14 |
+
from .argparser import bootstrap_args, CONFIG
|
15 |
+
from .utils.misc import to_string
|
16 |
+
|
17 |
+
__all__ = ['bootstrap', 'logger', 'CONFIG']
|
18 |
+
|
19 |
+
|
20 |
+
def bootstrap(default_cfg=None, print_cfg=True):
|
21 |
+
"""TODO: Docstring for bootstrap.
|
22 |
+
|
23 |
+
Kwargs:
|
24 |
+
use_argparser (TODO): TODO
|
25 |
+
use_logger (TODO): TODO
|
26 |
+
|
27 |
+
Returns: TODO
|
28 |
+
|
29 |
+
"""
|
30 |
+
config = bootstrap_args(default_cfg)
|
31 |
+
if print_cfg:
|
32 |
+
logger.info(to_string(config))
|
33 |
+
return config
|
dlhammer/dlhammer/logger.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#================================================================
|
3 |
+
# Don't go gently into that good night.
|
4 |
+
#
|
5 |
+
# author: klaus
|
6 |
+
# description:
|
7 |
+
#
|
8 |
+
#================================================================
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logger = logging.getLogger('DLHammer')
|
15 |
+
|
16 |
+
|
17 |
+
def bootstrap_logger(logfile=None, fmt=None):
|
18 |
+
"""TODO: Docstring for bootstrap_logger.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
logfile (str): file path logging to.
|
22 |
+
|
23 |
+
Kwargs:
|
24 |
+
fmt (TODO): TODO
|
25 |
+
|
26 |
+
Returns: TODO
|
27 |
+
|
28 |
+
"""
|
29 |
+
if fmt is None:
|
30 |
+
# fmt = '%(asctime)s - %(levelname)-5s - [%(filename)s:%(lineno)d] %(message)s'
|
31 |
+
fmt = '%(message)s'
|
32 |
+
logging.basicConfig(level=logging.DEBUG, format=fmt)
|
33 |
+
|
34 |
+
#log to file
|
35 |
+
if logfile is not None:
|
36 |
+
formatter = logging.Formatter(fmt)
|
37 |
+
fh = logging.FileHandler(logfile)
|
38 |
+
fh.setLevel(logging.DEBUG)
|
39 |
+
fh.setFormatter(formatter)
|
40 |
+
logger.addHandler(fh)
|
41 |
+
|
42 |
+
# sys.stdout = LoggerWriter(sys.stdout, logger.info)
|
43 |
+
# sys.stderr = LoggerWriter(sys.stderr, logger.error)
|
44 |
+
return
|
45 |
+
|
46 |
+
|
47 |
+
class LoggerWriter(object):
|
48 |
+
|
49 |
+
def __init__(self, stream, logfct):
|
50 |
+
self.terminal = stream
|
51 |
+
self.logfct = logfct
|
52 |
+
self.buf = []
|
53 |
+
|
54 |
+
def write(self, msg):
|
55 |
+
if msg.endswith('\n'):
|
56 |
+
self.buf.append(msg.rstrip('\n'))
|
57 |
+
|
58 |
+
message = ''.join(self.buf)
|
59 |
+
self.logfct(message)
|
60 |
+
|
61 |
+
self.buf = []
|
62 |
+
else:
|
63 |
+
self.buf.append(msg)
|
64 |
+
|
65 |
+
def flush(self):
|
66 |
+
pass
|
dlhammer/dlhammer/test/config.yml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
a_int: 12
|
2 |
+
a_float: 1e-2
|
3 |
+
a_list: [0,1,2]
|
4 |
+
eval_list: eval(list(range(10)))
|
5 |
+
DATA:
|
6 |
+
PATH_TO_DATA_DIR: /home/ubuntu/data/kinetics/Mini-Kinetics-200
|
7 |
+
PATH_PREFIX: /home/ubuntu/data/kinetics/k400_ver3
|
8 |
+
NUM_FRAMES: 16
|
9 |
+
SAMPLING_RATE: 8
|
10 |
+
TARGET_FPS: 25
|
11 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
12 |
+
TRAIN_CROP_SIZE: 224
|
13 |
+
TEST_CROP_SIZE: 224
|
14 |
+
INPUT_CHANNEL_NUM: [3]
|
15 |
+
SOLVER:
|
16 |
+
BACKBONE:
|
17 |
+
OPTIMIZER: sgd
|
18 |
+
MOMENTUM: 0.9
|
19 |
+
BASE_LR: 1e-3
|
20 |
+
SCHEDULER:
|
21 |
+
NAME: warmup_multistep
|
22 |
+
MILESTONES: [13, 24]
|
23 |
+
WARMUP_EPOCHS: 0.5
|
24 |
+
GAMMA: 0.1
|
25 |
+
TEMPORAL_MODEL:
|
26 |
+
OPTIMIZER: sgd
|
27 |
+
MOMENTUM: 0.9
|
28 |
+
BASE_LR: 1e-3
|
29 |
+
SCHEDULER:
|
30 |
+
NAME: multistep
|
31 |
+
MILESTONES: [13, 24]
|
32 |
+
GAMMA: 0.1
|
dlhammer/dlhammer/test/test_args.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#================================================================
|
3 |
+
# Don't go gently into that good night.
|
4 |
+
#
|
5 |
+
# author: klaus
|
6 |
+
# description:
|
7 |
+
#
|
8 |
+
#================================================================
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
|
13 |
+
CURRENT_FILE_DIRECTORY = os.path.abspath(os.path.dirname(__file__))
|
14 |
+
sys.path.append(os.path.join(CURRENT_FILE_DIRECTORY, '../..'))
|
15 |
+
sys.path.append(os.path.join(CURRENT_FILE_DIRECTORY, '.'))
|
16 |
+
|
17 |
+
from dlhammer import bootstrap, CONFIG
|
18 |
+
from dlhammer import logger
|
19 |
+
|
20 |
+
config = bootstrap(print_cfg=True)
|
dlhammer/dlhammer/test/test_logger.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#================================================================
|
3 |
+
# Don't go gently into that good night.
|
4 |
+
#
|
5 |
+
# author: klaus
|
6 |
+
# description:
|
7 |
+
#
|
8 |
+
#================================================================
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
|
13 |
+
CURRENT_FILE_DIRECTORY = os.path.abspath(os.path.dirname(__file__))
|
14 |
+
sys.path.append(os.path.join(CURRENT_FILE_DIRECTORY, '../..'))
|
15 |
+
sys.path.append(os.path.join(CURRENT_FILE_DIRECTORY, '.'))
|
16 |
+
|
17 |
+
from dlhammer import bootstrap, logger
|
18 |
+
bootstrap()
|
19 |
+
|
20 |
+
logger.info('dummy output')
|
21 |
+
|
22 |
+
raise Exception('dummy error')
|
dlhammer/dlhammer/utils/__init__.py
ADDED
File without changes
|
dlhammer/dlhammer/utils/misc.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#================================================================
|
3 |
+
# Don't go gently into that good night.
|
4 |
+
#
|
5 |
+
# author: klaus
|
6 |
+
# description:
|
7 |
+
#
|
8 |
+
#================================================================
|
9 |
+
|
10 |
+
import ast
|
11 |
+
|
12 |
+
|
13 |
+
def merge_dict(a, b, path=None):
|
14 |
+
"""merge b into a. The values in b will override values in a.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
a (dict): dict to merge to.
|
18 |
+
b (dict): dict to merge from.
|
19 |
+
|
20 |
+
Returns: dict1 with values merged from b.
|
21 |
+
|
22 |
+
"""
|
23 |
+
if path is None: path = []
|
24 |
+
for key in b:
|
25 |
+
if key in a:
|
26 |
+
if isinstance(a[key], dict) and isinstance(b[key], dict):
|
27 |
+
merge_dict(a[key], b[key], path + [str(key)])
|
28 |
+
else:
|
29 |
+
a[key] = b[key]
|
30 |
+
else:
|
31 |
+
a[key] = b[key]
|
32 |
+
return a
|
33 |
+
|
34 |
+
|
35 |
+
def merge_opts(d, opts):
|
36 |
+
"""merge opts
|
37 |
+
Args:
|
38 |
+
d (dict): The dict.
|
39 |
+
opts (list): The opts to merge. format: [key1, name1, key2, name2,...]
|
40 |
+
Returns: d. the input dict `d` with merged opts.
|
41 |
+
|
42 |
+
"""
|
43 |
+
assert len(opts) % 2 == 0, f'length of opts must be even. Got: {opts}'
|
44 |
+
for i in range(0, len(opts), 2):
|
45 |
+
full_k, v = opts[i], opts[i + 1]
|
46 |
+
keys = full_k.split('.')
|
47 |
+
sub_d = d
|
48 |
+
for i, k in enumerate(keys):
|
49 |
+
if not hasattr(sub_d, k):
|
50 |
+
raise ValueError(f'The key {k} not exist in the dict. Full key:{full_k}')
|
51 |
+
if i != len(keys) - 1:
|
52 |
+
sub_d = sub_d[k]
|
53 |
+
else:
|
54 |
+
sub_d[k] = v
|
55 |
+
return d
|
56 |
+
|
57 |
+
|
58 |
+
def to_string(params, indent=2):
|
59 |
+
"""format params to a string
|
60 |
+
|
61 |
+
Args:
|
62 |
+
params (EasyDict): the params.
|
63 |
+
|
64 |
+
Returns: The string to display.
|
65 |
+
|
66 |
+
"""
|
67 |
+
msg = '{\n'
|
68 |
+
for i, (k, v) in enumerate(params.items()):
|
69 |
+
if isinstance(v, dict):
|
70 |
+
v = to_string(v, indent + 4)
|
71 |
+
spaces = ' ' * indent
|
72 |
+
msg += spaces + '{}: {}'.format(k, v)
|
73 |
+
if i == len(params) - 1:
|
74 |
+
msg += ' }'
|
75 |
+
else:
|
76 |
+
msg += '\n'
|
77 |
+
return msg
|
78 |
+
|
79 |
+
|
80 |
+
def eval_dict_leaf(d):
|
81 |
+
"""eval values of dict leaf.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
d (dict): The dict to eval.
|
85 |
+
|
86 |
+
Returns: dict.
|
87 |
+
|
88 |
+
"""
|
89 |
+
for k, v in d.items():
|
90 |
+
if not isinstance(v, dict):
|
91 |
+
d[k] = eval_string(v)
|
92 |
+
else:
|
93 |
+
eval_dict_leaf(v)
|
94 |
+
return d
|
95 |
+
|
96 |
+
|
97 |
+
def eval_string(string):
|
98 |
+
"""automatically evaluate string to corresponding types.
|
99 |
+
|
100 |
+
For example:
|
101 |
+
not a string -> return the original input
|
102 |
+
'0' -> 0
|
103 |
+
'0.2' -> 0.2
|
104 |
+
'[0, 1, 2]' -> [0,1,2]
|
105 |
+
'eval(1+2)' -> 3
|
106 |
+
'eval(range(5))' -> [0,1,2,3,4]
|
107 |
+
|
108 |
+
|
109 |
+
Args:
|
110 |
+
value : string.
|
111 |
+
|
112 |
+
Returns: the corresponding type
|
113 |
+
|
114 |
+
"""
|
115 |
+
if not isinstance(string, str):
|
116 |
+
return string
|
117 |
+
if len(string) > 1 and string[0] == '[' and string[-1] == ']':
|
118 |
+
return eval(string)
|
119 |
+
if string[0:5] == 'eval(':
|
120 |
+
return eval(string[5:-1])
|
121 |
+
try:
|
122 |
+
v = ast.literal_eval(string)
|
123 |
+
except:
|
124 |
+
v = string
|
125 |
+
return v
|
dlhammer/dlhammer/utils/system.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#================================================================
|
3 |
+
# Don't go gently into that good night.
|
4 |
+
#
|
5 |
+
# author: klaus
|
6 |
+
# description:
|
7 |
+
#
|
8 |
+
#================================================================
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import subprocess
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
|
16 |
+
def get_available_gpuids():
|
17 |
+
"""
|
18 |
+
Returns: the gpu ids sorted in descending order w.r.t occupied memory.
|
19 |
+
"""
|
20 |
+
com = "nvidia-smi|sed -n '/%/p'|sed 's/|/\\n/g'|sed -n '/MiB/p'|sed 's/ //g'|sed 's/MiB/\\n/'|sed '/\\//d'"
|
21 |
+
gpum = subprocess.check_output(com, shell=True)
|
22 |
+
gpum = gpum.decode('utf-8').split('\n')
|
23 |
+
gpum = gpum[:-1]
|
24 |
+
sorted_gpuid = np.argsort(gpum)
|
25 |
+
return sorted_gpuid
|
environment.yml
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: loconet
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=conda_forge
|
7 |
+
- _openmp_mutex=4.5=1_gnu
|
8 |
+
- alsa-lib=1.2.3=h516909a_0
|
9 |
+
- anyio=3.5.0=py37h89c1867_0
|
10 |
+
- argon2-cffi=21.3.0=pyhd8ed1ab_0
|
11 |
+
- argon2-cffi-bindings=21.2.0=py37h5e8e339_1
|
12 |
+
- aria2=1.36.0=h319415d_2
|
13 |
+
- attrs=21.4.0=pyhd8ed1ab_0
|
14 |
+
- babel=2.9.1=pyh44b312d_0
|
15 |
+
- backcall=0.2.0=pyh9f0ad1d_0
|
16 |
+
- backports=1.0=py_2
|
17 |
+
- backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
|
18 |
+
- bleach=4.1.0=pyhd8ed1ab_0
|
19 |
+
- bottleneck=1.3.4=py37h6c7ee08_0
|
20 |
+
- brotli=1.0.9=h7f98852_6
|
21 |
+
- brotli-bin=1.0.9=h7f98852_6
|
22 |
+
- brotlipy=0.7.0=py37h5e8e339_1003
|
23 |
+
- c-ares=1.18.1=h7f98852_0
|
24 |
+
- ca-certificates=2022.5.18.1=ha878542_0
|
25 |
+
- cffi=1.14.6=py37hc58025e_0
|
26 |
+
- configparser=5.2.0=pyhd8ed1ab_0
|
27 |
+
- cryptography=36.0.1=py37hf1a17b8_0
|
28 |
+
- cycler=0.11.0=pyhd8ed1ab_0
|
29 |
+
- cython=0.29.27=py37hcd2ae1e_0
|
30 |
+
- dbus=1.13.6=h48d8840_2
|
31 |
+
- debugpy=1.5.1=py37hcd2ae1e_0
|
32 |
+
- defusedxml=0.7.1=pyhd8ed1ab_0
|
33 |
+
- easydict=1.9=py_0
|
34 |
+
- entrypoints=0.4=pyhd8ed1ab_0
|
35 |
+
- expat=2.4.6=h27087fc_0
|
36 |
+
- flit-core=3.7.0=pyhd8ed1ab_0
|
37 |
+
- fontconfig=2.13.96=ha180cfb_0
|
38 |
+
- fonttools=4.29.1=py37h5e8e339_0
|
39 |
+
- freetype=2.10.4=h0708190_1
|
40 |
+
- gettext=0.19.8.1=h0b5b191_1005
|
41 |
+
- giflib=5.2.1=h36c2ea0_2
|
42 |
+
- glib=2.68.4=h9c3ff4c_0
|
43 |
+
- glib-tools=2.68.4=h9c3ff4c_0
|
44 |
+
- gst-plugins-base=1.18.5=hf529b03_0
|
45 |
+
- gstreamer=1.18.5=h76c114f_0
|
46 |
+
- icu=68.2=h9c3ff4c_0
|
47 |
+
- idna=3.3=pyhd8ed1ab_0
|
48 |
+
- importlib_resources=5.4.0=pyhd8ed1ab_0
|
49 |
+
- ipykernel=6.9.1=py37h6531663_0
|
50 |
+
- ipython=7.31.1=py37h89c1867_0
|
51 |
+
- ipython_genutils=0.2.0=py_1
|
52 |
+
- jbig=2.1=h7f98852_2003
|
53 |
+
- jedi=0.18.1=py37h89c1867_0
|
54 |
+
- jinja2=3.0.3=pyhd8ed1ab_0
|
55 |
+
- jpeg=9e=h7f98852_0
|
56 |
+
- json5=0.9.5=pyh9f0ad1d_0
|
57 |
+
- jsonschema=4.4.0=pyhd8ed1ab_0
|
58 |
+
- jupyter_client=7.1.2=pyhd8ed1ab_0
|
59 |
+
- jupyter_core=4.9.2=py37h89c1867_0
|
60 |
+
- jupyter_server=1.13.5=pyhd8ed1ab_1
|
61 |
+
- jupyterlab=3.2.9=pyhd8ed1ab_0
|
62 |
+
- jupyterlab_pygments=0.1.2=pyh9f0ad1d_0
|
63 |
+
- jupyterlab_server=2.10.3=pyhd8ed1ab_0
|
64 |
+
- kiwisolver=1.3.2=py37h2527ec5_1
|
65 |
+
- krb5=1.19.2=hcc1bbae_3
|
66 |
+
- lcms2=2.12=hddcbb42_0
|
67 |
+
- ld_impl_linux-64=2.36.1=hea4e1c9_2
|
68 |
+
- lerc=3.0=h9c3ff4c_0
|
69 |
+
- libblas=3.9.0=13_linux64_openblas
|
70 |
+
- libbrotlicommon=1.0.9=h7f98852_6
|
71 |
+
- libbrotlidec=1.0.9=h7f98852_6
|
72 |
+
- libbrotlienc=1.0.9=h7f98852_6
|
73 |
+
- libcblas=3.9.0=13_linux64_openblas
|
74 |
+
- libclang=11.1.0=default_ha53f305_1
|
75 |
+
- libdeflate=1.10=h7f98852_0
|
76 |
+
- libedit=3.1.20191231=he28a2e2_2
|
77 |
+
- libevent=2.1.10=h9b69904_4
|
78 |
+
- libffi=3.3=h58526e2_2
|
79 |
+
- libgcc-ng=11.2.0=h1d223b6_12
|
80 |
+
- libgfortran-ng=11.2.0=h69a702a_12
|
81 |
+
- libgfortran5=11.2.0=h5c6108e_12
|
82 |
+
- libglib=2.68.4=h3e27bee_0
|
83 |
+
- libgomp=11.2.0=h1d223b6_12
|
84 |
+
- libiconv=1.16=h516909a_0
|
85 |
+
- liblapack=3.9.0=13_linux64_openblas
|
86 |
+
- libllvm11=11.1.0=hf817b99_3
|
87 |
+
- libogg=1.3.4=h7f98852_1
|
88 |
+
- libopenblas=0.3.18=pthreads_h8fe5266_0
|
89 |
+
- libopus=1.3.1=h7f98852_1
|
90 |
+
- libpng=1.6.37=h21135ba_2
|
91 |
+
- libpq=13.5=hd57d9b9_1
|
92 |
+
- libsodium=1.0.18=h36c2ea0_1
|
93 |
+
- libssh2=1.10.0=ha56f1ee_2
|
94 |
+
- libstdcxx-ng=11.2.0=he4da1e4_12
|
95 |
+
- libtiff=4.3.0=h542a066_3
|
96 |
+
- libuuid=2.32.1=h7f98852_1000
|
97 |
+
- libvorbis=1.3.7=h9c3ff4c_0
|
98 |
+
- libwebp=1.2.2=h3452ae3_0
|
99 |
+
- libwebp-base=1.2.2=h7f98852_1
|
100 |
+
- libxcb=1.13=h7f98852_1004
|
101 |
+
- libxkbcommon=1.0.3=he3ba5ed_0
|
102 |
+
- libxml2=2.9.12=h72842e0_0
|
103 |
+
- libzlib=1.2.11=h36c2ea0_1013
|
104 |
+
- llvmlite=0.38.0=py37h0761922_1
|
105 |
+
- lz4-c=1.9.3=h9c3ff4c_1
|
106 |
+
- markupsafe=2.1.0=py37h540881e_0
|
107 |
+
- matplotlib=3.5.1=py37h89c1867_0
|
108 |
+
- matplotlib-base=3.5.1=py37h1058ff1_0
|
109 |
+
- matplotlib-inline=0.1.3=pyhd8ed1ab_0
|
110 |
+
- mistune=0.8.4=py37h5e8e339_1005
|
111 |
+
- munkres=1.1.4=pyh9f0ad1d_0
|
112 |
+
- mysql-common=8.0.28=ha770c72_0
|
113 |
+
- mysql-libs=8.0.28=hfa10184_0
|
114 |
+
- nbclassic=0.3.5=pyhd8ed1ab_0
|
115 |
+
- nbclient=0.5.11=pyhd8ed1ab_0
|
116 |
+
- nbconvert=6.4.2=py37h89c1867_0
|
117 |
+
- nbformat=5.1.3=pyhd8ed1ab_0
|
118 |
+
- ncurses=6.2=h58526e2_4
|
119 |
+
- nest-asyncio=1.5.4=pyhd8ed1ab_0
|
120 |
+
- nomkl=1.0=h5ca1d4c_0
|
121 |
+
- notebook=6.4.8=pyha770c72_0
|
122 |
+
- nspr=4.32=h9c3ff4c_1
|
123 |
+
- nss=3.74=hb5efdd6_0
|
124 |
+
- numba=0.55.1=py37h2d894fd_0
|
125 |
+
- numexpr=2.8.0=py37hfe5f03c_101
|
126 |
+
- numpy=1.21.5=py37hf2998dd_0
|
127 |
+
- openjpeg=2.4.0=hb52868f_1
|
128 |
+
- openssl=1.1.1o=h166bdaf_0
|
129 |
+
- packaging=21.3=pyhd8ed1ab_0
|
130 |
+
- pandas=1.3.5=py37h8c16a72_0
|
131 |
+
- pandoc=2.17.1.1=ha770c72_0
|
132 |
+
- pandocfilters=1.5.0=pyhd8ed1ab_0
|
133 |
+
- parso=0.8.3=pyhd8ed1ab_0
|
134 |
+
- patsy=0.5.2=pyhd8ed1ab_0
|
135 |
+
- pcre=8.45=h9c3ff4c_0
|
136 |
+
- pexpect=4.8.0=pyh9f0ad1d_2
|
137 |
+
- pickleshare=0.7.5=py_1003
|
138 |
+
- pip=22.0.3=pyhd8ed1ab_0
|
139 |
+
- prometheus_client=0.13.1=pyhd8ed1ab_0
|
140 |
+
- prompt-toolkit=3.0.27=pyha770c72_0
|
141 |
+
- pthread-stubs=0.4=h36c2ea0_1001
|
142 |
+
- ptyprocess=0.7.0=pyhd3deb0d_0
|
143 |
+
- pycparser=2.21=pyhd8ed1ab_0
|
144 |
+
- pygments=2.11.2=pyhd8ed1ab_0
|
145 |
+
- pyopenssl=22.0.0=pyhd8ed1ab_0
|
146 |
+
- pyparsing=3.0.7=pyhd8ed1ab_0
|
147 |
+
- pyqt=5.12.3=py37h89c1867_8
|
148 |
+
- pyqt-impl=5.12.3=py37hac37412_8
|
149 |
+
- pyqt5-sip=4.19.18=py37hcd2ae1e_8
|
150 |
+
- pyqtchart=5.12=py37he336c9b_8
|
151 |
+
- pyqtwebengine=5.12.1=py37he336c9b_8
|
152 |
+
- pyrsistent=0.18.1=py37h5e8e339_0
|
153 |
+
- pysocks=1.7.1=py37h89c1867_4
|
154 |
+
- python=3.7.9=hffdb5ce_100_cpython
|
155 |
+
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
156 |
+
- python_abi=3.7=2_cp37m
|
157 |
+
- pytz=2021.3=pyhd8ed1ab_0
|
158 |
+
- pyzmq=22.3.0=py37h336d617_1
|
159 |
+
- qt=5.12.9=hda022c4_4
|
160 |
+
- readline=8.1=h46c0cb4_0
|
161 |
+
- resampy=0.2.2=py_0
|
162 |
+
- scipy=1.7.3=py37hf2a6cf1_0
|
163 |
+
- seaborn=0.11.2=hd8ed1ab_0
|
164 |
+
- seaborn-base=0.11.2=pyhd8ed1ab_0
|
165 |
+
- send2trash=1.8.0=pyhd8ed1ab_0
|
166 |
+
- six=1.16.0=pyh6c4a22f_0
|
167 |
+
- sniffio=1.2.0=py37h89c1867_2
|
168 |
+
- sqlite=3.37.0=h9cd32fc_0
|
169 |
+
- statsmodels=0.13.2=py37hb1e94ed_0
|
170 |
+
- terminado=0.13.1=py37h89c1867_0
|
171 |
+
- testpath=0.5.0=pyhd8ed1ab_0
|
172 |
+
- tk=8.6.12=h27826a3_0
|
173 |
+
- tornado=6.1=py37h5e8e339_2
|
174 |
+
- traitlets=5.1.1=pyhd8ed1ab_0
|
175 |
+
- typing_extensions=4.1.1=pyha770c72_0
|
176 |
+
- unicodedata2=14.0.0=py37h5e8e339_0
|
177 |
+
- wcwidth=0.2.5=pyh9f0ad1d_2
|
178 |
+
- webencodings=0.5.1=py_1
|
179 |
+
- websocket-client=1.2.3=pyhd8ed1ab_0
|
180 |
+
- wheel=0.37.1=pyhd8ed1ab_0
|
181 |
+
- xorg-libxau=1.0.9=h7f98852_0
|
182 |
+
- xorg-libxdmcp=1.1.3=h7f98852_0
|
183 |
+
- xz=5.2.5=h516909a_1
|
184 |
+
- zeromq=4.3.4=h9c3ff4c_1
|
185 |
+
- zlib=1.2.11=h36c2ea0_1013
|
186 |
+
- zstd=1.5.2=ha95c52a_0
|
187 |
+
- pip:
|
188 |
+
- absl-py==1.0.0
|
189 |
+
- addict==2.4.0
|
190 |
+
- aiohttp==3.8.1
|
191 |
+
- aiosignal==1.2.0
|
192 |
+
- analytics-python==1.4.0
|
193 |
+
- appdirs==1.4.4
|
194 |
+
- asgiref==3.5.2
|
195 |
+
- async-timeout==4.0.2
|
196 |
+
- asynctest==0.13.0
|
197 |
+
- audioread==2.1.9
|
198 |
+
- backoff==1.10.0
|
199 |
+
- bcrypt==3.2.2
|
200 |
+
- beautifulsoup4==4.10.0
|
201 |
+
- cachetools==4.2.4
|
202 |
+
- certifi==2021.10.8
|
203 |
+
- charset-normalizer==2.0.9
|
204 |
+
- click==8.0.3
|
205 |
+
- decorator==4.4.2
|
206 |
+
- decord==0.6.0
|
207 |
+
- einops==0.4.0
|
208 |
+
- fastapi==0.78.0
|
209 |
+
- ffmpeg==1.4
|
210 |
+
- ffmpy==0.3.0
|
211 |
+
- filelock==3.4.0
|
212 |
+
- frozenlist==1.3.0
|
213 |
+
- fsspec==2022.1.0
|
214 |
+
- future==0.18.2
|
215 |
+
- fvcore==0.1.5.post20221221
|
216 |
+
- gdown==4.2.0
|
217 |
+
- google-auth==2.3.3
|
218 |
+
- google-auth-oauthlib==0.4.6
|
219 |
+
- gradio==3.0.2
|
220 |
+
- grpcio==1.43.0
|
221 |
+
- h11==0.13.0
|
222 |
+
- imageio==2.23.0
|
223 |
+
- imageio-ffmpeg==0.4.7
|
224 |
+
- importlib-metadata==4.10.0
|
225 |
+
- iopath==0.1.10
|
226 |
+
- ipywidgets==8.0.4
|
227 |
+
- joblib==1.1.0
|
228 |
+
- jupyterlab-widgets==3.0.5
|
229 |
+
- librosa==0.9.1
|
230 |
+
- linkify-it-py==1.0.3
|
231 |
+
- lmdb==1.4.1
|
232 |
+
- markdown==3.3.6
|
233 |
+
- markdown-it-py==2.1.0
|
234 |
+
- mdit-py-plugins==0.3.0
|
235 |
+
- mdurl==0.1.1
|
236 |
+
- mmaction2==0.24.1
|
237 |
+
- mmcv==1.7.0
|
238 |
+
- mmcv-full==1.4.6
|
239 |
+
- monotonic==1.6
|
240 |
+
- moviepy==1.0.3
|
241 |
+
- multidict==5.2.0
|
242 |
+
- oauthlib==3.1.1
|
243 |
+
- opencv-contrib-python==4.7.0.68
|
244 |
+
- opencv-python==4.5.5.62
|
245 |
+
- orjson==3.6.8
|
246 |
+
- paramiko==2.11.0
|
247 |
+
- pillow==8.3.2
|
248 |
+
- pooch==1.6.0
|
249 |
+
- portalocker==2.7.0
|
250 |
+
- proglog==0.1.10
|
251 |
+
- protobuf==3.19.3
|
252 |
+
- pyasn1==0.4.8
|
253 |
+
- pyasn1-modules==0.2.8
|
254 |
+
- pycryptodome==3.14.1
|
255 |
+
- pydantic==1.9.0
|
256 |
+
- pydeprecate==0.3.1
|
257 |
+
- pydub==0.25.1
|
258 |
+
- pynacl==1.5.0
|
259 |
+
- python-box==6.0.2
|
260 |
+
- python-multipart==0.0.5
|
261 |
+
- python-speech-features==0.6
|
262 |
+
- pytorch-lightning==1.5.8
|
263 |
+
- pyyaml==6.0
|
264 |
+
- requests==2.26.0
|
265 |
+
- requests-oauthlib==1.3.0
|
266 |
+
- rsa==4.8
|
267 |
+
- scenedetect==0.5.6.1
|
268 |
+
- scikit-learn==1.0.1
|
269 |
+
- setuptools==60.9.3
|
270 |
+
- soundfile==0.10.3.post1
|
271 |
+
- soupsieve==2.3.1
|
272 |
+
- starlette==0.19.1
|
273 |
+
- tabulate==0.9.0
|
274 |
+
- tensorboard==2.7.0
|
275 |
+
- tensorboard-data-server==0.6.1
|
276 |
+
- tensorboard-plugin-wit==1.8.1
|
277 |
+
- termcolor==2.2.0
|
278 |
+
- threadpoolctl==3.0.0
|
279 |
+
- timm==0.4.5
|
280 |
+
- torch==1.10.1
|
281 |
+
- torchaudio==0.10.1
|
282 |
+
- torchlibrosa==0.0.9
|
283 |
+
- torchmetrics==0.7.0
|
284 |
+
- torchvision==0.11.2
|
285 |
+
- tqdm==4.62.3
|
286 |
+
- typing-extensions==4.0.1
|
287 |
+
- uc-micro-py==1.0.1
|
288 |
+
- urllib3==1.26.7
|
289 |
+
- uvicorn==0.17.6
|
290 |
+
- warmup-scheduler-pytorch==0.1.2
|
291 |
+
- werkzeug==2.0.2
|
292 |
+
- wget==3.2
|
293 |
+
- widgetsnbextension==4.0.5
|
294 |
+
- yacs==0.1.8
|
295 |
+
- yapf==0.32.0
|
296 |
+
- yarl==1.7.2
|
297 |
+
- youtube-dl==2021.12.17
|
298 |
+
- zipp==3.6.0
|
legacy/talkNet_multi_multicard.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import sys, time, numpy, os, subprocess, pandas, tqdm
|
6 |
+
|
7 |
+
from loss_multi import lossAV, lossA, lossV
|
8 |
+
from model.talkNetModel import talkNetModel
|
9 |
+
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
from torch import distributed as dist
|
12 |
+
|
13 |
+
|
14 |
+
class talkNet(pl.LightningModule):
|
15 |
+
|
16 |
+
def __init__(self, cfg):
|
17 |
+
super(talkNet, self).__init__()
|
18 |
+
self.model = talkNetModel().cuda()
|
19 |
+
self.cfg = cfg
|
20 |
+
self.lossAV = lossAV().cuda()
|
21 |
+
self.lossA = lossA().cuda()
|
22 |
+
self.lossV = lossV().cuda()
|
23 |
+
print(
|
24 |
+
time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" %
|
25 |
+
(sum(param.numel() for param in self.model.parameters()) / 1024 / 1024))
|
26 |
+
|
27 |
+
def configure_optimizers(self):
|
28 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.SOLVER.BASE_LR)
|
29 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
|
30 |
+
step_size=1,
|
31 |
+
gamma=self.cfg.SOLVER.SCHEDULER.GAMMA)
|
32 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
33 |
+
|
34 |
+
def training_step(self, batch, batch_idx):
|
35 |
+
audioFeature, visualFeature, labels, masks = batch
|
36 |
+
b, s, t = visualFeature.shape[0], visualFeature.shape[1], visualFeature.shape[2]
|
37 |
+
audioFeature = audioFeature.repeat(1, s, 1, 1)
|
38 |
+
audioFeature = audioFeature.view(b * s, *audioFeature.shape[2:])
|
39 |
+
visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
|
40 |
+
labels = labels.view(b * s, *labels.shape[2:])
|
41 |
+
masks = masks.view(b * s, *masks.shape[2:])
|
42 |
+
|
43 |
+
audioEmbed = self.model.forward_audio_frontend(audioFeature) # feedForward
|
44 |
+
visualEmbed = self.model.forward_visual_frontend(visualFeature)
|
45 |
+
audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
|
46 |
+
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
|
47 |
+
outsA = self.model.forward_audio_backend(audioEmbed)
|
48 |
+
outsV = self.model.forward_visual_backend(visualEmbed)
|
49 |
+
labels = labels.reshape((-1))
|
50 |
+
nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks)
|
51 |
+
nlossA = self.lossA.forward(outsA, labels, masks)
|
52 |
+
nlossV = self.lossV.forward(outsV, labels, masks)
|
53 |
+
loss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
|
54 |
+
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
55 |
+
return loss
|
56 |
+
|
57 |
+
def training_epoch_end(self, training_step_outputs):
|
58 |
+
self.saveParameters(
|
59 |
+
os.path.join(self.cfg.WORKSPACE, "model", "{}.pth".format(self.current_epoch)))
|
60 |
+
|
61 |
+
def evaluate_network(self, loader):
|
62 |
+
self.eval()
|
63 |
+
predScores = []
|
64 |
+
self.model = self.model.cuda()
|
65 |
+
self.lossAV = self.lossAV.cuda()
|
66 |
+
self.lossA = self.lossA.cuda()
|
67 |
+
self.lossV = self.lossV.cuda()
|
68 |
+
evalCsvSave = self.cfg.evalCsvSave
|
69 |
+
evalOrig = self.cfg.evalOrig
|
70 |
+
for audioFeature, visualFeature, labels, masks in tqdm.tqdm(loader):
|
71 |
+
with torch.no_grad():
|
72 |
+
b, s = visualFeature.shape[0], visualFeature.shape[1]
|
73 |
+
t = visualFeature.shape[2]
|
74 |
+
audioFeature = audioFeature.repeat(1, s, 1, 1)
|
75 |
+
audioFeature = audioFeature.view(b * s, *audioFeature.shape[2:])
|
76 |
+
visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
|
77 |
+
labels = labels.view(b * s, *labels.shape[2:])
|
78 |
+
masks = masks.view(b * s, *masks.shape[2:])
|
79 |
+
audioEmbed = self.model.forward_audio_frontend(audioFeature.cuda())
|
80 |
+
visualEmbed = self.model.forward_visual_frontend(visualFeature.cuda())
|
81 |
+
audioEmbed, visualEmbed = self.model.forward_cross_attention(
|
82 |
+
audioEmbed, visualEmbed)
|
83 |
+
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
|
84 |
+
labels = labels.reshape((-1)).cuda()
|
85 |
+
outsAV = outsAV.view(b, s, t, -1)[:, 0, :, :].view(b * t, -1)
|
86 |
+
labels = labels.view(b, s, t)[:, 0, :].view(b * t)
|
87 |
+
masks = masks.view(b, s, t)[:, 0, :].view(b * t)
|
88 |
+
_, predScore, _, _ = self.lossAV.forward(outsAV, labels, masks)
|
89 |
+
predScore = predScore.detach().cpu().numpy()
|
90 |
+
predScores.extend(predScore)
|
91 |
+
evalLines = open(evalOrig).read().splitlines()[1:]
|
92 |
+
labels = []
|
93 |
+
labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
|
94 |
+
scores = pandas.Series(predScores)
|
95 |
+
evalRes = pandas.read_csv(evalOrig)
|
96 |
+
evalRes['score'] = scores
|
97 |
+
evalRes['label'] = labels
|
98 |
+
evalRes.drop(['label_id'], axis=1, inplace=True)
|
99 |
+
evalRes.drop(['instance_id'], axis=1, inplace=True)
|
100 |
+
evalRes.to_csv(evalCsvSave, index=False)
|
101 |
+
cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
|
102 |
+
evalCsvSave)
|
103 |
+
mAP = float(
|
104 |
+
str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
|
105 |
+
return mAP
|
106 |
+
|
107 |
+
def saveParameters(self, path):
|
108 |
+
torch.save(self.state_dict(), path)
|
109 |
+
|
110 |
+
def loadParameters(self, path):
|
111 |
+
selfState = self.state_dict()
|
112 |
+
loadedState = torch.load(path)
|
113 |
+
for name, param in loadedState.items():
|
114 |
+
origName = name
|
115 |
+
if name not in selfState:
|
116 |
+
name = name.replace("module.", "")
|
117 |
+
if name not in selfState:
|
118 |
+
print("%s is not in the model." % origName)
|
119 |
+
continue
|
120 |
+
if selfState[name].size() != loadedState[origName].size():
|
121 |
+
sys.stderr.write("Wrong parameter length: %s, model: %s, loaded: %s" %
|
122 |
+
(origName, selfState[name].size(), loadedState[origName].size()))
|
123 |
+
continue
|
124 |
+
selfState[name].copy_(param)
|
legacy/talkNet_multicard.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import sys, time, numpy, os, subprocess, pandas, tqdm
|
6 |
+
|
7 |
+
from loss import lossAV, lossA, lossV
|
8 |
+
from model.talkNetModel import talkNetModel
|
9 |
+
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
from torch import distributed as dist
|
12 |
+
|
13 |
+
|
14 |
+
class talkNet(pl.LightningModule):
|
15 |
+
|
16 |
+
def __init__(self, cfg):
|
17 |
+
super(talkNet, self).__init__()
|
18 |
+
self.cfg = cfg
|
19 |
+
self.model = talkNetModel()
|
20 |
+
self.lossAV = lossAV()
|
21 |
+
self.lossA = lossA()
|
22 |
+
self.lossV = lossV()
|
23 |
+
print(
|
24 |
+
time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" %
|
25 |
+
(sum(param.numel() for param in self.model.parameters()) / 1024 / 1024))
|
26 |
+
|
27 |
+
def configure_optimizers(self):
|
28 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.SOLVER.BASE_LR)
|
29 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
|
30 |
+
step_size=1,
|
31 |
+
gamma=self.cfg.SOLVER.SCHEDULER.GAMMA)
|
32 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
33 |
+
|
34 |
+
def training_step(self, batch, batch_idx):
|
35 |
+
audioFeature, visualFeature, labels = batch
|
36 |
+
audioEmbed = self.model.forward_audio_frontend(audioFeature[0]) # feedForward
|
37 |
+
visualEmbed = self.model.forward_visual_frontend(visualFeature[0])
|
38 |
+
audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
|
39 |
+
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
|
40 |
+
outsA = self.model.forward_audio_backend(audioEmbed)
|
41 |
+
outsV = self.model.forward_visual_backend(visualEmbed)
|
42 |
+
labels = labels[0].reshape((-1))
|
43 |
+
nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels)
|
44 |
+
nlossA = self.lossA.forward(outsA, labels)
|
45 |
+
nlossV = self.lossV.forward(outsV, labels)
|
46 |
+
loss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
|
47 |
+
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
48 |
+
|
49 |
+
return loss
|
50 |
+
|
51 |
+
def training_epoch_end(self, training_step_outputs):
|
52 |
+
self.saveParameters(
|
53 |
+
os.path.join(self.cfg.WORKSPACE, "model", "{}.pth".format(self.current_epoch)))
|
54 |
+
|
55 |
+
def validation_step(self, batch, batch_idx):
|
56 |
+
audioFeature, visualFeature, labels, indices = batch
|
57 |
+
audioEmbed = self.model.forward_audio_frontend(audioFeature[0])
|
58 |
+
visualEmbed = self.model.forward_visual_frontend(visualFeature[0])
|
59 |
+
audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
|
60 |
+
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
|
61 |
+
labels = labels[0].reshape((-1))
|
62 |
+
loss, predScore, _, _ = self.lossAV.forward(outsAV, labels)
|
63 |
+
predScore = predScore[:, -1:].detach().cpu().numpy()
|
64 |
+
# self.log("val_loss", loss)
|
65 |
+
|
66 |
+
return predScore
|
67 |
+
|
68 |
+
def validation_epoch_end(self, validation_step_outputs):
|
69 |
+
evalCsvSave = self.cfg.evalCsvSave
|
70 |
+
evalOrig = self.cfg.evalOrig
|
71 |
+
predScores = []
|
72 |
+
|
73 |
+
for out in validation_step_outputs: # batch size =1
|
74 |
+
predScores.extend(out)
|
75 |
+
|
76 |
+
evalLines = open(evalOrig).read().splitlines()[1:]
|
77 |
+
labels = []
|
78 |
+
labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
|
79 |
+
scores = pandas.Series(predScores)
|
80 |
+
evalRes = pandas.read_csv(evalOrig)
|
81 |
+
print(len(evalRes), len(predScores), len(evalLines))
|
82 |
+
evalRes['score'] = scores
|
83 |
+
evalRes['label'] = labels
|
84 |
+
evalRes.drop(['label_id'], axis=1, inplace=True)
|
85 |
+
evalRes.drop(['instance_id'], axis=1, inplace=True)
|
86 |
+
evalRes.to_csv(evalCsvSave, index=False)
|
87 |
+
cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
|
88 |
+
evalCsvSave)
|
89 |
+
mAP = float(
|
90 |
+
str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
|
91 |
+
print("validation mAP: {}".format(mAP))
|
92 |
+
|
93 |
+
def saveParameters(self, path):
|
94 |
+
torch.save(self.state_dict(), path)
|
95 |
+
|
96 |
+
def loadParameters(self, path):
|
97 |
+
selfState = self.state_dict()
|
98 |
+
loadedState = torch.load(path, map_location='cpu')
|
99 |
+
for name, param in loadedState.items():
|
100 |
+
origName = name
|
101 |
+
if name not in selfState:
|
102 |
+
name = name.replace("module.", "")
|
103 |
+
if name not in selfState:
|
104 |
+
print("%s is not in the model." % origName)
|
105 |
+
continue
|
106 |
+
if selfState[name].size() != loadedState[origName].size():
|
107 |
+
sys.stderr.write("Wrong parameter length: %s, model: %s, loaded: %s" %
|
108 |
+
(origName, selfState[name].size(), loadedState[origName].size()))
|
109 |
+
continue
|
110 |
+
selfState[name].copy_(param)
|
111 |
+
|
112 |
+
def evaluate_network(self, loader):
|
113 |
+
self.eval()
|
114 |
+
self.model = self.model.cuda()
|
115 |
+
self.lossAV = self.lossAV.cuda()
|
116 |
+
self.lossA = self.lossA.cuda()
|
117 |
+
self.lossV = self.lossV.cuda()
|
118 |
+
predScores = []
|
119 |
+
evalCsvSave = self.cfg.evalCsvSave
|
120 |
+
evalOrig = self.cfg.evalOrig
|
121 |
+
for audioFeature, visualFeature, labels in tqdm.tqdm(loader):
|
122 |
+
with torch.no_grad():
|
123 |
+
audioEmbed = self.model.forward_audio_frontend(audioFeature[0].cuda())
|
124 |
+
visualEmbed = self.model.forward_visual_frontend(visualFeature[0].cuda())
|
125 |
+
audioEmbed, visualEmbed = self.model.forward_cross_attention(
|
126 |
+
audioEmbed, visualEmbed)
|
127 |
+
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
|
128 |
+
labels = labels[0].reshape((-1)).cuda()
|
129 |
+
_, predScore, _, _ = self.lossAV.forward(outsAV, labels)
|
130 |
+
predScore = predScore[:, 1].detach().cpu().numpy()
|
131 |
+
predScores.extend(predScore)
|
132 |
+
evalLines = open(evalOrig).read().splitlines()[1:]
|
133 |
+
labels = []
|
134 |
+
labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
|
135 |
+
scores = pandas.Series(predScores)
|
136 |
+
evalRes = pandas.read_csv(evalOrig)
|
137 |
+
evalRes['score'] = scores
|
138 |
+
evalRes['label'] = labels
|
139 |
+
evalRes.drop(['label_id'], axis=1, inplace=True)
|
140 |
+
evalRes.drop(['instance_id'], axis=1, inplace=True)
|
141 |
+
evalRes.to_csv(evalCsvSave, index=False)
|
142 |
+
cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
|
143 |
+
evalCsvSave)
|
144 |
+
mAP = float(
|
145 |
+
str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
|
146 |
+
return mAP
|
legacy/talkNet_orig.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import sys, time, numpy, os, subprocess, pandas, tqdm
|
6 |
+
|
7 |
+
from loss import lossAV, lossA, lossV
|
8 |
+
from model.talkNetModel import talkNetModel
|
9 |
+
|
10 |
+
|
11 |
+
class talkNet(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self, lr=0.0001, lrDecay=0.95, **kwargs):
|
14 |
+
super(talkNet, self).__init__()
|
15 |
+
self.model = talkNetModel().cuda()
|
16 |
+
self.lossAV = lossAV().cuda()
|
17 |
+
self.lossA = lossA().cuda()
|
18 |
+
self.lossV = lossV().cuda()
|
19 |
+
self.optim = torch.optim.Adam(self.parameters(), lr=lr)
|
20 |
+
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim, step_size=1, gamma=lrDecay)
|
21 |
+
print(
|
22 |
+
time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" %
|
23 |
+
(sum(param.numel() for param in self.model.parameters()) / 1024 / 1024))
|
24 |
+
|
25 |
+
def train_network(self, loader, epoch, **kwargs):
|
26 |
+
self.train()
|
27 |
+
self.scheduler.step(epoch - 1)
|
28 |
+
index, top1, loss = 0, 0, 0
|
29 |
+
lr = self.optim.param_groups[0]['lr']
|
30 |
+
for num, (audioFeature, visualFeature, labels) in enumerate(loader, start=1):
|
31 |
+
self.zero_grad()
|
32 |
+
audioEmbed = self.model.forward_audio_frontend(audioFeature[0].cuda()) # feedForward
|
33 |
+
visualEmbed = self.model.forward_visual_frontend(visualFeature[0].cuda())
|
34 |
+
audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
|
35 |
+
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
|
36 |
+
outsA = self.model.forward_audio_backend(audioEmbed)
|
37 |
+
outsV = self.model.forward_visual_backend(visualEmbed)
|
38 |
+
labels = labels[0].reshape((-1)).cuda() # Loss
|
39 |
+
nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels)
|
40 |
+
nlossA = self.lossA.forward(outsA, labels)
|
41 |
+
nlossV = self.lossV.forward(outsV, labels)
|
42 |
+
nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
|
43 |
+
loss += nloss.detach().cpu().numpy()
|
44 |
+
top1 += prec
|
45 |
+
nloss.backward()
|
46 |
+
self.optim.step()
|
47 |
+
index += len(labels)
|
48 |
+
sys.stderr.write(time.strftime("%m-%d %H:%M:%S") + \
|
49 |
+
" [%2d] Lr: %5f, Training: %.2f%%, " %(epoch, lr, 100 * (num / loader.__len__())) + \
|
50 |
+
" Loss: %.5f, ACC: %2.2f%% \r" %(loss/(num), 100 * (top1/index)))
|
51 |
+
sys.stderr.flush()
|
52 |
+
sys.stdout.write("\n")
|
53 |
+
return loss / num, lr
|
54 |
+
|
55 |
+
def evaluate_network(self, loader, evalCsvSave, evalOrig, **kwargs):
|
56 |
+
self.eval()
|
57 |
+
predScores = []
|
58 |
+
for audioFeature, visualFeature, labels in tqdm.tqdm(loader):
|
59 |
+
with torch.no_grad():
|
60 |
+
audioEmbed = self.model.forward_audio_frontend(audioFeature[0].cuda())
|
61 |
+
visualEmbed = self.model.forward_visual_frontend(visualFeature[0].cuda())
|
62 |
+
audioEmbed, visualEmbed = self.model.forward_cross_attention(
|
63 |
+
audioEmbed, visualEmbed)
|
64 |
+
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
|
65 |
+
labels = labels[0].reshape((-1)).cuda()
|
66 |
+
_, predScore, _, _ = self.lossAV.forward(outsAV, labels)
|
67 |
+
predScore = predScore[:, 1].detach().cpu().numpy()
|
68 |
+
predScores.extend(predScore)
|
69 |
+
evalLines = open(evalOrig).read().splitlines()[1:]
|
70 |
+
labels = []
|
71 |
+
labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
|
72 |
+
scores = pandas.Series(predScores)
|
73 |
+
evalRes = pandas.read_csv(evalOrig)
|
74 |
+
evalRes['score'] = scores
|
75 |
+
evalRes['label'] = labels
|
76 |
+
evalRes.drop(['label_id'], axis=1, inplace=True)
|
77 |
+
evalRes.drop(['instance_id'], axis=1, inplace=True)
|
78 |
+
evalRes.to_csv(evalCsvSave, index=False)
|
79 |
+
cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
|
80 |
+
evalCsvSave)
|
81 |
+
mAP = float(
|
82 |
+
str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
|
83 |
+
return mAP
|
84 |
+
|
85 |
+
def saveParameters(self, path):
|
86 |
+
torch.save(self.state_dict(), path)
|
87 |
+
|
88 |
+
def loadParameters(self, path):
|
89 |
+
selfState = self.state_dict()
|
90 |
+
loadedState = torch.load(path)
|
91 |
+
for name, param in loadedState.items():
|
92 |
+
origName = name
|
93 |
+
if name not in selfState:
|
94 |
+
name = name.replace("module.", "")
|
95 |
+
if name not in selfState:
|
96 |
+
print("%s is not in the model." % origName)
|
97 |
+
continue
|
98 |
+
if selfState[name].size() != loadedState[origName].size():
|
99 |
+
sys.stderr.write("Wrong parameter length: %s, model: %s, loaded: %s" %
|
100 |
+
(origName, selfState[name].size(), loadedState[origName].size()))
|
101 |
+
continue
|
102 |
+
selfState[name].copy_(param)
|
legacy/trainTalkNet_multicard.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time, os, torch, argparse, warnings, glob
|
2 |
+
|
3 |
+
from utils.tools import *
|
4 |
+
from dlhammer import bootstrap
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from pytorch_lightning import Trainer, seed_everything
|
7 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
9 |
+
|
10 |
+
|
11 |
+
class MyCollator(object):
|
12 |
+
|
13 |
+
def __init__(self, cfg):
|
14 |
+
self.cfg = cfg
|
15 |
+
|
16 |
+
def __call__(self, data):
|
17 |
+
audiofeatures = [item[0] for item in data]
|
18 |
+
visualfeatures = [item[1] for item in data]
|
19 |
+
labels = [item[2] for item in data]
|
20 |
+
masks = [item[3] for item in data]
|
21 |
+
cut_limit = self.cfg.MODEL.CLIP_LENGTH
|
22 |
+
# pad audio
|
23 |
+
lengths = torch.tensor([t.shape[1] for t in audiofeatures])
|
24 |
+
max_len = max(lengths)
|
25 |
+
padded_audio = torch.stack([
|
26 |
+
torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1], i.shape[2]))], 1)
|
27 |
+
for i in audiofeatures
|
28 |
+
], 0)
|
29 |
+
|
30 |
+
if max_len > cut_limit * 4:
|
31 |
+
padded_audio = padded_audio[:, :, :cut_limit * 4, ...]
|
32 |
+
|
33 |
+
# pad video
|
34 |
+
lengths = torch.tensor([t.shape[1] for t in visualfeatures])
|
35 |
+
max_len = max(lengths)
|
36 |
+
padded_video = torch.stack([
|
37 |
+
torch.cat(
|
38 |
+
[i, i.new_zeros((i.shape[0], max_len - i.shape[1], i.shape[2], i.shape[3]))], 1)
|
39 |
+
for i in visualfeatures
|
40 |
+
], 0)
|
41 |
+
padded_labels = torch.stack(
|
42 |
+
[torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1]))], 1) for i in labels], 0)
|
43 |
+
padded_masks = torch.stack(
|
44 |
+
[torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1]))], 1) for i in masks], 0)
|
45 |
+
|
46 |
+
if max_len > cut_limit:
|
47 |
+
padded_video = padded_video[:, :, :cut_limit, ...]
|
48 |
+
padded_labels = padded_labels[:, :, :cut_limit, ...]
|
49 |
+
padded_masks = padded_masks[:, :, :cut_limit, ...]
|
50 |
+
return padded_audio, padded_video, padded_labels, padded_masks
|
51 |
+
|
52 |
+
|
53 |
+
class DataPrep(pl.LightningDataModule):
|
54 |
+
|
55 |
+
def __init__(self, cfg):
|
56 |
+
self.cfg = cfg
|
57 |
+
|
58 |
+
def train_dataloader(self):
|
59 |
+
cfg = self.cfg
|
60 |
+
|
61 |
+
if self.cfg.MODEL.NAME == "baseline":
|
62 |
+
from dataLoader import train_loader, val_loader
|
63 |
+
loader = train_loader(trialFileName = cfg.trainTrialAVA, \
|
64 |
+
audioPath = os.path.join(cfg.audioPathAVA , 'train'), \
|
65 |
+
visualPath = os.path.join(cfg.visualPathAVA, 'train'), \
|
66 |
+
batchSize=2500
|
67 |
+
)
|
68 |
+
elif self.cfg.MODEL.NAME == "multi":
|
69 |
+
from dataLoader_multiperson import train_loader, val_loader
|
70 |
+
loader = train_loader(trialFileName = cfg.trainTrialAVA, \
|
71 |
+
audioPath = os.path.join(cfg.audioPathAVA , 'train'), \
|
72 |
+
visualPath = os.path.join(cfg.visualPathAVA, 'train'), \
|
73 |
+
num_speakers=cfg.MODEL.NUM_SPEAKERS,
|
74 |
+
)
|
75 |
+
if cfg.MODEL.NAME == "baseline":
|
76 |
+
trainLoader = torch.utils.data.DataLoader(
|
77 |
+
loader,
|
78 |
+
batch_size=1,
|
79 |
+
shuffle=True,
|
80 |
+
num_workers=4,
|
81 |
+
)
|
82 |
+
elif cfg.MODEL.NAME == "multi":
|
83 |
+
collator = MyCollator(cfg)
|
84 |
+
trainLoader = torch.utils.data.DataLoader(loader,
|
85 |
+
batch_size=1,
|
86 |
+
shuffle=True,
|
87 |
+
num_workers=4,
|
88 |
+
collate_fn=collator)
|
89 |
+
|
90 |
+
return trainLoader
|
91 |
+
|
92 |
+
def val_dataloader(self):
|
93 |
+
cfg = self.cfg
|
94 |
+
loader = val_loader(trialFileName = cfg.evalTrialAVA, \
|
95 |
+
audioPath = os.path.join(cfg.audioPathAVA , cfg.evalDataType), \
|
96 |
+
visualPath = os.path.join(cfg.visualPathAVA, cfg.evalDataType), \
|
97 |
+
)
|
98 |
+
valLoader = torch.utils.data.DataLoader(loader,
|
99 |
+
batch_size=cfg.VAL.BATCH_SIZE,
|
100 |
+
shuffle=False,
|
101 |
+
num_workers=16)
|
102 |
+
return valLoader
|
103 |
+
|
104 |
+
|
105 |
+
def main():
|
106 |
+
# The structure of this code is learnt from https://github.com/clovaai/voxceleb_trainer
|
107 |
+
cfg = bootstrap(print_cfg=False)
|
108 |
+
print(cfg)
|
109 |
+
|
110 |
+
warnings.filterwarnings("ignore")
|
111 |
+
seed_everything(42, workers=True)
|
112 |
+
|
113 |
+
cfg = init_args(cfg)
|
114 |
+
|
115 |
+
# checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(cfg.WORKSPACE, "model"),
|
116 |
+
# save_top_k=-1,
|
117 |
+
# filename='{epoch}')
|
118 |
+
|
119 |
+
data = DataPrep(cfg)
|
120 |
+
|
121 |
+
trainer = Trainer(
|
122 |
+
gpus=int(cfg.TRAIN.TRAINER_GPU),
|
123 |
+
precision=32,
|
124 |
+
# callbacks=[checkpoint_callback],
|
125 |
+
max_epochs=25,
|
126 |
+
replace_sampler_ddp=True)
|
127 |
+
# val_trainer = Trainer(deterministic=True, num_sanity_val_steps=-1, gpus=1)
|
128 |
+
if cfg.downloadAVA == True:
|
129 |
+
preprocess_AVA(cfg)
|
130 |
+
quit()
|
131 |
+
|
132 |
+
# if cfg.RESUME:
|
133 |
+
# modelfiles = glob.glob('%s/model_0*.model' % cfg.modelSavePath)
|
134 |
+
# modelfiles.sort()
|
135 |
+
# if len(modelfiles) >= 1:
|
136 |
+
# print("Model %s loaded from previous state!" % modelfiles[-1])
|
137 |
+
# epoch = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][6:]) + 1
|
138 |
+
# s = talkNet(cfg)
|
139 |
+
# s.loadParameters(modelfiles[-1])
|
140 |
+
# else:
|
141 |
+
# epoch = 1
|
142 |
+
# s = talkNet(cfg)
|
143 |
+
epoch = 1
|
144 |
+
if cfg.MODEL.NAME == "baseline":
|
145 |
+
from talkNet_multicard import talkNet
|
146 |
+
elif cfg.MODEL.NAME == "multi":
|
147 |
+
from talkNet_multi import talkNet
|
148 |
+
|
149 |
+
s = talkNet(cfg)
|
150 |
+
|
151 |
+
# scoreFile = open(cfg.scoreSavePath, "a+")
|
152 |
+
|
153 |
+
trainer.fit(s, train_dataloaders=data.train_dataloader())
|
154 |
+
|
155 |
+
modelfiles = glob.glob('%s/*.pth' % os.path.join(cfg.WORKSPACE, "model"))
|
156 |
+
|
157 |
+
modelfiles.sort()
|
158 |
+
for path in modelfiles:
|
159 |
+
s.loadParameters(path)
|
160 |
+
prec = trainer.validate(s, data.val_dataloader())
|
161 |
+
|
162 |
+
# if epoch % cfg.testInterval == 0:
|
163 |
+
# s.saveParameters(cfg.modelSavePath + "/model_%04d.model" % epoch)
|
164 |
+
# trainer.validate(dataloaders=valLoader)
|
165 |
+
# print(time.strftime("%Y-%m-%d %H:%M:%S"), "%d epoch, mAP %2.2f%%" % (epoch, mAPs[-1]))
|
166 |
+
# scoreFile.write("%d epoch, LOSS %f, mAP %2.2f%%\n" % (epoch, loss, mAPs[-1]))
|
167 |
+
# scoreFile.flush()
|
168 |
+
|
169 |
+
|
170 |
+
if __name__ == '__main__':
|
171 |
+
main()
|
legacy/train_multi.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time, os, torch, argparse, warnings, glob
|
2 |
+
|
3 |
+
from dataLoader_multiperson import train_loader, val_loader
|
4 |
+
from utils.tools import *
|
5 |
+
from talkNet_multi import talkNet
|
6 |
+
|
7 |
+
|
8 |
+
def collate_fn_padding(data):
|
9 |
+
audiofeatures = [item[0] for item in data]
|
10 |
+
visualfeatures = [item[1] for item in data]
|
11 |
+
labels = [item[2] for item in data]
|
12 |
+
masks = [item[3] for item in data]
|
13 |
+
cut_limit = 200
|
14 |
+
# pad audio
|
15 |
+
lengths = torch.tensor([t.shape[1] for t in audiofeatures])
|
16 |
+
max_len = max(lengths)
|
17 |
+
padded_audio = torch.stack([
|
18 |
+
torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1], i.shape[2]))], 1)
|
19 |
+
for i in audiofeatures
|
20 |
+
], 0)
|
21 |
+
|
22 |
+
if max_len > cut_limit * 4:
|
23 |
+
padded_audio = padded_audio[:, :, :cut_limit * 4, ...]
|
24 |
+
|
25 |
+
# pad video
|
26 |
+
lengths = torch.tensor([t.shape[1] for t in visualfeatures])
|
27 |
+
max_len = max(lengths)
|
28 |
+
padded_video = torch.stack([
|
29 |
+
torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1], i.shape[2], i.shape[3]))], 1)
|
30 |
+
for i in visualfeatures
|
31 |
+
], 0)
|
32 |
+
padded_labels = torch.stack(
|
33 |
+
[torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1]))], 1) for i in labels], 0)
|
34 |
+
padded_masks = torch.stack(
|
35 |
+
[torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1]))], 1) for i in masks], 0)
|
36 |
+
|
37 |
+
if max_len > cut_limit:
|
38 |
+
padded_video = padded_video[:, :, :cut_limit, ...]
|
39 |
+
padded_labels = padded_labels[:, :, :cut_limit, ...]
|
40 |
+
padded_masks = padded_masks[:, :, :cut_limit, ...]
|
41 |
+
# print(padded_audio.shape, padded_video.shape, padded_labels.shape, padded_masks.shape)
|
42 |
+
return padded_audio, padded_video, padded_labels, padded_masks
|
43 |
+
|
44 |
+
|
45 |
+
def main():
|
46 |
+
# The structure of this code is learnt from https://github.com/clovaai/voxceleb_trainer
|
47 |
+
warnings.filterwarnings("ignore")
|
48 |
+
|
49 |
+
parser = argparse.ArgumentParser(description="TalkNet Training")
|
50 |
+
# Training details
|
51 |
+
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
|
52 |
+
parser.add_argument('--lrDecay', type=float, default=0.95, help='Learning rate decay rate')
|
53 |
+
parser.add_argument('--maxEpoch', type=int, default=25, help='Maximum number of epochs')
|
54 |
+
parser.add_argument('--testInterval',
|
55 |
+
type=int,
|
56 |
+
default=1,
|
57 |
+
help='Test and save every [testInterval] epochs')
|
58 |
+
parser.add_argument(
|
59 |
+
'--batchSize',
|
60 |
+
type=int,
|
61 |
+
default=2500,
|
62 |
+
help=
|
63 |
+
'Dynamic batch size, default is 2500 frames, other batchsize (such as 1500) will not affect the performance'
|
64 |
+
)
|
65 |
+
parser.add_argument('--batch_size', type=int, default=1, help='batch_size')
|
66 |
+
parser.add_argument('--num_speakers', type=int, default=5, help='num_speakers')
|
67 |
+
parser.add_argument('--nDataLoaderThread', type=int, default=4, help='Number of loader threads')
|
68 |
+
# Data path
|
69 |
+
parser.add_argument('--dataPathAVA',
|
70 |
+
type=str,
|
71 |
+
default="/data08/AVA",
|
72 |
+
help='Save path of AVA dataset')
|
73 |
+
parser.add_argument('--savePath', type=str, default="exps/exp1")
|
74 |
+
# Data selection
|
75 |
+
parser.add_argument('--evalDataType',
|
76 |
+
type=str,
|
77 |
+
default="val",
|
78 |
+
help='Only for AVA, to choose the dataset for evaluation, val or test')
|
79 |
+
# For download dataset only, for evaluation only
|
80 |
+
parser.add_argument('--downloadAVA',
|
81 |
+
dest='downloadAVA',
|
82 |
+
action='store_true',
|
83 |
+
help='Only download AVA dataset and do related preprocess')
|
84 |
+
parser.add_argument('--evaluation',
|
85 |
+
dest='evaluation',
|
86 |
+
action='store_true',
|
87 |
+
help='Only do evaluation by using pretrained model [pretrain_AVA.model]')
|
88 |
+
args = parser.parse_args()
|
89 |
+
# Data loader
|
90 |
+
args = init_args(args)
|
91 |
+
|
92 |
+
if args.downloadAVA == True:
|
93 |
+
preprocess_AVA(args)
|
94 |
+
quit()
|
95 |
+
|
96 |
+
loader = train_loader(trialFileName = args.trainTrialAVA, \
|
97 |
+
audioPath = os.path.join(args.audioPathAVA , 'train'), \
|
98 |
+
visualPath = os.path.join(args.visualPathAVA, 'train'), \
|
99 |
+
# num_speakers = args.num_speakers, \
|
100 |
+
**vars(args))
|
101 |
+
trainLoader = torch.utils.data.DataLoader(loader,
|
102 |
+
batch_size=args.batch_size,
|
103 |
+
shuffle=True,
|
104 |
+
num_workers=args.nDataLoaderThread,
|
105 |
+
collate_fn=collate_fn_padding)
|
106 |
+
|
107 |
+
loader = val_loader(trialFileName = args.evalTrialAVA, \
|
108 |
+
audioPath = os.path.join(args.audioPathAVA , args.evalDataType), \
|
109 |
+
visualPath = os.path.join(args.visualPathAVA, args.evalDataType), \
|
110 |
+
# num_speakers = args.num_speakers, \
|
111 |
+
**vars(args))
|
112 |
+
valLoader = torch.utils.data.DataLoader(loader, batch_size=1, shuffle=False, num_workers=16)
|
113 |
+
|
114 |
+
if args.evaluation == True:
|
115 |
+
download_pretrain_model_AVA()
|
116 |
+
s = talkNet(**vars(args))
|
117 |
+
s.loadParameters('pretrain_AVA.model')
|
118 |
+
print("Model %s loaded from previous state!" % ('pretrain_AVA.model'))
|
119 |
+
mAP = s.evaluate_network(loader=valLoader, **vars(args))
|
120 |
+
print("mAP %2.2f%%" % (mAP))
|
121 |
+
quit()
|
122 |
+
|
123 |
+
modelfiles = glob.glob('%s/model_0*.model' % args.modelSavePath)
|
124 |
+
modelfiles.sort()
|
125 |
+
if len(modelfiles) >= 1:
|
126 |
+
print("Model %s loaded from previous state!" % modelfiles[-1])
|
127 |
+
epoch = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][6:]) + 1
|
128 |
+
s = talkNet(epoch=epoch, **vars(args))
|
129 |
+
s.loadParameters(modelfiles[-1])
|
130 |
+
else:
|
131 |
+
epoch = 1
|
132 |
+
s = talkNet(epoch=epoch, **vars(args))
|
133 |
+
|
134 |
+
mAPs = []
|
135 |
+
scoreFile = open(args.scoreSavePath, "a+")
|
136 |
+
|
137 |
+
while (1):
|
138 |
+
loss, lr = s.train_network(epoch=epoch, loader=trainLoader, **vars(args))
|
139 |
+
|
140 |
+
if epoch % args.testInterval == 0:
|
141 |
+
s.saveParameters(args.modelSavePath + "/model_%04d.model" % epoch)
|
142 |
+
mAPs.append(s.evaluate_network(epoch=epoch, loader=valLoader, **vars(args)))
|
143 |
+
print(time.strftime("%Y-%m-%d %H:%M:%S"),
|
144 |
+
"%d epoch, mAP %2.2f%%, bestmAP %2.2f%%" % (epoch, mAPs[-1], max(mAPs)))
|
145 |
+
scoreFile.write("%d epoch, LR %f, LOSS %f, mAP %2.2f%%, bestmAP %2.2f%%\n" %
|
146 |
+
(epoch, lr, loss, mAPs[-1], max(mAPs)))
|
147 |
+
scoreFile.flush()
|
148 |
+
|
149 |
+
if epoch >= args.maxEpoch:
|
150 |
+
quit()
|
151 |
+
|
152 |
+
epoch += 1
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == '__main__':
|
156 |
+
main()
|
loconet.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import sys, time, numpy, os, subprocess, pandas, tqdm
|
6 |
+
|
7 |
+
from loss_multi import lossAV, lossA, lossV
|
8 |
+
from model.loconet_encoder import locoencoder
|
9 |
+
|
10 |
+
import torch.distributed as dist
|
11 |
+
from xxlib.utils.distributed import all_gather, all_reduce
|
12 |
+
|
13 |
+
|
14 |
+
class Loconet(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, cfg):
|
17 |
+
super(Loconet, self).__init__()
|
18 |
+
self.cfg = cfg
|
19 |
+
self.model = locoencoder(cfg)
|
20 |
+
self.lossAV = lossAV()
|
21 |
+
self.lossA = lossA()
|
22 |
+
self.lossV = lossV()
|
23 |
+
|
24 |
+
def forward(self, audioFeature, visualFeature, labels, masks):
|
25 |
+
b, s, t = visualFeature.shape[:3]
|
26 |
+
visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
|
27 |
+
labels = labels.view(b * s, *labels.shape[2:])
|
28 |
+
masks = masks.view(b * s, *masks.shape[2:])
|
29 |
+
|
30 |
+
audioEmbed = self.model.forward_audio_frontend(audioFeature) # B, C, T, 4
|
31 |
+
visualEmbed = self.model.forward_visual_frontend(visualFeature)
|
32 |
+
audioEmbed = audioEmbed.repeat(s, 1, 1)
|
33 |
+
|
34 |
+
audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
|
35 |
+
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s)
|
36 |
+
outsA = self.model.forward_audio_backend(audioEmbed)
|
37 |
+
outsV = self.model.forward_visual_backend(visualEmbed)
|
38 |
+
|
39 |
+
labels = labels.reshape((-1))
|
40 |
+
masks = masks.reshape((-1))
|
41 |
+
nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks)
|
42 |
+
nlossA = self.lossA.forward(outsA, labels, masks)
|
43 |
+
nlossV = self.lossV.forward(outsV, labels, masks)
|
44 |
+
|
45 |
+
nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
|
46 |
+
|
47 |
+
num_frames = masks.sum()
|
48 |
+
return nloss, prec, num_frames
|
49 |
+
|
50 |
+
|
51 |
+
class loconet(nn.Module):
|
52 |
+
|
53 |
+
def __init__(self, cfg, rank=None, device=None):
|
54 |
+
super(loconet, self).__init__()
|
55 |
+
self.cfg = cfg
|
56 |
+
self.rank = rank
|
57 |
+
if rank != None:
|
58 |
+
self.rank = rank
|
59 |
+
self.device = device
|
60 |
+
|
61 |
+
self.model = Loconet(cfg).to(device)
|
62 |
+
self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
|
63 |
+
self.model = nn.parallel.DistributedDataParallel(self.model,
|
64 |
+
device_ids=[rank],
|
65 |
+
output_device=rank,
|
66 |
+
find_unused_parameters=False)
|
67 |
+
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.SOLVER.BASE_LR)
|
68 |
+
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim,
|
69 |
+
step_size=1,
|
70 |
+
gamma=self.cfg.SOLVER.SCHEDULER.GAMMA)
|
71 |
+
else:
|
72 |
+
self.model = locoencoder(cfg).cuda()
|
73 |
+
self.lossAV = lossAV().cuda()
|
74 |
+
self.lossA = lossA().cuda()
|
75 |
+
self.lossV = lossV().cuda()
|
76 |
+
|
77 |
+
print(
|
78 |
+
time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" %
|
79 |
+
(sum(param.numel() for param in self.model.parameters()) / 1024 / 1024))
|
80 |
+
|
81 |
+
def train_network(self, epoch, loader):
|
82 |
+
self.model.train()
|
83 |
+
self.scheduler.step(epoch - 1)
|
84 |
+
index, top1, loss = 0, 0, 0
|
85 |
+
lr = self.optim.param_groups[0]['lr']
|
86 |
+
loader.sampler.set_epoch(epoch)
|
87 |
+
device = self.device
|
88 |
+
|
89 |
+
pbar = enumerate(loader, start=1)
|
90 |
+
if self.rank == 0:
|
91 |
+
pbar = tqdm.tqdm(pbar, total=loader.__len__())
|
92 |
+
|
93 |
+
for num, (audioFeature, visualFeature, labels, masks) in pbar:
|
94 |
+
|
95 |
+
audioFeature = audioFeature.to(device)
|
96 |
+
visualFeature = visualFeature.to(device)
|
97 |
+
labels = labels.to(device)
|
98 |
+
masks = masks.to(device)
|
99 |
+
nloss, prec, num_frames = self.model(
|
100 |
+
audioFeature,
|
101 |
+
visualFeature,
|
102 |
+
labels,
|
103 |
+
masks,
|
104 |
+
)
|
105 |
+
|
106 |
+
self.optim.zero_grad()
|
107 |
+
nloss.backward()
|
108 |
+
self.optim.step()
|
109 |
+
|
110 |
+
[nloss, prec, num_frames] = all_reduce([nloss, prec, num_frames], average=False)
|
111 |
+
top1 += prec.detach().cpu().numpy()
|
112 |
+
loss += nloss.detach().cpu().numpy()
|
113 |
+
index += int(num_frames.detach().cpu().item())
|
114 |
+
if self.rank == 0:
|
115 |
+
pbar.set_postfix(
|
116 |
+
dict(epoch=epoch,
|
117 |
+
lr=lr,
|
118 |
+
loss=loss / (num * self.cfg.NUM_GPUS),
|
119 |
+
acc=(top1 / index)))
|
120 |
+
dist.barrier()
|
121 |
+
return loss / num, lr
|
122 |
+
|
123 |
+
def evaluate_network(self, epoch, loader):
|
124 |
+
self.eval()
|
125 |
+
predScores = []
|
126 |
+
evalCsvSave = os.path.join(self.cfg.WORKSPACE, "{}_res.csv".format(epoch))
|
127 |
+
evalOrig = self.cfg.evalOrig
|
128 |
+
for audioFeature, visualFeature, labels, masks in tqdm.tqdm(loader):
|
129 |
+
with torch.no_grad():
|
130 |
+
audioFeature = audioFeature.cuda()
|
131 |
+
visualFeature = visualFeature.cuda()
|
132 |
+
labels = labels.cuda()
|
133 |
+
masks = masks.cuda()
|
134 |
+
b, s, t = visualFeature.shape[0], visualFeature.shape[1], visualFeature.shape[2]
|
135 |
+
visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
|
136 |
+
labels = labels.view(b * s, *labels.shape[2:])
|
137 |
+
masks = masks.view(b * s, *masks.shape[2:])
|
138 |
+
audioEmbed = self.model.forward_audio_frontend(audioFeature)
|
139 |
+
visualEmbed = self.model.forward_visual_frontend(visualFeature)
|
140 |
+
audioEmbed = audioEmbed.repeat(s, 1, 1)
|
141 |
+
audioEmbed, visualEmbed = self.model.forward_cross_attention(
|
142 |
+
audioEmbed, visualEmbed)
|
143 |
+
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s)
|
144 |
+
labels = labels.reshape((-1))
|
145 |
+
masks = masks.reshape((-1))
|
146 |
+
outsAV = outsAV.view(b, s, t, -1)[:, 0, :, :].view(b * t, -1)
|
147 |
+
labels = labels.view(b, s, t)[:, 0, :].view(b * t).cuda()
|
148 |
+
masks = masks.view(b, s, t)[:, 0, :].view(b * t)
|
149 |
+
_, predScore, _, _ = self.lossAV.forward(outsAV, labels, masks)
|
150 |
+
predScore = predScore[:, 1].detach().cpu().numpy()
|
151 |
+
predScores.extend(predScore)
|
152 |
+
evalLines = open(evalOrig).read().splitlines()[1:]
|
153 |
+
labels = []
|
154 |
+
labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
|
155 |
+
scores = pandas.Series(predScores)
|
156 |
+
evalRes = pandas.read_csv(evalOrig)
|
157 |
+
evalRes['score'] = scores
|
158 |
+
evalRes['label'] = labels
|
159 |
+
evalRes.drop(['label_id'], axis=1, inplace=True)
|
160 |
+
evalRes.drop(['instance_id'], axis=1, inplace=True)
|
161 |
+
evalRes.to_csv(evalCsvSave, index=False)
|
162 |
+
cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
|
163 |
+
evalCsvSave)
|
164 |
+
mAP = float(
|
165 |
+
str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
|
166 |
+
return mAP
|
167 |
+
|
168 |
+
def saveParameters(self, path):
|
169 |
+
torch.save(self.state_dict(), path)
|
170 |
+
|
171 |
+
def loadParameters(self, path):
|
172 |
+
selfState = self.state_dict()
|
173 |
+
loadedState = torch.load(path, map_location='cpu')
|
174 |
+
if self.rank != None:
|
175 |
+
info = self.load_state_dict(loadedState)
|
176 |
+
else:
|
177 |
+
new_state = {}
|
178 |
+
|
179 |
+
for k, v in loadedState.items():
|
180 |
+
new_state[k.replace("model.module.", "")] = v
|
181 |
+
info = self.load_state_dict(new_state, strict=False)
|
182 |
+
print(info)
|
loss_multi.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import utils.distributed as du
|
5 |
+
|
6 |
+
|
7 |
+
class lossAV(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
super(lossAV, self).__init__()
|
11 |
+
self.criterion = nn.CrossEntropyLoss(reduction='none')
|
12 |
+
self.FC = nn.Linear(256, 2)
|
13 |
+
|
14 |
+
def forward(self, x, labels=None, masks=None):
|
15 |
+
x = x.squeeze(1)
|
16 |
+
x = self.FC(x)
|
17 |
+
if labels == None:
|
18 |
+
predScore = x[:, 1]
|
19 |
+
predScore = predScore.t()
|
20 |
+
predScore = predScore.view(-1).detach().cpu().numpy()
|
21 |
+
return predScore
|
22 |
+
else:
|
23 |
+
nloss = self.criterion(x, labels) * masks
|
24 |
+
|
25 |
+
num_valid = masks.sum().float()
|
26 |
+
if self.training:
|
27 |
+
[num_valid] = du.all_reduce([num_valid],average=True)
|
28 |
+
nloss = torch.sum(nloss) / num_valid
|
29 |
+
|
30 |
+
predScore = F.softmax(x, dim=-1)
|
31 |
+
predLabel = torch.round(F.softmax(x, dim=-1))[:, 1]
|
32 |
+
correctNum = ((predLabel == labels) * masks).sum().float()
|
33 |
+
return nloss, predScore, predLabel, correctNum
|
34 |
+
|
35 |
+
|
36 |
+
class lossA(nn.Module):
|
37 |
+
|
38 |
+
def __init__(self):
|
39 |
+
super(lossA, self).__init__()
|
40 |
+
self.criterion = nn.CrossEntropyLoss(reduction='none')
|
41 |
+
self.FC = nn.Linear(128, 2)
|
42 |
+
|
43 |
+
def forward(self, x, labels, masks=None):
|
44 |
+
x = x.squeeze(1)
|
45 |
+
x = self.FC(x)
|
46 |
+
nloss = self.criterion(x, labels) * masks
|
47 |
+
num_valid = masks.sum().float()
|
48 |
+
if self.training:
|
49 |
+
[num_valid] = du.all_reduce([num_valid],average=True)
|
50 |
+
nloss = torch.sum(nloss) / num_valid
|
51 |
+
#nloss = torch.sum(nloss) / torch.sum(masks)
|
52 |
+
return nloss
|
53 |
+
|
54 |
+
|
55 |
+
class lossV(nn.Module):
|
56 |
+
|
57 |
+
def __init__(self):
|
58 |
+
super(lossV, self).__init__()
|
59 |
+
|
60 |
+
self.criterion = nn.CrossEntropyLoss(reduction='none')
|
61 |
+
self.FC = nn.Linear(128, 2)
|
62 |
+
|
63 |
+
def forward(self, x, labels, masks=None):
|
64 |
+
x = x.squeeze(1)
|
65 |
+
x = self.FC(x)
|
66 |
+
nloss = self.criterion(x, labels) * masks
|
67 |
+
# nloss = torch.sum(nloss) / torch.sum(masks)
|
68 |
+
num_valid = masks.sum().float()
|
69 |
+
if self.training:
|
70 |
+
[num_valid] = du.all_reduce([num_valid],average=True)
|
71 |
+
nloss = torch.sum(nloss) / num_valid
|
72 |
+
return nloss
|
metrics/AverageMeter.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#taken from pytorch imagenet example
|
2 |
+
class AverageMeter(object):
|
3 |
+
"""Computes and stores the average and current value"""
|
4 |
+
def __init__(self):
|
5 |
+
self.reset()
|
6 |
+
|
7 |
+
def reset(self):
|
8 |
+
self.val = 0
|
9 |
+
self.avg = 0
|
10 |
+
self.sum = 0
|
11 |
+
self.count = 0
|
12 |
+
|
13 |
+
def update(self, val, n=1):
|
14 |
+
self.val = val
|
15 |
+
self.sum += val * n
|
16 |
+
self.count += n
|
17 |
+
self.avg = self.sum / self.count
|
18 |
+
|
metrics/__pycache__/.nfs000000035f4a8257000000eb
ADDED
Binary file (896 Bytes). View file
|
|
metrics/__pycache__/AverageMeter.cpython-36.pyc
ADDED
Binary file (897 Bytes). View file
|
|
metrics/__pycache__/AverageMeter.cpython-38.pyc
ADDED
Binary file (908 Bytes). View file
|
|
metrics/__pycache__/accuracy.cpython-36.pyc
ADDED
Binary file (870 Bytes). View file
|
|
metrics/__pycache__/accuracy.cpython-38.pyc
ADDED
Binary file (876 Bytes). View file
|
|
metrics/accuracy.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
accuracy = lambda output,target : acc_topk(output, target)[0]
|
4 |
+
|
5 |
+
#taken from pytorch imagenet example
|
6 |
+
def acc_topk(output, target, topk=(1,)):
|
7 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
8 |
+
with torch.no_grad():
|
9 |
+
maxk = max(topk)
|
10 |
+
batch_size = target.size(0)
|
11 |
+
|
12 |
+
_, pred = output.topk(maxk, 1, True, True)
|
13 |
+
pred = pred.t()
|
14 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
15 |
+
|
16 |
+
res = []
|
17 |
+
for k in topk:
|
18 |
+
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
19 |
+
res.append(correct_k.mul_(1.0 / batch_size))
|
20 |
+
return res
|
model/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
model/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.transformer.position_encoding import PositionalEncoding
|
2 |
+
from model.transformer.transformer import Transformer
|
3 |
+
from model.transformer.transformer import TransformerEncoder, TransformerEncoderLayer
|
4 |
+
from model.transformer.transformer import TransformerDecoder, TransformerDecoderLayer
|
5 |
+
from model.transformer.utils import layer_norm, generate_square_subsequent_mask, generate_proposal_mask
|
model/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (561 Bytes). View file
|
|
model/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (573 Bytes). View file
|
|
model/__pycache__/attentionLayer.cpython-37.pyc
ADDED
Binary file (1.38 kB). View file
|
|
model/__pycache__/convLayer.cpython-37.pyc
ADDED
Binary file (1.32 kB). View file
|
|
model/__pycache__/loconet_encoder.cpython-37.pyc
ADDED
Binary file (3.21 kB). View file
|
|
model/__pycache__/position_encoding.cpython-36.pyc
ADDED
Binary file (1.26 kB). View file
|
|
model/__pycache__/talkNetModel.cpython-37.pyc
ADDED
Binary file (6.33 kB). View file
|
|
model/__pycache__/transformer.cpython-36.pyc
ADDED
Binary file (8.84 kB). View file
|
|
model/__pycache__/utils.cpython-36.pyc
ADDED
Binary file (1.08 kB). View file
|
|