more cross comparing
Browse files
script/hyperparameter_tuning.py
CHANGED
@@ -84,8 +84,9 @@ def objective(trial, hyperparam_run_dir, data_path):
|
|
84 |
'dataset_label': dataset_label,
|
85 |
'trial_number': trial.number,
|
86 |
'parameters': trial.params,
|
87 |
-
'
|
88 |
-
'visualization_dir': vis_dir
|
|
|
89 |
}
|
90 |
|
91 |
with open(os.path.join(trial_dir, 'trial_info.json'), 'w') as f:
|
|
|
84 |
'dataset_label': dataset_label,
|
85 |
'trial_number': trial.number,
|
86 |
'parameters': trial.params,
|
87 |
+
'accuracy': val_accuracy,
|
88 |
+
'visualization_dir': vis_dir,
|
89 |
+
'trial_dir': trial_dir
|
90 |
}
|
91 |
|
92 |
with open(os.path.join(trial_dir, 'trial_info.json'), 'w') as f:
|
script/visualization/analyze_trials.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import pandas as pd
|
4 |
+
from pathlib import Path
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
def parse_error_analysis(vis_dir):
|
8 |
+
"""Parse the error_analysis.txt file to get accuracy and misclassification details"""
|
9 |
+
metrics = {}
|
10 |
+
class_accuracies = {}
|
11 |
+
misclassified_files = []
|
12 |
+
|
13 |
+
with open(os.path.join(vis_dir, 'error_analysis.txt'), 'r') as f:
|
14 |
+
lines = f.readlines()
|
15 |
+
parsing_errors = False
|
16 |
+
header_found = False
|
17 |
+
|
18 |
+
for line in lines:
|
19 |
+
# Get overall accuracy
|
20 |
+
if line.startswith("Overall Accuracy:"):
|
21 |
+
metrics['overall_accuracy'] = float(line.split(":")[1].strip().rstrip('%')) / 100
|
22 |
+
|
23 |
+
# Parse per-class accuracy
|
24 |
+
if "samples)" in line and ":" in line:
|
25 |
+
class_name = line.split(":")[0].strip()
|
26 |
+
accuracy = float(line.split(":")[1].split("%")[0].strip()) / 100
|
27 |
+
samples = int(line.split("(")[1].split(" ")[0])
|
28 |
+
class_accuracies[class_name] = {
|
29 |
+
'accuracy': accuracy,
|
30 |
+
'samples': samples
|
31 |
+
}
|
32 |
+
|
33 |
+
# Parse misclassified files
|
34 |
+
if "Misclassified Videos:" in line:
|
35 |
+
parsing_errors = True
|
36 |
+
continue
|
37 |
+
if "Filename" in line and "True Class" in line:
|
38 |
+
header_found = True
|
39 |
+
continue
|
40 |
+
if parsing_errors and header_found and line.strip() and not line.startswith("-"):
|
41 |
+
try:
|
42 |
+
# Split the line while preserving filename with spaces
|
43 |
+
parts = line.strip().split()
|
44 |
+
# Find the confidence value (last element with %)
|
45 |
+
confidence_idx = next(i for i, part in enumerate(parts) if part.endswith('%'))
|
46 |
+
# Everything before the last three elements is the filename
|
47 |
+
filename = ' '.join(parts[:confidence_idx-2])
|
48 |
+
true_class = parts[confidence_idx-2]
|
49 |
+
pred_class = parts[confidence_idx-1]
|
50 |
+
confidence = float(parts[confidence_idx].rstrip('%')) / 100
|
51 |
+
|
52 |
+
misclassified_files.append({
|
53 |
+
'filename': filename,
|
54 |
+
'true_class': true_class,
|
55 |
+
'predicted_class': pred_class,
|
56 |
+
'confidence': confidence
|
57 |
+
})
|
58 |
+
except Exception as e:
|
59 |
+
print(f"Warning: Could not parse line: {line.strip()}")
|
60 |
+
continue
|
61 |
+
|
62 |
+
metrics['class_accuracies'] = class_accuracies
|
63 |
+
metrics['misclassified_files'] = misclassified_files
|
64 |
+
return metrics
|
65 |
+
|
66 |
+
def analyze_trial(trial_dir):
|
67 |
+
"""Analyze all visualization directories in a trial and aggregate results"""
|
68 |
+
trial_metrics = {
|
69 |
+
'overall_accuracy': 0,
|
70 |
+
'total_samples': 0,
|
71 |
+
'class_accuracies': defaultdict(lambda: {'correct': 0, 'total': 0}),
|
72 |
+
'misclassified_files': []
|
73 |
+
}
|
74 |
+
|
75 |
+
# Find all visualization directories
|
76 |
+
vis_dirs = [d for d in trial_dir.iterdir() if d.is_dir() and d.name.startswith('visualization_')]
|
77 |
+
if not vis_dirs:
|
78 |
+
return None
|
79 |
+
|
80 |
+
for vis_dir in vis_dirs:
|
81 |
+
try:
|
82 |
+
metrics = parse_error_analysis(vis_dir)
|
83 |
+
|
84 |
+
# Add to total samples and weighted accuracy
|
85 |
+
samples = sum(m['samples'] for m in metrics['class_accuracies'].values())
|
86 |
+
trial_metrics['total_samples'] += samples
|
87 |
+
trial_metrics['overall_accuracy'] += metrics['overall_accuracy'] * samples
|
88 |
+
|
89 |
+
# Aggregate per-class metrics
|
90 |
+
for class_name, class_metrics in metrics['class_accuracies'].items():
|
91 |
+
trial_metrics['class_accuracies'][class_name]['correct'] += (
|
92 |
+
class_metrics['accuracy'] * class_metrics['samples']
|
93 |
+
)
|
94 |
+
trial_metrics['class_accuracies'][class_name]['total'] += class_metrics['samples']
|
95 |
+
|
96 |
+
# Collect misclassified files with visualization directory info
|
97 |
+
for error in metrics['misclassified_files']:
|
98 |
+
error['vis_dir'] = vis_dir.name
|
99 |
+
trial_metrics['misclassified_files'].append(error)
|
100 |
+
|
101 |
+
except Exception as e:
|
102 |
+
print(f"Error processing visualization directory {vis_dir}: {e}")
|
103 |
+
|
104 |
+
# Calculate final metrics
|
105 |
+
if trial_metrics['total_samples'] > 0:
|
106 |
+
trial_metrics['overall_accuracy'] /= trial_metrics['total_samples']
|
107 |
+
|
108 |
+
for class_metrics in trial_metrics['class_accuracies'].values():
|
109 |
+
if class_metrics['total'] > 0:
|
110 |
+
class_metrics['accuracy'] = class_metrics['correct'] / class_metrics['total']
|
111 |
+
|
112 |
+
return trial_metrics
|
113 |
+
|
114 |
+
def analyze_trials(hyperparam_dir):
|
115 |
+
results = {
|
116 |
+
'search_dirs': defaultdict(lambda: {
|
117 |
+
'best_overall': {'accuracy': 0, 'trial': None},
|
118 |
+
'best_per_class': defaultdict(lambda: {'accuracy': 0, 'trial': None}),
|
119 |
+
'misclassified_files': []
|
120 |
+
})
|
121 |
+
}
|
122 |
+
|
123 |
+
# Process each search directory
|
124 |
+
for search_dir in Path(hyperparam_dir).iterdir():
|
125 |
+
if not search_dir.is_dir() or not search_dir.name.startswith('search_'):
|
126 |
+
continue
|
127 |
+
|
128 |
+
# Process each trial directory
|
129 |
+
for trial_dir in search_dir.iterdir():
|
130 |
+
if not trial_dir.is_dir() or not trial_dir.name.startswith('trial_'):
|
131 |
+
continue
|
132 |
+
|
133 |
+
trial_metrics = analyze_trial(trial_dir)
|
134 |
+
if trial_metrics is None:
|
135 |
+
continue
|
136 |
+
|
137 |
+
search_results = results['search_dirs'][search_dir.name]
|
138 |
+
|
139 |
+
# Update overall best for this search directory
|
140 |
+
if trial_metrics['overall_accuracy'] > search_results['best_overall']['accuracy']:
|
141 |
+
search_results['best_overall']['accuracy'] = trial_metrics['overall_accuracy']
|
142 |
+
search_results['best_overall']['trial'] = trial_dir.name
|
143 |
+
|
144 |
+
# Update per-class bests for this search directory
|
145 |
+
for class_name, class_metrics in trial_metrics['class_accuracies'].items():
|
146 |
+
if class_metrics['accuracy'] > search_results['best_per_class'][class_name]['accuracy']:
|
147 |
+
search_results['best_per_class'][class_name]['accuracy'] = class_metrics['accuracy']
|
148 |
+
search_results['best_per_class'][class_name]['trial'] = trial_dir.name
|
149 |
+
|
150 |
+
# Collect misclassified files
|
151 |
+
search_results['misclassified_files'].extend(trial_metrics['misclassified_files'])
|
152 |
+
|
153 |
+
return results
|
154 |
+
|
155 |
+
def save_analysis_report(results, hyperparam_dir):
|
156 |
+
output_file = os.path.join(hyperparam_dir, 'trial_analysis_report.txt')
|
157 |
+
|
158 |
+
with open(output_file, 'w') as f:
|
159 |
+
for search_dir, search_results in results['search_dirs'].items():
|
160 |
+
f.write(f"\n=== Results for {search_dir} ===\n")
|
161 |
+
f.write("-" * 80 + "\n")
|
162 |
+
|
163 |
+
# Best overall model
|
164 |
+
f.write("\nBest Overall Model:\n")
|
165 |
+
f.write(f"Trial: {search_results['best_overall']['trial']}\n")
|
166 |
+
f.write(f"Accuracy: {search_results['best_overall']['accuracy']:.2%}\n")
|
167 |
+
|
168 |
+
# Best model per class
|
169 |
+
f.write("\nBest Model Per Class:\n")
|
170 |
+
f.write(f"{'Class':<20} {'Accuracy':<10} {'Trial'}\n")
|
171 |
+
f.write("-" * 60 + "\n")
|
172 |
+
for class_name, data in search_results['best_per_class'].items():
|
173 |
+
f.write(f"{class_name:<20} {data['accuracy']:.2%} {data['trial']}\n")
|
174 |
+
|
175 |
+
# Most frequently misclassified files
|
176 |
+
f.write("\nMost Frequently Misclassified Files:\n")
|
177 |
+
f.write(f"{'Filename':<40} {'True Class':<15} {'Predicted':<15} {'Confidence':<10} {'Dataset'}\n")
|
178 |
+
f.write("-" * 100 + "\n")
|
179 |
+
|
180 |
+
# Sort misclassified files by confidence (ascending) to show most problematic cases first
|
181 |
+
misclassified = sorted(search_results['misclassified_files'],
|
182 |
+
key=lambda x: x['confidence'])
|
183 |
+
for error in misclassified[:10]: # Show top 10 most problematic
|
184 |
+
f.write(f"{error['filename']:<40} {error['true_class']:<15} "
|
185 |
+
f"{error['predicted_class']:<15} {error['confidence']:<10.2%} {error['vis_dir']}\n")
|
186 |
+
|
187 |
+
f.write("\n" + "=" * 80 + "\n")
|
188 |
+
|
189 |
+
def print_results(results):
|
190 |
+
"""Print a summary of the analysis results"""
|
191 |
+
for search_dir, search_results in results['search_dirs'].items():
|
192 |
+
print(f"\n=== Results for {search_dir} ===")
|
193 |
+
print("-" * 80)
|
194 |
+
|
195 |
+
# Best overall model
|
196 |
+
print(f"\nBest Overall Model:")
|
197 |
+
print(f"Trial: {search_results['best_overall']['trial']}")
|
198 |
+
print(f"Accuracy: {search_results['best_overall']['accuracy']:.2%}")
|
199 |
+
|
200 |
+
# Best model per class
|
201 |
+
print(f"\nBest Model Per Class:")
|
202 |
+
print(f"{'Class':<20} {'Accuracy':<10} {'Trial'}")
|
203 |
+
print("-" * 60)
|
204 |
+
for class_name, data in search_results['best_per_class'].items():
|
205 |
+
print(f"{class_name:<20} {data['accuracy']:.2%} {data['trial']}")
|
206 |
+
|
207 |
+
# Most frequently misclassified files (top 5)
|
208 |
+
print(f"\nTop 5 Most Problematic Files:")
|
209 |
+
print(f"{'Filename':<40} {'True Class':<15} {'Predicted':<15} {'Confidence'}")
|
210 |
+
print("-" * 80)
|
211 |
+
misclassified = sorted(search_results['misclassified_files'],
|
212 |
+
key=lambda x: x['confidence'])[:5]
|
213 |
+
for error in misclassified:
|
214 |
+
print(f"{error['filename']:<40} {error['true_class']:<15} "
|
215 |
+
f"{error['predicted_class']:<15} {error['confidence']:.2%}")
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
hyperparam_dir = "runs_hyperparam/hyperparam_20241106_124214"
|
219 |
+
results = analyze_trials(hyperparam_dir)
|
220 |
+
|
221 |
+
# Print summary to console
|
222 |
+
print_results(results)
|
223 |
+
|
224 |
+
# Save detailed results to file
|
225 |
+
save_analysis_report(results, hyperparam_dir)
|
script/visualization/visualize.py
CHANGED
@@ -45,9 +45,10 @@ def generate_evaluation_metrics(model, data_loader, device, output_dir, class_la
|
|
45 |
all_preds = []
|
46 |
all_labels = []
|
47 |
all_probs = []
|
|
|
48 |
|
49 |
with torch.no_grad():
|
50 |
-
for frames, labels,
|
51 |
frames = frames.to(device)
|
52 |
labels = labels.to(device)
|
53 |
|
@@ -58,11 +59,44 @@ def generate_evaluation_metrics(model, data_loader, device, output_dir, class_la
|
|
58 |
all_preds.extend(predicted.cpu().numpy())
|
59 |
all_labels.extend(labels.cpu().numpy())
|
60 |
all_probs.extend(probs.cpu().numpy())
|
|
|
61 |
|
62 |
all_labels = np.array(all_labels)
|
63 |
all_preds = np.array(all_preds)
|
64 |
all_probs = np.array(all_probs)
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# Compute and plot confusion matrix
|
67 |
cm = confusion_matrix(all_labels, all_preds)
|
68 |
plt.figure(figsize=(10, 8))
|
@@ -124,7 +158,11 @@ def run_visualization(run_dir, data_path=None, test_csv=None):
|
|
124 |
|
125 |
class_labels = config['class_labels']
|
126 |
num_classes = config['num_classes']
|
127 |
-
|
|
|
|
|
|
|
|
|
128 |
|
129 |
# Paths
|
130 |
log_file = os.path.join(run_dir, 'training_log.csv')
|
@@ -136,6 +174,8 @@ def run_visualization(run_dir, data_path=None, test_csv=None):
|
|
136 |
# Get the last directory of data_path and the file name
|
137 |
last_dir = os.path.basename(os.path.normpath(data_path))
|
138 |
file_name = os.path.basename(test_csv)
|
|
|
|
|
139 |
|
140 |
# Create a directory for visualization outputs
|
141 |
vis_dir = os.path.join(run_dir, f'visualization_{last_dir}_{file_name.split(".")[0]}')
|
@@ -164,8 +204,9 @@ def run_visualization(run_dir, data_path=None, test_csv=None):
|
|
164 |
|
165 |
if __name__ == "__main__":
|
166 |
# Find the most recent run directory
|
167 |
-
run_dir = get_latest_run_dir()
|
|
|
168 |
# run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241024-150232_otherpeopleval_large_model"
|
169 |
# run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241022-122939_3moves_balanced"
|
170 |
-
|
171 |
-
run_visualization(run_dir)
|
|
|
45 |
all_preds = []
|
46 |
all_labels = []
|
47 |
all_probs = []
|
48 |
+
all_files = []
|
49 |
|
50 |
with torch.no_grad():
|
51 |
+
for frames, labels, filenames in data_loader:
|
52 |
frames = frames.to(device)
|
53 |
labels = labels.to(device)
|
54 |
|
|
|
59 |
all_preds.extend(predicted.cpu().numpy())
|
60 |
all_labels.extend(labels.cpu().numpy())
|
61 |
all_probs.extend(probs.cpu().numpy())
|
62 |
+
all_files.extend(filenames)
|
63 |
|
64 |
all_labels = np.array(all_labels)
|
65 |
all_preds = np.array(all_preds)
|
66 |
all_probs = np.array(all_probs)
|
67 |
|
68 |
+
# Generate error analysis file
|
69 |
+
error_file = os.path.join(output_dir, 'error_analysis.txt')
|
70 |
+
with open(error_file, 'w') as f:
|
71 |
+
f.write(f"Error Analysis for {data_info}\n")
|
72 |
+
f.write("=" * 80 + "\n\n")
|
73 |
+
|
74 |
+
# Overall accuracy
|
75 |
+
accuracy = (all_labels == all_preds).mean()
|
76 |
+
f.write(f"Overall Accuracy: {accuracy:.2%}\n\n")
|
77 |
+
|
78 |
+
# Per-class accuracy
|
79 |
+
f.write("Per-Class Accuracy:\n")
|
80 |
+
for i, class_name in enumerate(class_labels):
|
81 |
+
class_mask = all_labels == i
|
82 |
+
if class_mask.sum() > 0:
|
83 |
+
class_acc = (all_preds[class_mask] == i).mean()
|
84 |
+
f.write(f"{class_name}: {class_acc:.2%} ({(class_mask).sum()} samples)\n")
|
85 |
+
f.write("\n")
|
86 |
+
|
87 |
+
# Detailed error analysis
|
88 |
+
f.write("Misclassified Videos:\n")
|
89 |
+
f.write("-" * 80 + "\n")
|
90 |
+
f.write(f"{'Filename':<40} {'True Class':<20} {'Predicted Class':<20} Confidence\n")
|
91 |
+
f.write("-" * 80 + "\n")
|
92 |
+
|
93 |
+
for i, (true_label, pred_label, probs, filename) in enumerate(zip(all_labels, all_preds, all_probs, all_files)):
|
94 |
+
if true_label != pred_label:
|
95 |
+
true_class = class_labels[true_label]
|
96 |
+
pred_class = class_labels[pred_label]
|
97 |
+
confidence = probs[pred_label]
|
98 |
+
f.write(f"{filename:<40} {true_class:<20} {pred_class:<20} {confidence:.2%}\n")
|
99 |
+
|
100 |
# Compute and plot confusion matrix
|
101 |
cm = confusion_matrix(all_labels, all_preds)
|
102 |
plt.figure(figsize=(10, 8))
|
|
|
158 |
|
159 |
class_labels = config['class_labels']
|
160 |
num_classes = config['num_classes']
|
161 |
+
|
162 |
+
# Update the config's data_path if provided
|
163 |
+
if data_path:
|
164 |
+
config['data_path'] = data_path
|
165 |
+
data_path = config['data_path']
|
166 |
|
167 |
# Paths
|
168 |
log_file = os.path.join(run_dir, 'training_log.csv')
|
|
|
174 |
# Get the last directory of data_path and the file name
|
175 |
last_dir = os.path.basename(os.path.normpath(data_path))
|
176 |
file_name = os.path.basename(test_csv)
|
177 |
+
|
178 |
+
print(f"Running visualization for {data_path} with {test_csv} from CWD {os.getcwd()}")
|
179 |
|
180 |
# Create a directory for visualization outputs
|
181 |
vis_dir = os.path.join(run_dir, f'visualization_{last_dir}_{file_name.split(".")[0]}')
|
|
|
204 |
|
205 |
if __name__ == "__main__":
|
206 |
# Find the most recent run directory
|
207 |
+
# run_dir = get_latest_run_dir()
|
208 |
+
run_dir = "/home/bawolf/workspace/break/clip/runs_hyperparam/hyperparam_20241106_124214/search_combined_adjusted/trial_combined_adjusted_20241106-195023/"
|
209 |
# run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241024-150232_otherpeopleval_large_model"
|
210 |
# run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241022-122939_3moves_balanced"
|
211 |
+
data_path = "/home/bawolf/workspace/break/finetune/blog/combined/all"
|
212 |
+
run_visualization(run_dir, data_path=data_path)
|
script/visualization/viz_cross_compare.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from visualize import run_visualization
|
4 |
+
|
5 |
+
def get_opposite_dataset_path(run_folder):
|
6 |
+
# Map run folders to their corresponding opposite dataset training files
|
7 |
+
dataset_mapping = {
|
8 |
+
'search_bryant_adjusted': '../finetune/blog/youtube/adjusted',
|
9 |
+
'search_bryant_random': '../finetune/blog/youtube/random',
|
10 |
+
'search_youtube_adjusted': '../finetune/blog/bryant/adjusted',
|
11 |
+
'search_youtube_random': '../finetune/blog/bryant/random'
|
12 |
+
}
|
13 |
+
|
14 |
+
for folder_prefix, dataset_path in dataset_mapping.items():
|
15 |
+
if run_folder.startswith(folder_prefix):
|
16 |
+
return dataset_path
|
17 |
+
return None
|
18 |
+
|
19 |
+
def process_runs(base_dir):
|
20 |
+
# Get the full path to the runs directory
|
21 |
+
runs_dir = Path(base_dir)
|
22 |
+
|
23 |
+
# Process each search directory
|
24 |
+
for search_dir in runs_dir.iterdir():
|
25 |
+
if not search_dir.is_dir() or search_dir.name == 'visualization':
|
26 |
+
continue
|
27 |
+
|
28 |
+
# Get the opposite dataset path for this search directory
|
29 |
+
opposite_dataset = get_opposite_dataset_path(search_dir.name)
|
30 |
+
|
31 |
+
if opposite_dataset is not None:
|
32 |
+
print(f"Skipping {search_dir.name} - no matching dataset mapping")
|
33 |
+
continue
|
34 |
+
|
35 |
+
# Process each trial directory within the search directory
|
36 |
+
for trial_dir in search_dir.iterdir():
|
37 |
+
if not trial_dir.is_dir() or not trial_dir.name.startswith('trial_'):
|
38 |
+
continue
|
39 |
+
|
40 |
+
print(f"Processing {trial_dir} with {opposite_dataset}")
|
41 |
+
try:
|
42 |
+
vis_dir, cm = run_visualization(
|
43 |
+
run_dir=str(trial_dir),
|
44 |
+
data_path=opposite_dataset,
|
45 |
+
test_csv=os.path.join(opposite_dataset, "train.csv")
|
46 |
+
)
|
47 |
+
print(f"Visualization complete: {vis_dir}")
|
48 |
+
except Exception as e:
|
49 |
+
print(f"Error processing {trial_dir}: {e}")
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
# Example usage
|
53 |
+
runs_path = "runs_hyperparam/hyperparam_20241106_124214"
|
54 |
+
process_runs(runs_path)
|
src/dataset/dataset.py
CHANGED
@@ -48,6 +48,11 @@ class VideoDataset(Dataset):
|
|
48 |
def __getitem__(self, idx):
|
49 |
video_path, label = self.data[idx]
|
50 |
|
|
|
|
|
|
|
|
|
|
|
51 |
frames, success = extract_frames(video_path,
|
52 |
{"max_frames": self.max_frames, "sigma": self.sigma},
|
53 |
self.transform)
|
|
|
48 |
def __getitem__(self, idx):
|
49 |
video_path, label = self.data[idx]
|
50 |
|
51 |
+
if not os.path.exists(video_path):
|
52 |
+
print(f"File not found: {video_path}")
|
53 |
+
print(f"Absolute path attempt: {os.path.abspath(video_path)}")
|
54 |
+
raise FileNotFoundError(f"File not found: {video_path}")
|
55 |
+
|
56 |
frames, success = extract_frames(video_path,
|
57 |
{"max_frames": self.max_frames, "sigma": self.sigma},
|
58 |
self.transform)
|