from flask import Flask, render_template, request, jsonify from flask_socketio import SocketIO import sys import os sys.path.append(os.path.dirname(os.path.abspath(__file__))) import shutil import numpy as np from PIL import Image from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor class Predictor: def __init__(self, model_cfg, checkpoint, device): self.device = device self.model = build_sam2(model_cfg, checkpoint, device=device) self.predictor = SAM2ImagePredictor(self.model) self.image_set = False def set_image(self, image): """Set the image for SAM prediction.""" self.image = image self.predictor.set_image(image) self.image_set = True def predict(self, point_coords, point_labels, multimask_output=False): """Run SAM prediction.""" if not self.image_set: raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") return self.predictor.predict( point_coords=point_coords, point_labels=point_labels, multimask_output=multimask_output ) from utils.helpers import ( blend_mask_with_image, save_mask_as_png, convert_mask_to_yolo, ) import torch from ultralytics import YOLO import threading from threading import Lock import subprocess import time import logging import multiprocessing import json # Initialize Flask app and SocketIO app = Flask(__name__) socketio = SocketIO(app) # Define Base Directory BASE_DIR = os.path.abspath(os.path.dirname(__file__)) # Folder structure with absolute paths UPLOAD_FOLDERS = { 'input': os.path.join(BASE_DIR, 'static/uploads/input'), 'segmented_voids': os.path.join(BASE_DIR, 'static/uploads/segmented/voids'), 'segmented_chips': os.path.join(BASE_DIR, 'static/uploads/segmented/chips'), 'mask_voids': os.path.join(BASE_DIR, 'static/uploads/mask/voids'), 'mask_chips': os.path.join(BASE_DIR, 'static/uploads/mask/chips'), 'automatic_segmented': os.path.join(BASE_DIR, 'static/uploads/segmented/automatic'), } HISTORY_FOLDERS = { 'images': os.path.join(BASE_DIR, 'static/history/images'), 'masks_chip': os.path.join(BASE_DIR, 'static/history/masks/chip'), 'masks_void': os.path.join(BASE_DIR, 'static/history/masks/void'), } DATASET_FOLDERS = { 'train_images': os.path.join(BASE_DIR, 'dataset/train/images'), 'train_labels': os.path.join(BASE_DIR, 'dataset/train/labels'), 'val_images': os.path.join(BASE_DIR, 'dataset/val/images'), 'val_labels': os.path.join(BASE_DIR, 'dataset/val/labels'), 'temp_backup': os.path.join(BASE_DIR, 'temp_backup'), 'models': os.path.join(BASE_DIR, 'models'), 'models_old': os.path.join(BASE_DIR, 'models/old'), } # Ensure all folders exist for folder_name, folder_path in {**UPLOAD_FOLDERS, **HISTORY_FOLDERS, **DATASET_FOLDERS}.items(): os.makedirs(folder_path, exist_ok=True) logging.info(f"Ensured folder exists: {folder_name} -> {folder_path}") training_process = None def initialize_training_status(): """Initialize global training status.""" global training_status training_status = {'running': False, 'cancelled': False} def persist_training_status(): """Save training status to a file.""" with open(os.path.join(BASE_DIR, 'training_status.json'), 'w') as status_file: json.dump(training_status, status_file) def load_training_status(): """Load training status from a file.""" global training_status status_path = os.path.join(BASE_DIR, 'training_status.json') if os.path.exists(status_path): with open(status_path, 'r') as status_file: training_status = json.load(status_file) else: training_status = {'running': False, 'cancelled': False} load_training_status() os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0" # Initialize SAM Predictor MODEL_CFG = r"sam2/sam2_hiera_l.yaml" CHECKPOINT = r"sam2/checkpoints/sam2.1_hiera_large.pt" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") predictor = Predictor(MODEL_CFG, CHECKPOINT, DEVICE) # Initialize YOLO-seg YOLO_CFG = os.path.join(DATASET_FOLDERS['models'], "best.pt") yolo_model = YOLO(YOLO_CFG) # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', handlers=[ logging.StreamHandler(), logging.FileHandler(os.path.join(BASE_DIR, "app.log")) # Log to a file ] ) @app.route('/') def index(): """Serve the main UI.""" return render_template('index.html') @app.route('/upload', methods=['POST']) def upload_image(): """Handle image uploads.""" if 'file' not in request.files: return jsonify({'error': 'No file uploaded'}), 400 file = request.files['file'] if file.filename == '': return jsonify({'error': 'No file selected'}), 400 # Save the uploaded file to the input folder input_path = os.path.join(UPLOAD_FOLDERS['input'], file.filename) file.save(input_path) # Set the uploaded image in the predictor image = np.array(Image.open(input_path).convert("RGB")) predictor.set_image(image) # Return a web-accessible URL instead of the file system path web_accessible_url = f"/static/uploads/input/{file.filename}" print(f"Image uploaded and set for prediction: {input_path}") return jsonify({'image_url': web_accessible_url}) @app.route('/segment', methods=['POST']) def segment(): """ Perform segmentation and return the blended image URL. """ try: # Extract data from request data = request.json points = np.array(data.get('points', [])) labels = np.array(data.get('labels', [])) current_class = data.get('class', 'voids') # Default to 'voids' if class not provided # Ensure predictor has an image set if not predictor.image_set: raise ValueError("No image set for prediction.") # Perform SAM prediction masks, _, _ = predictor.predict( point_coords=points, point_labels=labels, multimask_output=False ) # Check if masks exist and have non-zero elements if masks is None or masks.size == 0: raise RuntimeError("No masks were generated by the predictor.") # Define output paths based on class mask_folder = UPLOAD_FOLDERS.get(f'mask_{current_class}') segmented_folder = UPLOAD_FOLDERS.get(f'segmented_{current_class}') if not mask_folder or not segmented_folder: raise ValueError(f"Invalid class '{current_class}' provided.") os.makedirs(mask_folder, exist_ok=True) os.makedirs(segmented_folder, exist_ok=True) # Save the raw mask mask_path = os.path.join(mask_folder, 'raw_mask.png') save_mask_as_png(masks[0], mask_path) # Generate blended image blend_color = [34, 139, 34] if current_class == 'voids' else [30, 144, 255] # Green for voids, blue for chips blended_image = blend_mask_with_image(predictor.image, masks[0], blend_color) # Save blended image blended_filename = f"blended_{current_class}.png" blended_path = os.path.join(segmented_folder, blended_filename) Image.fromarray(blended_image).save(blended_path) # Return URL for frontend access segmented_url = f"/static/uploads/segmented/{current_class}/{blended_filename}" logging.info(f"Segmentation completed for {current_class}. Points: {points}, Labels: {labels}") return jsonify({'segmented_url': segmented_url}) except ValueError as ve: logging.error(f"Value error during segmentation: {ve}") return jsonify({'error': str(ve)}), 400 except Exception as e: logging.error(f"Unexpected error during segmentation: {e}") return jsonify({'error': 'Segmentation failed', 'details': str(e)}), 500 @app.route('/automatic_segment', methods=['POST']) def automatic_segment(): """Perform automatic segmentation using YOLO.""" if 'file' not in request.files: return jsonify({'error': 'No file uploaded'}), 400 file = request.files['file'] if file.filename == '': return jsonify({'error': 'No file selected'}), 400 input_path = os.path.join(UPLOAD_FOLDERS['input'], file.filename) file.save(input_path) try: # Perform YOLO segmentation results = yolo_model.predict(input_path, save=False, save_txt=False) output_folder = UPLOAD_FOLDERS['automatic_segmented'] os.makedirs(output_folder, exist_ok=True) chips_data = [] chips = [] voids = [] # Process results and save segmented images for result in results: annotated_image = result.plot() result_filename = f"{file.filename.rsplit('.', 1)[0]}_pred.jpg" result_path = os.path.join(output_folder, result_filename) Image.fromarray(annotated_image).save(result_path) # Separate chips and voids for i, label in enumerate(result.boxes.cls): # YOLO labels label_name = result.names[int(label)] # Get label name (e.g., 'chip' or 'void') box = result.boxes.xyxy[i].cpu().numpy() # Bounding box (x1, y1, x2, y2) area = float((box[2] - box[0]) * (box[3] - box[1])) # Calculate area if label_name == 'chip': chips.append({'box': box, 'area': area, 'voids': []}) elif label_name == 'void': voids.append({'box': box, 'area': area}) # Assign voids to chips based on proximity for void in voids: void_centroid = [ (void['box'][0] + void['box'][2]) / 2, # x centroid (void['box'][1] + void['box'][3]) / 2 # y centroid ] for chip in chips: # Check if void centroid is within chip bounding box if (chip['box'][0] <= void_centroid[0] <= chip['box'][2] and chip['box'][1] <= void_centroid[1] <= chip['box'][3]): chip['voids'].append(void) break # Calculate metrics for each chip for idx, chip in enumerate(chips): chip_area = chip['area'] total_void_area = sum([float(void['area']) for void in chip['voids']]) max_void_area = max([float(void['area']) for void in chip['voids']], default=0) void_percentage = (total_void_area / chip_area) * 100 if chip_area > 0 else 0 max_void_percentage = (max_void_area / chip_area) * 100 if chip_area > 0 else 0 chips_data.append({ "chip_number": int(idx + 1), "chip_area": round(chip_area, 2), "void_percentage": round(void_percentage, 2), "max_void_percentage": round(max_void_percentage, 2) }) # Return the segmented image URL and table data segmented_url = f"/static/uploads/segmented/automatic/{result_filename}" return jsonify({ "segmented_url": segmented_url, # Use the URL for frontend access "table_data": { "image_name": file.filename, "chips": chips_data } }) except Exception as e: print(f"Error in automatic segmentation: {e}") return jsonify({'error': 'Segmentation failed.'}), 500 @app.route('/save_both', methods=['POST']) def save_both(): """Save both the image and masks into the history folders.""" data = request.json image_name = data.get('image_name') if not image_name: return jsonify({'error': 'Image name not provided'}), 400 try: # Ensure image_name is a pure file name image_name = os.path.basename(image_name) # Strip any directory path print(f"Sanitized Image Name: {image_name}") # Correctly resolve the input image path input_image_path = os.path.join(UPLOAD_FOLDERS['input'], image_name) if not os.path.exists(input_image_path): print(f"Input image does not exist: {input_image_path}") return jsonify({'error': f'Input image not found: {input_image_path}'}), 404 # Copy the image to history/images image_history_path = os.path.join(HISTORY_FOLDERS['images'], image_name) os.makedirs(os.path.dirname(image_history_path), exist_ok=True) shutil.copy(input_image_path, image_history_path) print(f"Image saved to history: {image_history_path}") # Backup void mask void_mask_path = os.path.join(UPLOAD_FOLDERS['mask_voids'], 'raw_mask.png') if os.path.exists(void_mask_path): void_mask_history_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png") os.makedirs(os.path.dirname(void_mask_history_path), exist_ok=True) shutil.copy(void_mask_path, void_mask_history_path) print(f"Voids mask saved to history: {void_mask_history_path}") else: print(f"Voids mask not found: {void_mask_path}") # Backup chip mask chip_mask_path = os.path.join(UPLOAD_FOLDERS['mask_chips'], 'raw_mask.png') if os.path.exists(chip_mask_path): chip_mask_history_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png") os.makedirs(os.path.dirname(chip_mask_history_path), exist_ok=True) shutil.copy(chip_mask_path, chip_mask_history_path) print(f"Chips mask saved to history: {chip_mask_history_path}") else: print(f"Chips mask not found: {chip_mask_path}") return jsonify({'message': 'Image and masks saved successfully!'}), 200 except Exception as e: print(f"Error saving files: {e}") return jsonify({'error': 'Failed to save files.', 'details': str(e)}), 500 @app.route('/get_history', methods=['GET']) def get_history(): try: saved_images = os.listdir(HISTORY_FOLDERS['images']) return jsonify({'status': 'success', 'images': saved_images}), 200 except Exception as e: return jsonify({'status': 'error', 'message': f'Failed to fetch history: {e}'}), 500 @app.route('/delete_history_item', methods=['POST']) def delete_history_item(): data = request.json image_name = data.get('image_name') if not image_name: return jsonify({'error': 'Image name not provided'}), 400 try: image_path = os.path.join(HISTORY_FOLDERS['images'], image_name) if os.path.exists(image_path): os.remove(image_path) void_mask_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png") if os.path.exists(void_mask_path): os.remove(void_mask_path) chip_mask_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png") if os.path.exists(chip_mask_path): os.remove(chip_mask_path) return jsonify({'message': f'{image_name} and associated masks deleted successfully.'}), 200 except Exception as e: return jsonify({'error': f'Failed to delete files: {e}'}), 500 # Lock for training status updates status_lock = Lock() def update_training_status(key, value): """Thread-safe update for training status.""" with status_lock: training_status[key] = value @app.route('/retrain_model', methods=['POST']) def retrain_model(): """Handle retrain model workflow.""" global training_status if training_status.get('running', False): return jsonify({'error': 'Training is already in progress'}), 400 try: # Update training status update_training_status('running', True) update_training_status('cancelled', False) logging.info("Training status updated. Starting training workflow.") # Backup masks and images backup_masks_and_images() logging.info("Backup completed successfully.") # Prepare YOLO labels prepare_yolo_labels() logging.info("YOLO labels prepared successfully.") # Start YOLO training in a separate thread threading.Thread(target=run_yolo_training).start() return jsonify({'message': 'Training started successfully!'}), 200 except Exception as e: logging.error(f"Error during training preparation: {e}") update_training_status('running', False) return jsonify({'error': f"Failed to start training: {e}"}), 500 def prepare_yolo_labels(): """Convert all masks into YOLO-compatible labels and copy images to the dataset folder.""" images_folder = HISTORY_FOLDERS['images'] # Use history images as the source train_labels_folder = DATASET_FOLDERS['train_labels'] train_images_folder = DATASET_FOLDERS['train_images'] val_labels_folder = DATASET_FOLDERS['val_labels'] val_images_folder = DATASET_FOLDERS['val_images'] # Ensure destination directories exist os.makedirs(train_labels_folder, exist_ok=True) os.makedirs(train_images_folder, exist_ok=True) os.makedirs(val_labels_folder, exist_ok=True) os.makedirs(val_images_folder, exist_ok=True) try: all_images = [img for img in os.listdir(images_folder) if img.endswith(('.jpg', '.png'))] random.shuffle(all_images) # Shuffle the images for randomness # Determine split index split_idx = int(len(all_images) * 0.8) # 80% for training, 20% for validation # Split images into train and validation sets train_images = all_images[:split_idx] val_images = all_images[split_idx:] # Process training images for image_name in train_images: process_image_and_mask( image_name, source_images_folder=images_folder, dest_images_folder=train_images_folder, dest_labels_folder=train_labels_folder ) # Process validation images for image_name in val_images: process_image_and_mask( image_name, source_images_folder=images_folder, dest_images_folder=val_images_folder, dest_labels_folder=val_labels_folder ) logging.info("YOLO labels prepared, and images split into train and validation successfully.") except Exception as e: logging.error(f"Error in preparing YOLO labels: {e}") raise import random def prepare_yolo_labels(): """Convert all masks into YOLO-compatible labels and copy images to the dataset folder.""" images_folder = HISTORY_FOLDERS['images'] # Use history images as the source train_labels_folder = DATASET_FOLDERS['train_labels'] train_images_folder = DATASET_FOLDERS['train_images'] val_labels_folder = DATASET_FOLDERS['val_labels'] val_images_folder = DATASET_FOLDERS['val_images'] # Ensure destination directories exist os.makedirs(train_labels_folder, exist_ok=True) os.makedirs(train_images_folder, exist_ok=True) os.makedirs(val_labels_folder, exist_ok=True) os.makedirs(val_images_folder, exist_ok=True) try: all_images = [img for img in os.listdir(images_folder) if img.endswith(('.jpg', '.png'))] random.shuffle(all_images) # Shuffle the images for randomness # Determine split index split_idx = int(len(all_images) * 0.8) # 80% for training, 20% for validation # Split images into train and validation sets train_images = all_images[:split_idx] val_images = all_images[split_idx:] # Process training images for image_name in train_images: process_image_and_mask( image_name, source_images_folder=images_folder, dest_images_folder=train_images_folder, dest_labels_folder=train_labels_folder ) # Process validation images for image_name in val_images: process_image_and_mask( image_name, source_images_folder=images_folder, dest_images_folder=val_images_folder, dest_labels_folder=val_labels_folder ) logging.info("YOLO labels prepared, and images split into train and validation successfully.") except Exception as e: logging.error(f"Error in preparing YOLO labels: {e}") raise def process_image_and_mask(image_name, source_images_folder, dest_images_folder, dest_labels_folder): """ Process a single image and its masks, saving them in the appropriate YOLO format. """ try: image_path = os.path.join(source_images_folder, image_name) label_file_path = os.path.join(dest_labels_folder, f"{os.path.splitext(image_name)[0]}.txt") # Copy image to the destination images folder shutil.copy(image_path, os.path.join(dest_images_folder, image_name)) # Clear the label file if it exists if os.path.exists(label_file_path): os.remove(label_file_path) # Process void mask void_mask_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png") if os.path.exists(void_mask_path): convert_mask_to_yolo( mask_path=void_mask_path, image_path=image_path, class_id=0, # Void class output_path=label_file_path ) # Process chip mask chip_mask_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png") if os.path.exists(chip_mask_path): convert_mask_to_yolo( mask_path=chip_mask_path, image_path=image_path, class_id=1, # Chip class output_path=label_file_path, append=True # Append chip annotations ) logging.info(f"Processed {image_name} into YOLO format.") except Exception as e: logging.error(f"Error processing {image_name}: {e}") raise def backup_masks_and_images(): """Backup current masks and images from history folders.""" temp_backup_paths = { 'voids': os.path.join(DATASET_FOLDERS['temp_backup'], 'masks/voids'), 'chips': os.path.join(DATASET_FOLDERS['temp_backup'], 'masks/chips'), 'images': os.path.join(DATASET_FOLDERS['temp_backup'], 'images') } # Prepare all backup directories for path in temp_backup_paths.values(): if os.path.exists(path): shutil.rmtree(path) os.makedirs(path, exist_ok=True) try: # Backup images from history for file in os.listdir(HISTORY_FOLDERS['images']): src_image_path = os.path.join(HISTORY_FOLDERS['images'], file) dst_image_path = os.path.join(temp_backup_paths['images'], file) shutil.copy(src_image_path, dst_image_path) # Backup void masks from history for file in os.listdir(HISTORY_FOLDERS['masks_void']): src_void_path = os.path.join(HISTORY_FOLDERS['masks_void'], file) dst_void_path = os.path.join(temp_backup_paths['voids'], file) shutil.copy(src_void_path, dst_void_path) # Backup chip masks from history for file in os.listdir(HISTORY_FOLDERS['masks_chip']): src_chip_path = os.path.join(HISTORY_FOLDERS['masks_chip'], file) dst_chip_path = os.path.join(temp_backup_paths['chips'], file) shutil.copy(src_chip_path, dst_chip_path) logging.info("Masks and images backed up successfully from history.") except Exception as e: logging.error(f"Error during backup: {e}") raise RuntimeError("Backup process failed.") def run_yolo_training(num_epochs=10): """Run YOLO training process.""" global training_process try: device = "cuda" if torch.cuda.is_available() else "cpu" data_cfg_path = os.path.join(BASE_DIR, "models/data.yaml") # Ensure correct YAML path logging.info(f"Starting YOLO training on {device} with {num_epochs} epochs.") logging.info(f"Using dataset configuration: {data_cfg_path}") training_command = [ "yolo", "train", f"data={data_cfg_path}", f"model={os.path.join(DATASET_FOLDERS['models'], 'best.pt')}", f"device={device}", f"epochs={num_epochs}", "project=runs", "name=train" ] training_process = subprocess.Popen( training_command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, env=os.environ.copy(), ) # Display and log output in real time for line in iter(training_process.stdout.readline, ''): print(line.strip()) logging.info(line.strip()) socketio.emit('training_update', {'message': line.strip()}) # Send updates to the frontend training_process.wait() if training_process.returncode == 0: finalize_training() # Finalize successfully completed training else: raise RuntimeError("YOLO training process failed. Check logs for details.") except Exception as e: logging.error(f"Training error: {e}") restore_backup() # Restore the dataset and masks # Emit training error event to the frontend socketio.emit('training_status', {'status': 'error', 'message': f"Training failed: {str(e)}"}) finally: update_training_status('running', False) training_process = None # Reset the process @socketio.on('cancel_training') def handle_cancel_training(): """Cancel the YOLO training process.""" global training_process, training_status if not training_status.get('running', False): socketio.emit('button_update', {'action': 'retrain'}) # Update button to retrain return try: training_process.terminate() training_process.wait() training_status['running'] = False training_status['cancelled'] = True restore_backup() cleanup_train_val_directories() # Emit button state change socketio.emit('button_update', {'action': 'retrain'}) socketio.emit('training_status', {'status': 'cancelled', 'message': 'Training was canceled by the user.'}) except Exception as e: logging.error(f"Error cancelling training: {e}") socketio.emit('training_status', {'status': 'error', 'message': str(e)}) def finalize_training(): """Finalize training by promoting the new model and cleaning up.""" try: # Locate the most recent training directory runs_dir = os.path.join(BASE_DIR, 'runs') if not os.path.exists(runs_dir): raise FileNotFoundError("Training runs directory does not exist.") # Get the latest training run folder latest_run = max( [os.path.join(runs_dir, d) for d in os.listdir(runs_dir)], key=os.path.getmtime ) weights_dir = os.path.join(latest_run, 'weights') best_model_path = os.path.join(weights_dir, 'best.pt') if not os.path.exists(best_model_path): raise FileNotFoundError(f"'best.pt' not found in {weights_dir}.") # Backup the old model old_model_folder = DATASET_FOLDERS['models_old'] os.makedirs(old_model_folder, exist_ok=True) existing_best_model = os.path.join(DATASET_FOLDERS['models'], 'best.pt') if os.path.exists(existing_best_model): timestamp = time.strftime("%Y%m%d_%H%M%S") shutil.move(existing_best_model, os.path.join(old_model_folder, f"old_{timestamp}.pt")) logging.info(f"Old model backed up to {old_model_folder}.") # Move the new model to the models directory new_model_dest = os.path.join(DATASET_FOLDERS['models'], 'best.pt') shutil.move(best_model_path, new_model_dest) logging.info(f"New model saved to {new_model_dest}.") # Notify frontend that training is completed socketio.emit('training_status', { 'status': 'completed', 'message': 'Training completed successfully! Model saved as best.pt.' }) # Clean up train/val directories cleanup_train_val_directories() logging.info("Train and validation directories cleaned up successfully.") except Exception as e: logging.error(f"Error finalizing training: {e}") # Emit error status to the frontend socketio.emit('training_status', {'status': 'error', 'message': f"Error finalizing training: {str(e)}"}) def restore_backup(): """Restore the dataset and masks from the backup.""" try: temp_backup = DATASET_FOLDERS['temp_backup'] shutil.copytree(os.path.join(temp_backup, 'masks/voids'), UPLOAD_FOLDERS['mask_voids'], dirs_exist_ok=True) shutil.copytree(os.path.join(temp_backup, 'masks/chips'), UPLOAD_FOLDERS['mask_chips'], dirs_exist_ok=True) shutil.copytree(os.path.join(temp_backup, 'images'), UPLOAD_FOLDERS['input'], dirs_exist_ok=True) logging.info("Backup restored successfully.") except Exception as e: logging.error(f"Error restoring backup: {e}") @app.route('/cancel_training', methods=['POST']) def cancel_training(): global training_process if training_process is None: logging.error("No active training process to terminate.") return jsonify({'error': 'No active training process to cancel.'}), 400 try: training_process.terminate() training_process.wait() training_process = None # Reset the process after termination # Update training status update_training_status('running', False) update_training_status('cancelled', True) # Check if the model is already saved as best.pt best_model_path = os.path.join(DATASET_FOLDERS['models'], 'best.pt') if os.path.exists(best_model_path): logging.info(f"Model already saved as best.pt at {best_model_path}.") socketio.emit('button_update', {'action': 'revert'}) # Notify frontend to revert button state else: logging.info("Training canceled, but no new model was saved.") # Restore backup if needed restore_backup() cleanup_train_val_directories() # Emit status update to frontend socketio.emit('training_status', {'status': 'cancelled', 'message': 'Training was canceled by the user.'}) return jsonify({'message': 'Training canceled and data restored successfully.'}), 200 except Exception as e: logging.error(f"Error cancelling training: {e}") return jsonify({'error': f"Failed to cancel training: {e}"}), 500 @app.route('/clear_history', methods=['POST']) def clear_history(): try: for folder in [HISTORY_FOLDERS['images'], HISTORY_FOLDERS['masks_chip'], HISTORY_FOLDERS['masks_void']]: shutil.rmtree(folder, ignore_errors=True) os.makedirs(folder, exist_ok=True) # Recreate the empty folder return jsonify({'message': 'History cleared successfully!'}), 200 except Exception as e: return jsonify({'error': f'Failed to clear history: {e}'}), 500 @app.route('/training_status', methods=['GET']) def get_training_status(): """Return the current training status.""" if training_status.get('running', False): return jsonify({'status': 'running', 'message': 'Training in progress.'}), 200 elif training_status.get('cancelled', False): return jsonify({'status': 'cancelled', 'message': 'Training was cancelled.'}), 200 return jsonify({'status': 'idle', 'message': 'No training is currently running.'}), 200 def cleanup_train_val_directories(): """Clear the train and validation directories.""" try: for folder in [DATASET_FOLDERS['train_images'], DATASET_FOLDERS['train_labels'], DATASET_FOLDERS['val_images'], DATASET_FOLDERS['val_labels']]: shutil.rmtree(folder, ignore_errors=True) # Remove folder contents os.makedirs(folder, exist_ok=True) # Recreate empty folders logging.info("Train and validation directories cleaned up successfully.") except Exception as e: logging.error(f"Error cleaning up train/val directories: {e}") if __name__ == '__main__': multiprocessing.set_start_method('spawn') # Required for multiprocessing on Windows app.run(debug=True, use_reloader=False)