import torch from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ from tqdm import tqdm import os import logging import csv import json from torch.optim.lr_scheduler import CosineAnnealingLR import math import sys sys.path.append(os.path.dirname(os.path.dirname(__file__))) from src.utils.utils import create_run_directory from src.dataset.dataset import VideoDataset from src.models.model import create_model from src.dataset.video_utils import create_transform from visualization.visualize import run_visualization from visualization.miscalculations_report import analyze_misclassifications def train_and_evaluate(config): try: # Create a run directory if it doesn't exist if "run_dir" not in config: config["run_dir"] = create_run_directory() # Update paths based on run_dir config.update({ "best_model_path": os.path.join(config["run_dir"], 'best_model.pth'), "final_model_path": os.path.join(config["run_dir"], 'final_model.pth'), "csv_path": os.path.join(config["run_dir"], 'training_log.csv'), "misclassifications_dir": os.path.join(config["run_dir"], 'misclassifications'), }) config_path = os.path.join(config["run_dir"], 'config.json') with open(config_path, 'w') as f: json.dump(config, f, indent=2) # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.FileHandler(os.path.join(config["run_dir"], 'training.log')), logging.StreamHandler()]) logger = logging.getLogger(__name__) # Use device from config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") if torch.cuda.is_available(): torch.cuda.empty_cache() # Initialize variables best_val_loss = float('inf') epochs_without_improvement = 0 if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB") print(f"Currently allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB") model = create_model(config["num_classes"], config["clip_model"]) # Unfreeze the last 2 layers of the vision encoder model.unfreeze_vision_encoder(num_layers=config["unfreeze_layers"]) model = model.to(device) # Ensure criterion is on the same device criterion = torch.nn.CrossEntropyLoss().to(device) # logger.info(f"Model architecture:\n{model}") # Load datasets train_dataset = VideoDataset( os.path.join(config['data_path'], 'train.csv'), config=config ) # For validation, create a new config with training=False for transforms val_config = config.copy() val_dataset = VideoDataset( os.path.join(config['data_path'], 'val.csv'), config=val_config, transform=create_transform(config, training=False) ) # Create data loaders train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True) val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False) # Define optimizer and learning rate scheduler optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"]) scheduler = CosineAnnealingLR(optimizer, T_max=config["num_epochs"]) # Open a CSV file to log training progress with open(config["csv_path"], 'w', newline='') as file: writer = csv.writer(file) writer.writerow(["epoch", "train_loss", "train_accuracy", "val_loss", "val_accuracy"]) # Function to calculate accuracy def calculate_accuracy(outputs, labels): _, predicted = torch.max(outputs, 1) correct = (predicted == labels).sum().item() total = labels.size(0) return correct / total def log_misclassifications(outputs, labels, video_paths, dataset, misclassified_videos): _, predicted = torch.max(outputs, 1) for pred, label, video_path in zip(predicted, labels, video_paths): if pred != label: true_label = dataset.label_map[label.item()] predicted_label = dataset.label_map[pred.item()] misclassified_videos.append({ 'video_path': video_path, 'true_label': true_label, 'predicted_label': predicted_label }) # Create a subfolder for misclassification logs os.makedirs(config["misclassifications_dir"], exist_ok=True) # Training loop for epoch in range(config["num_epochs"]): model.train() total_loss = 0 total_accuracy = 0 for frames, labels, video_paths in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{config['num_epochs']}"): frames = frames.to(device) labels = labels.to(device) logits = model(frames) loss = criterion(logits, labels) accuracy = calculate_accuracy(logits, labels) optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), max_norm=config["gradient_clip_max_norm"]) optimizer.step() total_loss += loss.item() total_accuracy += accuracy avg_train_loss = total_loss / len(train_loader) avg_train_accuracy = total_accuracy / len(train_loader) # Validation model.eval() val_loss = 0 val_accuracy = 0 misclassified_videos = [] with torch.no_grad(): for frames, labels, video_paths in val_loader: frames = frames.to(device) labels = labels.to(device) logits = model(frames) loss = criterion(logits, labels) accuracy = calculate_accuracy(logits, labels) val_loss += loss.item() val_accuracy += accuracy # Log misclassifications log_misclassifications(logits, labels, video_paths, val_dataset, misclassified_videos) avg_val_loss = val_loss / len(val_loader) avg_val_accuracy = val_accuracy / len(val_loader) # Log misclassified videos if misclassified_videos: misclassified_log_path = os.path.join(config["misclassifications_dir"], f'epoch_{epoch+1}.json') with open(misclassified_log_path, 'w') as f: json.dump(misclassified_videos, f, indent=2) logger.info(f"Logged {len(misclassified_videos)} misclassified videos to {misclassified_log_path}") # Log the metrics logger.info(f"Epoch [{epoch+1}/{config['num_epochs']}], " f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy*100:.2f}%, " f"Val Loss: {avg_val_loss:.4f}, Val Accuracy: {avg_val_accuracy*100:.2f}%") # Write to CSV with open(config["csv_path"], 'a', newline='') as file: writer = csv.writer(file) writer.writerow([epoch+1, avg_train_loss, avg_train_accuracy*100, avg_val_loss, avg_val_accuracy*100]) # Learning rate scheduling scheduler.step() # Save the best model and check for early stopping if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss torch.save(model.state_dict(), config["best_model_path"]) logger.info(f"Saved best model to {config['best_model_path']}") epochs_without_improvement = 0 else: epochs_without_improvement += 1 # Early stopping check if epochs_without_improvement >= config["patience"]: logger.info(f"Early stopping triggered after {config['patience']} epochs without improvement") break # Overfitting detection if avg_train_accuracy - avg_val_accuracy > config["overfitting_threshold"]: logger.warning("Possible overfitting detected") logger.info("Training finished!") # Save the final model torch.save(model.state_dict(), config["final_model_path"]) logger.info(f"Saved final model to {config['final_model_path']}") # Save run information with open(os.path.join(config["run_dir"], 'run_info.txt'), 'w') as f: for key, value in config.items(): f.write(f"{key}: {value}\n") f.write(f"Device: {device}\n") f.write(f"Model: {model.__class__.__name__}\n") f.write(f"Optimizer: {optimizer.__class__.__name__}\n") f.write(f"Scheduler: {scheduler.__class__.__name__}\n") f.write(f"Loss function: CrossEntropyLoss\n") f.write(f"Data augmentation: RandomHorizontalFlip, RandomRotation(5), ColorJitter\n") f.write(f"Mixed precision training: {'Enabled' if 'scaler' in locals() else 'Disabled'}\n") f.write(f"Train dataset size: {len(train_dataset)}\n") f.write(f"Validation dataset size: {len(val_dataset)}\n") f.write(f"Vision encoder frozen: {'Partially' if hasattr(model, 'unfreeze_vision_encoder') else 'Unknown'}\n") # Run visualization try: logger.info("Running visualization...") vis_dir, confusion_matrix = run_visualization(config["run_dir"]) logger.info(f"Visualization complete! Check the output directory: {vis_dir}") # Log confusion matrix results class_accuracies = confusion_matrix.diagonal() / confusion_matrix.sum(axis=1) overall_accuracy = confusion_matrix.diagonal().sum() / confusion_matrix.sum() logger.info("\nConfusion Matrix Results:") for i, (label, accuracy) in enumerate(zip(config['class_labels'], class_accuracies)): logger.info(f"{label}: {accuracy:.2%}") logger.info(f"Overall Accuracy: {overall_accuracy:.2%}") except Exception as e: logger.error(f"Error running visualization: {str(e)}") # Run misclassification analysis try: analyze_misclassifications(config["run_dir"]) logger.info(f"Misclassification analysis complete! Check the output directory: {config['run_dir']}") except Exception as e: logger.error(f"Error running misclassification analysis: {str(e)}") if math.isnan(avg_val_accuracy) or math.isinf(avg_val_accuracy): raise ValueError(f"Invalid validation accuracy: {avg_val_accuracy}") print("Script finished.") return avg_val_accuracy, vis_dir except Exception as e: logger.error(f"Training error: {str(e)}") raise # Re-raise the exception to be caught by the hyperparameter tuning def main(): # Create run directory run_dir = create_run_directory() class_labels = ["windmill", "halo", "swipe", "baby_mill"][:3] # Write configuration config = { "class_labels": class_labels, "num_classes": len(class_labels), "data_path": './data/blog/datasets/bryant/random', "batch_size": 8, "learning_rate": 2e-6, "weight_decay": 0.007, "num_epochs": 2, "patience": 10, # for early stopping "max_frames": 10, "sigma": 0.3, "image_size": 224, "flip_probability": 0.5, "rotation_degrees": 15, "brightness_jitter": 0.2, "contrast_jitter": 0.2, "saturation_jitter": 0.2, "hue_jitter": 0.1, "crop_scale_min": 0.8, "crop_scale_max": 1.0, "normalization_mean": [0.485, 0.456, 0.406], "normalization_std": [0.229, 0.224, 0.225], "unfreeze_layers": 3, # "clip_model": "openai/clip-vit-large-patch14", "clip_model": "openai/clip-vit-base-patch32", "gradient_clip_max_norm": 1.0, "overfitting_threshold": 10, "run_dir": run_dir, } train_and_evaluate(config) if __name__ == "__main__": main()