bawolf commited on
Commit
5acfa1a
1 Parent(s): 72a4f99

more cross comparing

Browse files
script/hyperparameter_tuning.py CHANGED
@@ -84,8 +84,9 @@ def objective(trial, hyperparam_run_dir, data_path):
84
  'dataset_label': dataset_label,
85
  'trial_number': trial.number,
86
  'parameters': trial.params,
87
- 'value': val_accuracy,
88
- 'visualization_dir': vis_dir
 
89
  }
90
 
91
  with open(os.path.join(trial_dir, 'trial_info.json'), 'w') as f:
 
84
  'dataset_label': dataset_label,
85
  'trial_number': trial.number,
86
  'parameters': trial.params,
87
+ 'accuracy': val_accuracy,
88
+ 'visualization_dir': vis_dir,
89
+ 'trial_dir': trial_dir
90
  }
91
 
92
  with open(os.path.join(trial_dir, 'trial_info.json'), 'w') as f:
script/visualization/analyze_trials.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pandas as pd
4
+ from pathlib import Path
5
+ from collections import defaultdict
6
+
7
+ def parse_error_analysis(vis_dir):
8
+ """Parse the error_analysis.txt file to get accuracy and misclassification details"""
9
+ metrics = {}
10
+ class_accuracies = {}
11
+ misclassified_files = []
12
+
13
+ with open(os.path.join(vis_dir, 'error_analysis.txt'), 'r') as f:
14
+ lines = f.readlines()
15
+ parsing_errors = False
16
+ header_found = False
17
+
18
+ for line in lines:
19
+ # Get overall accuracy
20
+ if line.startswith("Overall Accuracy:"):
21
+ metrics['overall_accuracy'] = float(line.split(":")[1].strip().rstrip('%')) / 100
22
+
23
+ # Parse per-class accuracy
24
+ if "samples)" in line and ":" in line:
25
+ class_name = line.split(":")[0].strip()
26
+ accuracy = float(line.split(":")[1].split("%")[0].strip()) / 100
27
+ samples = int(line.split("(")[1].split(" ")[0])
28
+ class_accuracies[class_name] = {
29
+ 'accuracy': accuracy,
30
+ 'samples': samples
31
+ }
32
+
33
+ # Parse misclassified files
34
+ if "Misclassified Videos:" in line:
35
+ parsing_errors = True
36
+ continue
37
+ if "Filename" in line and "True Class" in line:
38
+ header_found = True
39
+ continue
40
+ if parsing_errors and header_found and line.strip() and not line.startswith("-"):
41
+ try:
42
+ # Split the line while preserving filename with spaces
43
+ parts = line.strip().split()
44
+ # Find the confidence value (last element with %)
45
+ confidence_idx = next(i for i, part in enumerate(parts) if part.endswith('%'))
46
+ # Everything before the last three elements is the filename
47
+ filename = ' '.join(parts[:confidence_idx-2])
48
+ true_class = parts[confidence_idx-2]
49
+ pred_class = parts[confidence_idx-1]
50
+ confidence = float(parts[confidence_idx].rstrip('%')) / 100
51
+
52
+ misclassified_files.append({
53
+ 'filename': filename,
54
+ 'true_class': true_class,
55
+ 'predicted_class': pred_class,
56
+ 'confidence': confidence
57
+ })
58
+ except Exception as e:
59
+ print(f"Warning: Could not parse line: {line.strip()}")
60
+ continue
61
+
62
+ metrics['class_accuracies'] = class_accuracies
63
+ metrics['misclassified_files'] = misclassified_files
64
+ return metrics
65
+
66
+ def analyze_trial(trial_dir):
67
+ """Analyze all visualization directories in a trial and aggregate results"""
68
+ trial_metrics = {
69
+ 'overall_accuracy': 0,
70
+ 'total_samples': 0,
71
+ 'class_accuracies': defaultdict(lambda: {'correct': 0, 'total': 0}),
72
+ 'misclassified_files': []
73
+ }
74
+
75
+ # Find all visualization directories
76
+ vis_dirs = [d for d in trial_dir.iterdir() if d.is_dir() and d.name.startswith('visualization_')]
77
+ if not vis_dirs:
78
+ return None
79
+
80
+ for vis_dir in vis_dirs:
81
+ try:
82
+ metrics = parse_error_analysis(vis_dir)
83
+
84
+ # Add to total samples and weighted accuracy
85
+ samples = sum(m['samples'] for m in metrics['class_accuracies'].values())
86
+ trial_metrics['total_samples'] += samples
87
+ trial_metrics['overall_accuracy'] += metrics['overall_accuracy'] * samples
88
+
89
+ # Aggregate per-class metrics
90
+ for class_name, class_metrics in metrics['class_accuracies'].items():
91
+ trial_metrics['class_accuracies'][class_name]['correct'] += (
92
+ class_metrics['accuracy'] * class_metrics['samples']
93
+ )
94
+ trial_metrics['class_accuracies'][class_name]['total'] += class_metrics['samples']
95
+
96
+ # Collect misclassified files with visualization directory info
97
+ for error in metrics['misclassified_files']:
98
+ error['vis_dir'] = vis_dir.name
99
+ trial_metrics['misclassified_files'].append(error)
100
+
101
+ except Exception as e:
102
+ print(f"Error processing visualization directory {vis_dir}: {e}")
103
+
104
+ # Calculate final metrics
105
+ if trial_metrics['total_samples'] > 0:
106
+ trial_metrics['overall_accuracy'] /= trial_metrics['total_samples']
107
+
108
+ for class_metrics in trial_metrics['class_accuracies'].values():
109
+ if class_metrics['total'] > 0:
110
+ class_metrics['accuracy'] = class_metrics['correct'] / class_metrics['total']
111
+
112
+ return trial_metrics
113
+
114
+ def analyze_trials(hyperparam_dir):
115
+ results = {
116
+ 'search_dirs': defaultdict(lambda: {
117
+ 'best_overall': {'accuracy': 0, 'trial': None},
118
+ 'best_per_class': defaultdict(lambda: {'accuracy': 0, 'trial': None}),
119
+ 'misclassified_files': []
120
+ })
121
+ }
122
+
123
+ # Process each search directory
124
+ for search_dir in Path(hyperparam_dir).iterdir():
125
+ if not search_dir.is_dir() or not search_dir.name.startswith('search_'):
126
+ continue
127
+
128
+ # Process each trial directory
129
+ for trial_dir in search_dir.iterdir():
130
+ if not trial_dir.is_dir() or not trial_dir.name.startswith('trial_'):
131
+ continue
132
+
133
+ trial_metrics = analyze_trial(trial_dir)
134
+ if trial_metrics is None:
135
+ continue
136
+
137
+ search_results = results['search_dirs'][search_dir.name]
138
+
139
+ # Update overall best for this search directory
140
+ if trial_metrics['overall_accuracy'] > search_results['best_overall']['accuracy']:
141
+ search_results['best_overall']['accuracy'] = trial_metrics['overall_accuracy']
142
+ search_results['best_overall']['trial'] = trial_dir.name
143
+
144
+ # Update per-class bests for this search directory
145
+ for class_name, class_metrics in trial_metrics['class_accuracies'].items():
146
+ if class_metrics['accuracy'] > search_results['best_per_class'][class_name]['accuracy']:
147
+ search_results['best_per_class'][class_name]['accuracy'] = class_metrics['accuracy']
148
+ search_results['best_per_class'][class_name]['trial'] = trial_dir.name
149
+
150
+ # Collect misclassified files
151
+ search_results['misclassified_files'].extend(trial_metrics['misclassified_files'])
152
+
153
+ return results
154
+
155
+ def save_analysis_report(results, hyperparam_dir):
156
+ output_file = os.path.join(hyperparam_dir, 'trial_analysis_report.txt')
157
+
158
+ with open(output_file, 'w') as f:
159
+ for search_dir, search_results in results['search_dirs'].items():
160
+ f.write(f"\n=== Results for {search_dir} ===\n")
161
+ f.write("-" * 80 + "\n")
162
+
163
+ # Best overall model
164
+ f.write("\nBest Overall Model:\n")
165
+ f.write(f"Trial: {search_results['best_overall']['trial']}\n")
166
+ f.write(f"Accuracy: {search_results['best_overall']['accuracy']:.2%}\n")
167
+
168
+ # Best model per class
169
+ f.write("\nBest Model Per Class:\n")
170
+ f.write(f"{'Class':<20} {'Accuracy':<10} {'Trial'}\n")
171
+ f.write("-" * 60 + "\n")
172
+ for class_name, data in search_results['best_per_class'].items():
173
+ f.write(f"{class_name:<20} {data['accuracy']:.2%} {data['trial']}\n")
174
+
175
+ # Most frequently misclassified files
176
+ f.write("\nMost Frequently Misclassified Files:\n")
177
+ f.write(f"{'Filename':<40} {'True Class':<15} {'Predicted':<15} {'Confidence':<10} {'Dataset'}\n")
178
+ f.write("-" * 100 + "\n")
179
+
180
+ # Sort misclassified files by confidence (ascending) to show most problematic cases first
181
+ misclassified = sorted(search_results['misclassified_files'],
182
+ key=lambda x: x['confidence'])
183
+ for error in misclassified[:10]: # Show top 10 most problematic
184
+ f.write(f"{error['filename']:<40} {error['true_class']:<15} "
185
+ f"{error['predicted_class']:<15} {error['confidence']:<10.2%} {error['vis_dir']}\n")
186
+
187
+ f.write("\n" + "=" * 80 + "\n")
188
+
189
+ def print_results(results):
190
+ """Print a summary of the analysis results"""
191
+ for search_dir, search_results in results['search_dirs'].items():
192
+ print(f"\n=== Results for {search_dir} ===")
193
+ print("-" * 80)
194
+
195
+ # Best overall model
196
+ print(f"\nBest Overall Model:")
197
+ print(f"Trial: {search_results['best_overall']['trial']}")
198
+ print(f"Accuracy: {search_results['best_overall']['accuracy']:.2%}")
199
+
200
+ # Best model per class
201
+ print(f"\nBest Model Per Class:")
202
+ print(f"{'Class':<20} {'Accuracy':<10} {'Trial'}")
203
+ print("-" * 60)
204
+ for class_name, data in search_results['best_per_class'].items():
205
+ print(f"{class_name:<20} {data['accuracy']:.2%} {data['trial']}")
206
+
207
+ # Most frequently misclassified files (top 5)
208
+ print(f"\nTop 5 Most Problematic Files:")
209
+ print(f"{'Filename':<40} {'True Class':<15} {'Predicted':<15} {'Confidence'}")
210
+ print("-" * 80)
211
+ misclassified = sorted(search_results['misclassified_files'],
212
+ key=lambda x: x['confidence'])[:5]
213
+ for error in misclassified:
214
+ print(f"{error['filename']:<40} {error['true_class']:<15} "
215
+ f"{error['predicted_class']:<15} {error['confidence']:.2%}")
216
+
217
+ if __name__ == "__main__":
218
+ hyperparam_dir = "runs_hyperparam/hyperparam_20241106_124214"
219
+ results = analyze_trials(hyperparam_dir)
220
+
221
+ # Print summary to console
222
+ print_results(results)
223
+
224
+ # Save detailed results to file
225
+ save_analysis_report(results, hyperparam_dir)
script/visualization/visualize.py CHANGED
@@ -45,9 +45,10 @@ def generate_evaluation_metrics(model, data_loader, device, output_dir, class_la
45
  all_preds = []
46
  all_labels = []
47
  all_probs = []
 
48
 
49
  with torch.no_grad():
50
- for frames, labels, _ in data_loader:
51
  frames = frames.to(device)
52
  labels = labels.to(device)
53
 
@@ -58,11 +59,44 @@ def generate_evaluation_metrics(model, data_loader, device, output_dir, class_la
58
  all_preds.extend(predicted.cpu().numpy())
59
  all_labels.extend(labels.cpu().numpy())
60
  all_probs.extend(probs.cpu().numpy())
 
61
 
62
  all_labels = np.array(all_labels)
63
  all_preds = np.array(all_preds)
64
  all_probs = np.array(all_probs)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Compute and plot confusion matrix
67
  cm = confusion_matrix(all_labels, all_preds)
68
  plt.figure(figsize=(10, 8))
@@ -124,7 +158,11 @@ def run_visualization(run_dir, data_path=None, test_csv=None):
124
 
125
  class_labels = config['class_labels']
126
  num_classes = config['num_classes']
127
- data_path = data_path or config['data_path']
 
 
 
 
128
 
129
  # Paths
130
  log_file = os.path.join(run_dir, 'training_log.csv')
@@ -136,6 +174,8 @@ def run_visualization(run_dir, data_path=None, test_csv=None):
136
  # Get the last directory of data_path and the file name
137
  last_dir = os.path.basename(os.path.normpath(data_path))
138
  file_name = os.path.basename(test_csv)
 
 
139
 
140
  # Create a directory for visualization outputs
141
  vis_dir = os.path.join(run_dir, f'visualization_{last_dir}_{file_name.split(".")[0]}')
@@ -164,8 +204,9 @@ def run_visualization(run_dir, data_path=None, test_csv=None):
164
 
165
  if __name__ == "__main__":
166
  # Find the most recent run directory
167
- run_dir = get_latest_run_dir()
 
168
  # run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241024-150232_otherpeopleval_large_model"
169
  # run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241022-122939_3moves_balanced"
170
-
171
- run_visualization(run_dir)
 
45
  all_preds = []
46
  all_labels = []
47
  all_probs = []
48
+ all_files = []
49
 
50
  with torch.no_grad():
51
+ for frames, labels, filenames in data_loader:
52
  frames = frames.to(device)
53
  labels = labels.to(device)
54
 
 
59
  all_preds.extend(predicted.cpu().numpy())
60
  all_labels.extend(labels.cpu().numpy())
61
  all_probs.extend(probs.cpu().numpy())
62
+ all_files.extend(filenames)
63
 
64
  all_labels = np.array(all_labels)
65
  all_preds = np.array(all_preds)
66
  all_probs = np.array(all_probs)
67
 
68
+ # Generate error analysis file
69
+ error_file = os.path.join(output_dir, 'error_analysis.txt')
70
+ with open(error_file, 'w') as f:
71
+ f.write(f"Error Analysis for {data_info}\n")
72
+ f.write("=" * 80 + "\n\n")
73
+
74
+ # Overall accuracy
75
+ accuracy = (all_labels == all_preds).mean()
76
+ f.write(f"Overall Accuracy: {accuracy:.2%}\n\n")
77
+
78
+ # Per-class accuracy
79
+ f.write("Per-Class Accuracy:\n")
80
+ for i, class_name in enumerate(class_labels):
81
+ class_mask = all_labels == i
82
+ if class_mask.sum() > 0:
83
+ class_acc = (all_preds[class_mask] == i).mean()
84
+ f.write(f"{class_name}: {class_acc:.2%} ({(class_mask).sum()} samples)\n")
85
+ f.write("\n")
86
+
87
+ # Detailed error analysis
88
+ f.write("Misclassified Videos:\n")
89
+ f.write("-" * 80 + "\n")
90
+ f.write(f"{'Filename':<40} {'True Class':<20} {'Predicted Class':<20} Confidence\n")
91
+ f.write("-" * 80 + "\n")
92
+
93
+ for i, (true_label, pred_label, probs, filename) in enumerate(zip(all_labels, all_preds, all_probs, all_files)):
94
+ if true_label != pred_label:
95
+ true_class = class_labels[true_label]
96
+ pred_class = class_labels[pred_label]
97
+ confidence = probs[pred_label]
98
+ f.write(f"{filename:<40} {true_class:<20} {pred_class:<20} {confidence:.2%}\n")
99
+
100
  # Compute and plot confusion matrix
101
  cm = confusion_matrix(all_labels, all_preds)
102
  plt.figure(figsize=(10, 8))
 
158
 
159
  class_labels = config['class_labels']
160
  num_classes = config['num_classes']
161
+
162
+ # Update the config's data_path if provided
163
+ if data_path:
164
+ config['data_path'] = data_path
165
+ data_path = config['data_path']
166
 
167
  # Paths
168
  log_file = os.path.join(run_dir, 'training_log.csv')
 
174
  # Get the last directory of data_path and the file name
175
  last_dir = os.path.basename(os.path.normpath(data_path))
176
  file_name = os.path.basename(test_csv)
177
+
178
+ print(f"Running visualization for {data_path} with {test_csv} from CWD {os.getcwd()}")
179
 
180
  # Create a directory for visualization outputs
181
  vis_dir = os.path.join(run_dir, f'visualization_{last_dir}_{file_name.split(".")[0]}')
 
204
 
205
  if __name__ == "__main__":
206
  # Find the most recent run directory
207
+ # run_dir = get_latest_run_dir()
208
+ run_dir = "/home/bawolf/workspace/break/clip/runs_hyperparam/hyperparam_20241106_124214/search_combined_adjusted/trial_combined_adjusted_20241106-195023/"
209
  # run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241024-150232_otherpeopleval_large_model"
210
  # run_dir = "/home/bawolf/workspace/break/clip/runs/run_20241022-122939_3moves_balanced"
211
+ data_path = "/home/bawolf/workspace/break/finetune/blog/combined/all"
212
+ run_visualization(run_dir, data_path=data_path)
script/visualization/viz_cross_compare.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from visualize import run_visualization
4
+
5
+ def get_opposite_dataset_path(run_folder):
6
+ # Map run folders to their corresponding opposite dataset training files
7
+ dataset_mapping = {
8
+ 'search_bryant_adjusted': '../finetune/blog/youtube/adjusted',
9
+ 'search_bryant_random': '../finetune/blog/youtube/random',
10
+ 'search_youtube_adjusted': '../finetune/blog/bryant/adjusted',
11
+ 'search_youtube_random': '../finetune/blog/bryant/random'
12
+ }
13
+
14
+ for folder_prefix, dataset_path in dataset_mapping.items():
15
+ if run_folder.startswith(folder_prefix):
16
+ return dataset_path
17
+ return None
18
+
19
+ def process_runs(base_dir):
20
+ # Get the full path to the runs directory
21
+ runs_dir = Path(base_dir)
22
+
23
+ # Process each search directory
24
+ for search_dir in runs_dir.iterdir():
25
+ if not search_dir.is_dir() or search_dir.name == 'visualization':
26
+ continue
27
+
28
+ # Get the opposite dataset path for this search directory
29
+ opposite_dataset = get_opposite_dataset_path(search_dir.name)
30
+
31
+ if opposite_dataset is not None:
32
+ print(f"Skipping {search_dir.name} - no matching dataset mapping")
33
+ continue
34
+
35
+ # Process each trial directory within the search directory
36
+ for trial_dir in search_dir.iterdir():
37
+ if not trial_dir.is_dir() or not trial_dir.name.startswith('trial_'):
38
+ continue
39
+
40
+ print(f"Processing {trial_dir} with {opposite_dataset}")
41
+ try:
42
+ vis_dir, cm = run_visualization(
43
+ run_dir=str(trial_dir),
44
+ data_path=opposite_dataset,
45
+ test_csv=os.path.join(opposite_dataset, "train.csv")
46
+ )
47
+ print(f"Visualization complete: {vis_dir}")
48
+ except Exception as e:
49
+ print(f"Error processing {trial_dir}: {e}")
50
+
51
+ if __name__ == "__main__":
52
+ # Example usage
53
+ runs_path = "runs_hyperparam/hyperparam_20241106_124214"
54
+ process_runs(runs_path)
src/dataset/dataset.py CHANGED
@@ -48,6 +48,11 @@ class VideoDataset(Dataset):
48
  def __getitem__(self, idx):
49
  video_path, label = self.data[idx]
50
 
 
 
 
 
 
51
  frames, success = extract_frames(video_path,
52
  {"max_frames": self.max_frames, "sigma": self.sigma},
53
  self.transform)
 
48
  def __getitem__(self, idx):
49
  video_path, label = self.data[idx]
50
 
51
+ if not os.path.exists(video_path):
52
+ print(f"File not found: {video_path}")
53
+ print(f"Absolute path attempt: {os.path.abspath(video_path)}")
54
+ raise FileNotFoundError(f"File not found: {video_path}")
55
+
56
  frames, success = extract_frames(video_path,
57
  {"max_frames": self.max_frames, "sigma": self.sigma},
58
  self.transform)