|
from flask import Flask, render_template, request, jsonify |
|
from flask_socketio import SocketIO |
|
import sys |
|
import os |
|
sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
import shutil |
|
import numpy as np |
|
from PIL import Image |
|
from sam2.build_sam import build_sam2 |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
class Predictor: |
|
def __init__(self, model_cfg, checkpoint, device): |
|
self.device = device |
|
self.model = build_sam2(model_cfg, checkpoint, device=device) |
|
self.predictor = SAM2ImagePredictor(self.model) |
|
self.image_set = False |
|
|
|
def set_image(self, image): |
|
"""Set the image for SAM prediction.""" |
|
self.image = image |
|
self.predictor.set_image(image) |
|
self.image_set = True |
|
|
|
def predict(self, point_coords, point_labels, multimask_output=False): |
|
"""Run SAM prediction.""" |
|
if not self.image_set: |
|
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") |
|
return self.predictor.predict( |
|
point_coords=point_coords, |
|
point_labels=point_labels, |
|
multimask_output=multimask_output |
|
) |
|
from utils.helpers import ( |
|
blend_mask_with_image, |
|
save_mask_as_png, |
|
convert_mask_to_yolo, |
|
) |
|
import torch |
|
from ultralytics import YOLO |
|
import threading |
|
from threading import Lock |
|
import subprocess |
|
import time |
|
import logging |
|
import multiprocessing |
|
import json |
|
|
|
|
|
|
|
app = Flask(__name__) |
|
socketio = SocketIO(app) |
|
|
|
|
|
BASE_DIR = os.path.abspath(os.path.dirname(__file__)) |
|
|
|
|
|
UPLOAD_FOLDERS = { |
|
'input': os.path.join(BASE_DIR, 'static/uploads/input'), |
|
'segmented_voids': os.path.join(BASE_DIR, 'static/uploads/segmented/voids'), |
|
'segmented_chips': os.path.join(BASE_DIR, 'static/uploads/segmented/chips'), |
|
'mask_voids': os.path.join(BASE_DIR, 'static/uploads/mask/voids'), |
|
'mask_chips': os.path.join(BASE_DIR, 'static/uploads/mask/chips'), |
|
'automatic_segmented': os.path.join(BASE_DIR, 'static/uploads/segmented/automatic'), |
|
} |
|
|
|
HISTORY_FOLDERS = { |
|
'images': os.path.join(BASE_DIR, 'static/history/images'), |
|
'masks_chip': os.path.join(BASE_DIR, 'static/history/masks/chip'), |
|
'masks_void': os.path.join(BASE_DIR, 'static/history/masks/void'), |
|
} |
|
|
|
DATASET_FOLDERS = { |
|
'train_images': os.path.join(BASE_DIR, 'dataset/train/images'), |
|
'train_labels': os.path.join(BASE_DIR, 'dataset/train/labels'), |
|
'val_images': os.path.join(BASE_DIR, 'dataset/val/images'), |
|
'val_labels': os.path.join(BASE_DIR, 'dataset/val/labels'), |
|
'temp_backup': os.path.join(BASE_DIR, 'temp_backup'), |
|
'models': os.path.join(BASE_DIR, 'models'), |
|
'models_old': os.path.join(BASE_DIR, 'models/old'), |
|
} |
|
|
|
|
|
for folder_name, folder_path in {**UPLOAD_FOLDERS, **HISTORY_FOLDERS, **DATASET_FOLDERS}.items(): |
|
os.makedirs(folder_path, exist_ok=True) |
|
logging.info(f"Ensured folder exists: {folder_name} -> {folder_path}") |
|
|
|
training_process = None |
|
|
|
|
|
def initialize_training_status(): |
|
"""Initialize global training status.""" |
|
global training_status |
|
training_status = {'running': False, 'cancelled': False} |
|
|
|
def persist_training_status(): |
|
"""Save training status to a file.""" |
|
with open(os.path.join(BASE_DIR, 'training_status.json'), 'w') as status_file: |
|
json.dump(training_status, status_file) |
|
|
|
def load_training_status(): |
|
"""Load training status from a file.""" |
|
global training_status |
|
status_path = os.path.join(BASE_DIR, 'training_status.json') |
|
if os.path.exists(status_path): |
|
with open(status_path, 'r') as status_file: |
|
training_status = json.load(status_file) |
|
else: |
|
training_status = {'running': False, 'cancelled': False} |
|
|
|
load_training_status() |
|
|
|
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0" |
|
|
|
|
|
MODEL_CFG = r"sam2/sam2_hiera_l.yaml" |
|
CHECKPOINT = r"sam2/checkpoints/sam2.1_hiera_large.pt" |
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
predictor = Predictor(MODEL_CFG, CHECKPOINT, DEVICE) |
|
|
|
|
|
YOLO_CFG = os.path.join(DATASET_FOLDERS['models'], "best.pt") |
|
yolo_model = YOLO(YOLO_CFG) |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s [%(levelname)s] %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(), |
|
logging.FileHandler(os.path.join(BASE_DIR, "app.log")) |
|
] |
|
) |
|
|
|
|
|
@app.route('/') |
|
def index(): |
|
"""Serve the main UI.""" |
|
return render_template('index.html') |
|
|
|
@app.route('/upload', methods=['POST']) |
|
def upload_image(): |
|
"""Handle image uploads.""" |
|
if 'file' not in request.files: |
|
return jsonify({'error': 'No file uploaded'}), 400 |
|
file = request.files['file'] |
|
if file.filename == '': |
|
return jsonify({'error': 'No file selected'}), 400 |
|
|
|
|
|
input_path = os.path.join(UPLOAD_FOLDERS['input'], file.filename) |
|
file.save(input_path) |
|
|
|
|
|
image = np.array(Image.open(input_path).convert("RGB")) |
|
predictor.set_image(image) |
|
|
|
|
|
web_accessible_url = f"/static/uploads/input/{file.filename}" |
|
print(f"Image uploaded and set for prediction: {input_path}") |
|
return jsonify({'image_url': web_accessible_url}) |
|
|
|
@app.route('/segment', methods=['POST']) |
|
def segment(): |
|
""" |
|
Perform segmentation and return the blended image URL. |
|
""" |
|
try: |
|
|
|
data = request.json |
|
points = np.array(data.get('points', [])) |
|
labels = np.array(data.get('labels', [])) |
|
current_class = data.get('class', 'voids') |
|
|
|
|
|
if not predictor.image_set: |
|
raise ValueError("No image set for prediction.") |
|
|
|
|
|
masks, _, _ = predictor.predict( |
|
point_coords=points, |
|
point_labels=labels, |
|
multimask_output=False |
|
) |
|
|
|
|
|
if masks is None or masks.size == 0: |
|
raise RuntimeError("No masks were generated by the predictor.") |
|
|
|
|
|
mask_folder = UPLOAD_FOLDERS.get(f'mask_{current_class}') |
|
segmented_folder = UPLOAD_FOLDERS.get(f'segmented_{current_class}') |
|
|
|
if not mask_folder or not segmented_folder: |
|
raise ValueError(f"Invalid class '{current_class}' provided.") |
|
|
|
os.makedirs(mask_folder, exist_ok=True) |
|
os.makedirs(segmented_folder, exist_ok=True) |
|
|
|
|
|
mask_path = os.path.join(mask_folder, 'raw_mask.png') |
|
save_mask_as_png(masks[0], mask_path) |
|
|
|
|
|
blend_color = [34, 139, 34] if current_class == 'voids' else [30, 144, 255] |
|
blended_image = blend_mask_with_image(predictor.image, masks[0], blend_color) |
|
|
|
|
|
blended_filename = f"blended_{current_class}.png" |
|
blended_path = os.path.join(segmented_folder, blended_filename) |
|
Image.fromarray(blended_image).save(blended_path) |
|
|
|
|
|
segmented_url = f"/static/uploads/segmented/{current_class}/{blended_filename}" |
|
logging.info(f"Segmentation completed for {current_class}. Points: {points}, Labels: {labels}") |
|
return jsonify({'segmented_url': segmented_url}) |
|
|
|
except ValueError as ve: |
|
logging.error(f"Value error during segmentation: {ve}") |
|
return jsonify({'error': str(ve)}), 400 |
|
|
|
except Exception as e: |
|
logging.error(f"Unexpected error during segmentation: {e}") |
|
return jsonify({'error': 'Segmentation failed', 'details': str(e)}), 500 |
|
|
|
@app.route('/automatic_segment', methods=['POST']) |
|
def automatic_segment(): |
|
"""Perform automatic segmentation using YOLO.""" |
|
if 'file' not in request.files: |
|
return jsonify({'error': 'No file uploaded'}), 400 |
|
file = request.files['file'] |
|
if file.filename == '': |
|
return jsonify({'error': 'No file selected'}), 400 |
|
|
|
input_path = os.path.join(UPLOAD_FOLDERS['input'], file.filename) |
|
file.save(input_path) |
|
|
|
try: |
|
|
|
results = yolo_model.predict(input_path, save=False, save_txt=False) |
|
output_folder = UPLOAD_FOLDERS['automatic_segmented'] |
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
chips_data = [] |
|
chips = [] |
|
voids = [] |
|
|
|
|
|
for result in results: |
|
annotated_image = result.plot() |
|
result_filename = f"{file.filename.rsplit('.', 1)[0]}_pred.jpg" |
|
result_path = os.path.join(output_folder, result_filename) |
|
Image.fromarray(annotated_image).save(result_path) |
|
|
|
|
|
for i, label in enumerate(result.boxes.cls): |
|
label_name = result.names[int(label)] |
|
box = result.boxes.xyxy[i].cpu().numpy() |
|
area = float((box[2] - box[0]) * (box[3] - box[1])) |
|
|
|
if label_name == 'chip': |
|
chips.append({'box': box, 'area': area, 'voids': []}) |
|
elif label_name == 'void': |
|
voids.append({'box': box, 'area': area}) |
|
|
|
|
|
for void in voids: |
|
void_centroid = [ |
|
(void['box'][0] + void['box'][2]) / 2, |
|
(void['box'][1] + void['box'][3]) / 2 |
|
] |
|
for chip in chips: |
|
|
|
if (chip['box'][0] <= void_centroid[0] <= chip['box'][2] and |
|
chip['box'][1] <= void_centroid[1] <= chip['box'][3]): |
|
chip['voids'].append(void) |
|
break |
|
|
|
|
|
for idx, chip in enumerate(chips): |
|
chip_area = chip['area'] |
|
total_void_area = sum([float(void['area']) for void in chip['voids']]) |
|
max_void_area = max([float(void['area']) for void in chip['voids']], default=0) |
|
|
|
void_percentage = (total_void_area / chip_area) * 100 if chip_area > 0 else 0 |
|
max_void_percentage = (max_void_area / chip_area) * 100 if chip_area > 0 else 0 |
|
|
|
chips_data.append({ |
|
"chip_number": int(idx + 1), |
|
"chip_area": round(chip_area, 2), |
|
"void_percentage": round(void_percentage, 2), |
|
"max_void_percentage": round(max_void_percentage, 2) |
|
}) |
|
|
|
|
|
segmented_url = f"/static/uploads/segmented/automatic/{result_filename}" |
|
return jsonify({ |
|
"segmented_url": segmented_url, |
|
"table_data": { |
|
"image_name": file.filename, |
|
"chips": chips_data |
|
} |
|
}) |
|
|
|
except Exception as e: |
|
print(f"Error in automatic segmentation: {e}") |
|
return jsonify({'error': 'Segmentation failed.'}), 500 |
|
|
|
@app.route('/save_both', methods=['POST']) |
|
def save_both(): |
|
"""Save both the image and masks into the history folders.""" |
|
data = request.json |
|
image_name = data.get('image_name') |
|
|
|
if not image_name: |
|
return jsonify({'error': 'Image name not provided'}), 400 |
|
|
|
try: |
|
|
|
image_name = os.path.basename(image_name) |
|
print(f"Sanitized Image Name: {image_name}") |
|
|
|
|
|
input_image_path = os.path.join(UPLOAD_FOLDERS['input'], image_name) |
|
if not os.path.exists(input_image_path): |
|
print(f"Input image does not exist: {input_image_path}") |
|
return jsonify({'error': f'Input image not found: {input_image_path}'}), 404 |
|
|
|
|
|
image_history_path = os.path.join(HISTORY_FOLDERS['images'], image_name) |
|
os.makedirs(os.path.dirname(image_history_path), exist_ok=True) |
|
shutil.copy(input_image_path, image_history_path) |
|
print(f"Image saved to history: {image_history_path}") |
|
|
|
|
|
void_mask_path = os.path.join(UPLOAD_FOLDERS['mask_voids'], 'raw_mask.png') |
|
if os.path.exists(void_mask_path): |
|
void_mask_history_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png") |
|
os.makedirs(os.path.dirname(void_mask_history_path), exist_ok=True) |
|
shutil.copy(void_mask_path, void_mask_history_path) |
|
print(f"Voids mask saved to history: {void_mask_history_path}") |
|
else: |
|
print(f"Voids mask not found: {void_mask_path}") |
|
|
|
|
|
chip_mask_path = os.path.join(UPLOAD_FOLDERS['mask_chips'], 'raw_mask.png') |
|
if os.path.exists(chip_mask_path): |
|
chip_mask_history_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png") |
|
os.makedirs(os.path.dirname(chip_mask_history_path), exist_ok=True) |
|
shutil.copy(chip_mask_path, chip_mask_history_path) |
|
print(f"Chips mask saved to history: {chip_mask_history_path}") |
|
else: |
|
print(f"Chips mask not found: {chip_mask_path}") |
|
|
|
return jsonify({'message': 'Image and masks saved successfully!'}), 200 |
|
|
|
except Exception as e: |
|
print(f"Error saving files: {e}") |
|
return jsonify({'error': 'Failed to save files.', 'details': str(e)}), 500 |
|
|
|
@app.route('/get_history', methods=['GET']) |
|
def get_history(): |
|
try: |
|
saved_images = os.listdir(HISTORY_FOLDERS['images']) |
|
return jsonify({'status': 'success', 'images': saved_images}), 200 |
|
except Exception as e: |
|
return jsonify({'status': 'error', 'message': f'Failed to fetch history: {e}'}), 500 |
|
|
|
|
|
@app.route('/delete_history_item', methods=['POST']) |
|
def delete_history_item(): |
|
data = request.json |
|
image_name = data.get('image_name') |
|
|
|
if not image_name: |
|
return jsonify({'error': 'Image name not provided'}), 400 |
|
|
|
try: |
|
image_path = os.path.join(HISTORY_FOLDERS['images'], image_name) |
|
if os.path.exists(image_path): |
|
os.remove(image_path) |
|
|
|
void_mask_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png") |
|
if os.path.exists(void_mask_path): |
|
os.remove(void_mask_path) |
|
|
|
chip_mask_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png") |
|
if os.path.exists(chip_mask_path): |
|
os.remove(chip_mask_path) |
|
|
|
return jsonify({'message': f'{image_name} and associated masks deleted successfully.'}), 200 |
|
except Exception as e: |
|
return jsonify({'error': f'Failed to delete files: {e}'}), 500 |
|
|
|
|
|
status_lock = Lock() |
|
|
|
def update_training_status(key, value): |
|
"""Thread-safe update for training status.""" |
|
with status_lock: |
|
training_status[key] = value |
|
|
|
@app.route('/retrain_model', methods=['POST']) |
|
def retrain_model(): |
|
"""Handle retrain model workflow.""" |
|
global training_status |
|
|
|
if training_status.get('running', False): |
|
return jsonify({'error': 'Training is already in progress'}), 400 |
|
|
|
try: |
|
|
|
update_training_status('running', True) |
|
update_training_status('cancelled', False) |
|
logging.info("Training status updated. Starting training workflow.") |
|
|
|
|
|
backup_masks_and_images() |
|
logging.info("Backup completed successfully.") |
|
|
|
|
|
prepare_yolo_labels() |
|
logging.info("YOLO labels prepared successfully.") |
|
|
|
|
|
threading.Thread(target=run_yolo_training).start() |
|
return jsonify({'message': 'Training started successfully!'}), 200 |
|
|
|
except Exception as e: |
|
logging.error(f"Error during training preparation: {e}") |
|
update_training_status('running', False) |
|
return jsonify({'error': f"Failed to start training: {e}"}), 500 |
|
|
|
def prepare_yolo_labels(): |
|
"""Convert all masks into YOLO-compatible labels and copy images to the dataset folder.""" |
|
images_folder = HISTORY_FOLDERS['images'] |
|
train_labels_folder = DATASET_FOLDERS['train_labels'] |
|
train_images_folder = DATASET_FOLDERS['train_images'] |
|
val_labels_folder = DATASET_FOLDERS['val_labels'] |
|
val_images_folder = DATASET_FOLDERS['val_images'] |
|
|
|
|
|
os.makedirs(train_labels_folder, exist_ok=True) |
|
os.makedirs(train_images_folder, exist_ok=True) |
|
os.makedirs(val_labels_folder, exist_ok=True) |
|
os.makedirs(val_images_folder, exist_ok=True) |
|
|
|
try: |
|
all_images = [img for img in os.listdir(images_folder) if img.endswith(('.jpg', '.png'))] |
|
random.shuffle(all_images) |
|
|
|
|
|
split_idx = int(len(all_images) * 0.8) |
|
|
|
|
|
train_images = all_images[:split_idx] |
|
val_images = all_images[split_idx:] |
|
|
|
|
|
for image_name in train_images: |
|
process_image_and_mask( |
|
image_name, |
|
source_images_folder=images_folder, |
|
dest_images_folder=train_images_folder, |
|
dest_labels_folder=train_labels_folder |
|
) |
|
|
|
|
|
for image_name in val_images: |
|
process_image_and_mask( |
|
image_name, |
|
source_images_folder=images_folder, |
|
dest_images_folder=val_images_folder, |
|
dest_labels_folder=val_labels_folder |
|
) |
|
|
|
logging.info("YOLO labels prepared, and images split into train and validation successfully.") |
|
|
|
except Exception as e: |
|
logging.error(f"Error in preparing YOLO labels: {e}") |
|
raise |
|
|
|
import random |
|
|
|
def prepare_yolo_labels(): |
|
"""Convert all masks into YOLO-compatible labels and copy images to the dataset folder.""" |
|
images_folder = HISTORY_FOLDERS['images'] |
|
train_labels_folder = DATASET_FOLDERS['train_labels'] |
|
train_images_folder = DATASET_FOLDERS['train_images'] |
|
val_labels_folder = DATASET_FOLDERS['val_labels'] |
|
val_images_folder = DATASET_FOLDERS['val_images'] |
|
|
|
|
|
os.makedirs(train_labels_folder, exist_ok=True) |
|
os.makedirs(train_images_folder, exist_ok=True) |
|
os.makedirs(val_labels_folder, exist_ok=True) |
|
os.makedirs(val_images_folder, exist_ok=True) |
|
|
|
try: |
|
all_images = [img for img in os.listdir(images_folder) if img.endswith(('.jpg', '.png'))] |
|
random.shuffle(all_images) |
|
|
|
|
|
split_idx = int(len(all_images) * 0.8) |
|
|
|
|
|
train_images = all_images[:split_idx] |
|
val_images = all_images[split_idx:] |
|
|
|
|
|
for image_name in train_images: |
|
process_image_and_mask( |
|
image_name, |
|
source_images_folder=images_folder, |
|
dest_images_folder=train_images_folder, |
|
dest_labels_folder=train_labels_folder |
|
) |
|
|
|
|
|
for image_name in val_images: |
|
process_image_and_mask( |
|
image_name, |
|
source_images_folder=images_folder, |
|
dest_images_folder=val_images_folder, |
|
dest_labels_folder=val_labels_folder |
|
) |
|
|
|
logging.info("YOLO labels prepared, and images split into train and validation successfully.") |
|
|
|
except Exception as e: |
|
logging.error(f"Error in preparing YOLO labels: {e}") |
|
raise |
|
|
|
|
|
def process_image_and_mask(image_name, source_images_folder, dest_images_folder, dest_labels_folder): |
|
""" |
|
Process a single image and its masks, saving them in the appropriate YOLO format. |
|
""" |
|
try: |
|
image_path = os.path.join(source_images_folder, image_name) |
|
label_file_path = os.path.join(dest_labels_folder, f"{os.path.splitext(image_name)[0]}.txt") |
|
|
|
|
|
shutil.copy(image_path, os.path.join(dest_images_folder, image_name)) |
|
|
|
|
|
if os.path.exists(label_file_path): |
|
os.remove(label_file_path) |
|
|
|
|
|
void_mask_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png") |
|
if os.path.exists(void_mask_path): |
|
convert_mask_to_yolo( |
|
mask_path=void_mask_path, |
|
image_path=image_path, |
|
class_id=0, |
|
output_path=label_file_path |
|
) |
|
|
|
|
|
chip_mask_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png") |
|
if os.path.exists(chip_mask_path): |
|
convert_mask_to_yolo( |
|
mask_path=chip_mask_path, |
|
image_path=image_path, |
|
class_id=1, |
|
output_path=label_file_path, |
|
append=True |
|
) |
|
|
|
logging.info(f"Processed {image_name} into YOLO format.") |
|
except Exception as e: |
|
logging.error(f"Error processing {image_name}: {e}") |
|
raise |
|
|
|
def backup_masks_and_images(): |
|
"""Backup current masks and images from history folders.""" |
|
temp_backup_paths = { |
|
'voids': os.path.join(DATASET_FOLDERS['temp_backup'], 'masks/voids'), |
|
'chips': os.path.join(DATASET_FOLDERS['temp_backup'], 'masks/chips'), |
|
'images': os.path.join(DATASET_FOLDERS['temp_backup'], 'images') |
|
} |
|
|
|
|
|
for path in temp_backup_paths.values(): |
|
if os.path.exists(path): |
|
shutil.rmtree(path) |
|
os.makedirs(path, exist_ok=True) |
|
|
|
try: |
|
|
|
for file in os.listdir(HISTORY_FOLDERS['images']): |
|
src_image_path = os.path.join(HISTORY_FOLDERS['images'], file) |
|
dst_image_path = os.path.join(temp_backup_paths['images'], file) |
|
shutil.copy(src_image_path, dst_image_path) |
|
|
|
|
|
for file in os.listdir(HISTORY_FOLDERS['masks_void']): |
|
src_void_path = os.path.join(HISTORY_FOLDERS['masks_void'], file) |
|
dst_void_path = os.path.join(temp_backup_paths['voids'], file) |
|
shutil.copy(src_void_path, dst_void_path) |
|
|
|
|
|
for file in os.listdir(HISTORY_FOLDERS['masks_chip']): |
|
src_chip_path = os.path.join(HISTORY_FOLDERS['masks_chip'], file) |
|
dst_chip_path = os.path.join(temp_backup_paths['chips'], file) |
|
shutil.copy(src_chip_path, dst_chip_path) |
|
|
|
logging.info("Masks and images backed up successfully from history.") |
|
except Exception as e: |
|
logging.error(f"Error during backup: {e}") |
|
raise RuntimeError("Backup process failed.") |
|
|
|
def run_yolo_training(num_epochs=10): |
|
"""Run YOLO training process.""" |
|
global training_process |
|
|
|
try: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
data_cfg_path = os.path.join(BASE_DIR, "models/data.yaml") |
|
|
|
logging.info(f"Starting YOLO training on {device} with {num_epochs} epochs.") |
|
logging.info(f"Using dataset configuration: {data_cfg_path}") |
|
|
|
training_command = [ |
|
"yolo", |
|
"train", |
|
f"data={data_cfg_path}", |
|
f"model={os.path.join(DATASET_FOLDERS['models'], 'best.pt')}", |
|
f"device={device}", |
|
f"epochs={num_epochs}", |
|
"project=runs", |
|
"name=train" |
|
] |
|
|
|
training_process = subprocess.Popen( |
|
training_command, |
|
stdout=subprocess.PIPE, |
|
stderr=subprocess.STDOUT, |
|
text=True, |
|
env=os.environ.copy(), |
|
) |
|
|
|
|
|
for line in iter(training_process.stdout.readline, ''): |
|
print(line.strip()) |
|
logging.info(line.strip()) |
|
socketio.emit('training_update', {'message': line.strip()}) |
|
|
|
training_process.wait() |
|
|
|
if training_process.returncode == 0: |
|
finalize_training() |
|
else: |
|
raise RuntimeError("YOLO training process failed. Check logs for details.") |
|
except Exception as e: |
|
logging.error(f"Training error: {e}") |
|
restore_backup() |
|
|
|
|
|
socketio.emit('training_status', {'status': 'error', 'message': f"Training failed: {str(e)}"}) |
|
finally: |
|
update_training_status('running', False) |
|
training_process = None |
|
|
|
|
|
@socketio.on('cancel_training') |
|
def handle_cancel_training(): |
|
"""Cancel the YOLO training process.""" |
|
global training_process, training_status |
|
|
|
if not training_status.get('running', False): |
|
socketio.emit('button_update', {'action': 'retrain'}) |
|
return |
|
|
|
try: |
|
training_process.terminate() |
|
training_process.wait() |
|
training_status['running'] = False |
|
training_status['cancelled'] = True |
|
|
|
restore_backup() |
|
cleanup_train_val_directories() |
|
|
|
|
|
socketio.emit('button_update', {'action': 'retrain'}) |
|
socketio.emit('training_status', {'status': 'cancelled', 'message': 'Training was canceled by the user.'}) |
|
except Exception as e: |
|
logging.error(f"Error cancelling training: {e}") |
|
socketio.emit('training_status', {'status': 'error', 'message': str(e)}) |
|
|
|
def finalize_training(): |
|
"""Finalize training by promoting the new model and cleaning up.""" |
|
try: |
|
|
|
runs_dir = os.path.join(BASE_DIR, 'runs') |
|
if not os.path.exists(runs_dir): |
|
raise FileNotFoundError("Training runs directory does not exist.") |
|
|
|
|
|
latest_run = max( |
|
[os.path.join(runs_dir, d) for d in os.listdir(runs_dir)], |
|
key=os.path.getmtime |
|
) |
|
weights_dir = os.path.join(latest_run, 'weights') |
|
best_model_path = os.path.join(weights_dir, 'best.pt') |
|
|
|
if not os.path.exists(best_model_path): |
|
raise FileNotFoundError(f"'best.pt' not found in {weights_dir}.") |
|
|
|
|
|
old_model_folder = DATASET_FOLDERS['models_old'] |
|
os.makedirs(old_model_folder, exist_ok=True) |
|
existing_best_model = os.path.join(DATASET_FOLDERS['models'], 'best.pt') |
|
|
|
if os.path.exists(existing_best_model): |
|
timestamp = time.strftime("%Y%m%d_%H%M%S") |
|
shutil.move(existing_best_model, os.path.join(old_model_folder, f"old_{timestamp}.pt")) |
|
logging.info(f"Old model backed up to {old_model_folder}.") |
|
|
|
|
|
new_model_dest = os.path.join(DATASET_FOLDERS['models'], 'best.pt') |
|
shutil.move(best_model_path, new_model_dest) |
|
logging.info(f"New model saved to {new_model_dest}.") |
|
|
|
|
|
socketio.emit('training_status', { |
|
'status': 'completed', |
|
'message': 'Training completed successfully! Model saved as best.pt.' |
|
}) |
|
|
|
|
|
cleanup_train_val_directories() |
|
logging.info("Train and validation directories cleaned up successfully.") |
|
|
|
except Exception as e: |
|
logging.error(f"Error finalizing training: {e}") |
|
|
|
socketio.emit('training_status', {'status': 'error', 'message': f"Error finalizing training: {str(e)}"}) |
|
|
|
def restore_backup(): |
|
"""Restore the dataset and masks from the backup.""" |
|
try: |
|
temp_backup = DATASET_FOLDERS['temp_backup'] |
|
shutil.copytree(os.path.join(temp_backup, 'masks/voids'), UPLOAD_FOLDERS['mask_voids'], dirs_exist_ok=True) |
|
shutil.copytree(os.path.join(temp_backup, 'masks/chips'), UPLOAD_FOLDERS['mask_chips'], dirs_exist_ok=True) |
|
shutil.copytree(os.path.join(temp_backup, 'images'), UPLOAD_FOLDERS['input'], dirs_exist_ok=True) |
|
logging.info("Backup restored successfully.") |
|
except Exception as e: |
|
logging.error(f"Error restoring backup: {e}") |
|
|
|
@app.route('/cancel_training', methods=['POST']) |
|
def cancel_training(): |
|
global training_process |
|
|
|
if training_process is None: |
|
logging.error("No active training process to terminate.") |
|
return jsonify({'error': 'No active training process to cancel.'}), 400 |
|
|
|
try: |
|
training_process.terminate() |
|
training_process.wait() |
|
training_process = None |
|
|
|
|
|
update_training_status('running', False) |
|
update_training_status('cancelled', True) |
|
|
|
|
|
best_model_path = os.path.join(DATASET_FOLDERS['models'], 'best.pt') |
|
if os.path.exists(best_model_path): |
|
logging.info(f"Model already saved as best.pt at {best_model_path}.") |
|
socketio.emit('button_update', {'action': 'revert'}) |
|
else: |
|
logging.info("Training canceled, but no new model was saved.") |
|
|
|
|
|
restore_backup() |
|
cleanup_train_val_directories() |
|
|
|
|
|
socketio.emit('training_status', {'status': 'cancelled', 'message': 'Training was canceled by the user.'}) |
|
return jsonify({'message': 'Training canceled and data restored successfully.'}), 200 |
|
|
|
except Exception as e: |
|
logging.error(f"Error cancelling training: {e}") |
|
return jsonify({'error': f"Failed to cancel training: {e}"}), 500 |
|
|
|
@app.route('/clear_history', methods=['POST']) |
|
def clear_history(): |
|
try: |
|
for folder in [HISTORY_FOLDERS['images'], HISTORY_FOLDERS['masks_chip'], HISTORY_FOLDERS['masks_void']]: |
|
shutil.rmtree(folder, ignore_errors=True) |
|
os.makedirs(folder, exist_ok=True) |
|
return jsonify({'message': 'History cleared successfully!'}), 200 |
|
except Exception as e: |
|
return jsonify({'error': f'Failed to clear history: {e}'}), 500 |
|
|
|
@app.route('/training_status', methods=['GET']) |
|
def get_training_status(): |
|
"""Return the current training status.""" |
|
if training_status.get('running', False): |
|
return jsonify({'status': 'running', 'message': 'Training in progress.'}), 200 |
|
elif training_status.get('cancelled', False): |
|
return jsonify({'status': 'cancelled', 'message': 'Training was cancelled.'}), 200 |
|
return jsonify({'status': 'idle', 'message': 'No training is currently running.'}), 200 |
|
|
|
def cleanup_train_val_directories(): |
|
"""Clear the train and validation directories.""" |
|
try: |
|
for folder in [DATASET_FOLDERS['train_images'], DATASET_FOLDERS['train_labels'], |
|
DATASET_FOLDERS['val_images'], DATASET_FOLDERS['val_labels']]: |
|
shutil.rmtree(folder, ignore_errors=True) |
|
os.makedirs(folder, exist_ok=True) |
|
logging.info("Train and validation directories cleaned up successfully.") |
|
except Exception as e: |
|
logging.error(f"Error cleaning up train/val directories: {e}") |
|
|
|
|
|
if __name__ == '__main__': |
|
multiprocessing.set_start_method('spawn') |
|
app.run(debug=True, use_reloader=False) |
|
|
|
|
|
|