Spaces:
Running
Running
| from flask import Flask, request, jsonify, render_template, url_for | |
| from flask_socketio import SocketIO | |
| import threading | |
| from ultralytics import YOLO | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import importlib | |
| from segment_anything import sam_model_registry, SamPredictor | |
| import os | |
| from werkzeug.utils import secure_filename | |
| import logging | |
| import json | |
| import shutil | |
| import sys | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| app = Flask(__name__) | |
| socketio = SocketIO(app) | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| class Config: | |
| BASE_DIR = os.path.abspath(os.path.dirname(__file__)) | |
| UPLOAD_FOLDER = os.path.join(BASE_DIR, 'static', 'uploads') | |
| SAM_RESULT_FOLDER = os.path.join(BASE_DIR, 'static', 'sam','sam_results') | |
| YOLO_RESULT_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','yolo_results') | |
| YOLO_TRAIN_IMAGE_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo','train','images') | |
| YOLO_TRAIN_LABEL_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo','train','labels') | |
| AREA_DATA_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','area_data') | |
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
| MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size | |
| SAM_2 = os.path.join(BASE_DIR, 'static', 'sam',"sam2.1_hiera_tiny.pt") | |
| YOLO_PATH = os.path.join(BASE_DIR, 'static', 'yolo', "model_yolo.pt") | |
| RETRAINED_MODEL_PATH = os.path.join(BASE_DIR, 'static', 'yolo', "model_retrained.pt") | |
| DATA_PATH = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo', "data.yaml") | |
| app.config.from_object(Config) | |
| # Ensure directories exist | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| os.makedirs(app.config['SAM_RESULT_FOLDER'], exist_ok=True) | |
| os.makedirs(app.config['YOLO_RESULT_FOLDER'], exist_ok=True) | |
| os.makedirs(app.config['YOLO_TRAIN_IMAGE_FOLDER'], exist_ok=True) | |
| os.makedirs(app.config['YOLO_TRAIN_LABEL_FOLDER'], exist_ok=True) | |
| os.makedirs(app.config['AREA_DATA_FOLDER'], exist_ok=True) | |
| # Initialize Yolo model | |
| try: | |
| model = YOLO(app.config['YOLO_PATH']) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize YOLO model: {str(e)}") | |
| raise | |
| try: | |
| sam2_checkpoint = app.config['SAM_2'] | |
| model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml" | |
| sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu") | |
| predictor = SAM2ImagePredictor(sam2_model) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize SAM model: {str(e)}") | |
| raise | |
| def allowed_file(filename): | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS'] | |
| def scale_coordinates(coords, original_dims, target_dims): | |
| """ | |
| Scale coordinates from one dimension space to another. | |
| Args: | |
| coords: List of [x, y] coordinates | |
| original_dims: Tuple of (width, height) of original space | |
| target_dims: Tuple of (width, height) of target space | |
| Returns: | |
| Scaled coordinates | |
| """ | |
| scale_x = target_dims[0] / original_dims[0] | |
| scale_y = target_dims[1] / original_dims[1] | |
| return [ | |
| [int(coord[0] * scale_x), int(coord[1] * scale_y)] | |
| for coord in coords | |
| ] | |
| def scale_box(box, original_dims, target_dims): | |
| """ | |
| Scale bounding box coordinates from one dimension space to another. | |
| Args: | |
| box: List of [x1, y1, x2, y2] coordinates | |
| original_dims: Tuple of (width, height) of original space | |
| target_dims: Tuple of (width, height) of target space | |
| Returns: | |
| Scaled box coordinates | |
| """ | |
| scale_x = target_dims[0] / original_dims[0] | |
| scale_y = target_dims[1] / original_dims[1] | |
| return [ | |
| int(box[0] * scale_x), # x1 | |
| int(box[1] * scale_y), # y1 | |
| int(box[2] * scale_x), # x2 | |
| int(box[3] * scale_y) # y2 | |
| ] | |
| def retrain_model_fn(): | |
| # Parameters for retraining | |
| data_path = app.config['DATA_PATH'] | |
| epochs = 5 | |
| img_size = 640 | |
| batch_size = 8 | |
| # Start training with YOLO, using event listeners for epoch completion | |
| for epoch in range(epochs): | |
| # Train the model for one epoch, here we simulate with a loop | |
| model.train( | |
| data=data_path, | |
| epochs=1, # Use 1 epoch per call to get individual progress | |
| imgsz=img_size, | |
| batch=batch_size, | |
| device="cpu" # Adjust based on system capabilities | |
| ) | |
| # Emit an update to the client after each epoch | |
| socketio.emit('training_update', { | |
| 'epoch': epoch + 1, | |
| 'status': f"Epoch {epoch + 1} complete" | |
| }) | |
| # Emit a message once training is complete | |
| socketio.emit('training_complete', {'status': "Retraining complete"}) | |
| model.save(app.config['YOLO_PATH']) | |
| logger.info("Model retrained successfully") | |
| def index(): | |
| return render_template('index.html') | |
| def yolo(): | |
| return render_template('yolo.html') | |
| def upload_sam_file(): | |
| """ | |
| Handles SAM image upload and embeds the image into the predictor instance. | |
| Returns: | |
| JSON response with 'message', 'image_url', 'filename', and 'dimensions' keys | |
| on success, or 'error' key with an appropriate error message on failure. | |
| """ | |
| try: | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file part'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No selected file'}), 400 | |
| if not allowed_file(file.filename): | |
| return jsonify({'error': 'Invalid file type. Allowed types: PNG, JPG, JPEG'}), 400 | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| # Set the image for predictor right after upload | |
| image = cv2.imread(filepath) | |
| if image is None: | |
| return jsonify({'error': 'Failed to load uploaded image'}), 500 | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| predictor.set_image(image) | |
| logger.info("Image embedded successfully") | |
| # Get image dimensions | |
| height, width = image.shape[:2] | |
| image_url = url_for('static', filename=f'uploads/{filename}') | |
| logger.info(f"File uploaded successfully: {filepath}") | |
| return jsonify({ | |
| 'message': 'File uploaded successfully', | |
| 'image_url': image_url, | |
| 'filename': filename, | |
| 'dimensions': { | |
| 'width': width, | |
| 'height': height | |
| } | |
| }) | |
| except Exception as e: | |
| logger.error(f"Upload error: {str(e)}") | |
| return jsonify({'error': 'Server error during upload'}), 500 | |
| def upload_yolo_file(): | |
| """ | |
| Upload a YOLO image file | |
| This endpoint allows a POST request containing a single image file. The file is | |
| saved to the uploads folder and the image is embedded into the YOLO model. | |
| Returns a JSON response with the following keys: | |
| - message: a success message | |
| - image_url: the URL of the uploaded image | |
| - filename: the name of the uploaded file | |
| If an error occurs, the JSON response will contain an 'error' key with a | |
| descriptive error message. | |
| """ | |
| try: | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file part'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No selected file'}), 400 | |
| if not allowed_file(file.filename): | |
| return jsonify({'error': 'Invalid file type. Allowed types: PNG, JPG, JPEG'}), 400 | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| image_url = url_for('static', filename=f'uploads/{filename}') | |
| logger.info(f"File uploaded successfully: {filepath}") | |
| return jsonify({ | |
| 'message': 'File uploaded successfully', | |
| 'image_url': image_url, | |
| 'filename': filename, | |
| }) | |
| except Exception as e: | |
| logger.error(f"Upload error: {str(e)}") | |
| return jsonify({'error': 'Server error during upload'}), 500 | |
| def generate_mask(): | |
| """ | |
| Generate a mask for a given image using the YOLO model | |
| @param data: a JSON object containing the following keys: | |
| - filename: the name of the image file | |
| - normalized_void_points: a list of normalized 2D points (x, y) representing the voids | |
| - normalized_component_boxes: a list of normalized 2D bounding boxes (x, y, w, h) representing the components | |
| @return: a JSON object containing the following keys: | |
| - status: a string indicating the status of the request | |
| - train_image_url: the URL of the saved train image | |
| - result_path: the URL of the saved result image | |
| """ | |
| try: | |
| data = request.json | |
| normalized_void_points = data.get('void_points', []) | |
| normalized_component_boxes = data.get('component_boxes', []) | |
| filename = data.get('filename', '') | |
| if not filename: | |
| return jsonify({'error': 'No filename provided'}), 400 | |
| image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| if not os.path.exists(image_path): | |
| return jsonify({'error': 'Image file not found'}), 404 | |
| # Read image | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| return jsonify({'error': 'Failed to load image'}), 500 | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image_height, image_width = image.shape[:2] | |
| # Denormalize coordinates back to pixel values | |
| void_points = [ | |
| [int(point[0] * image_width), int(point[1] * image_height)] | |
| for point in normalized_void_points | |
| ] | |
| logger.info(f"Void points: {void_points}") | |
| component_boxes = [ | |
| [ | |
| int(box[0] * image_width), | |
| int(box[1] * image_height), | |
| int(box[2] * image_width), | |
| int(box[3] * image_height) | |
| ] | |
| for box in normalized_component_boxes | |
| ] | |
| logger.info(f"Void points: {void_points}") | |
| # Create a list to store individual void masks | |
| void_masks = [] | |
| # Process void points one by one | |
| for point in void_points: | |
| # Convert point to correct format: [N, 2] array | |
| point_coord = np.array([[point[0], point[1]]]) | |
| point_label = np.array([1]) # Single label | |
| masks, scores, _ = predictor.predict( | |
| point_coords=point_coord, | |
| point_labels=point_label, | |
| multimask_output=True # Get multiple masks | |
| ) | |
| if len(masks) > 0: # Check if any masks were generated | |
| # Get the mask with highest score | |
| best_mask_idx = np.argmax(scores) | |
| void_masks.append(masks[best_mask_idx]) | |
| logger.info(f"Processed void point {point} with score {scores[best_mask_idx]}") | |
| # Process component boxes | |
| component_masks = [] | |
| if component_boxes: | |
| for box in component_boxes: | |
| # Convert box to correct format: [2, 2] array | |
| box_np = np.array([[box[0], box[1]], [box[2], box[3]]]) | |
| masks, scores, _ = predictor.predict( | |
| box=box_np, | |
| multimask_output=True | |
| ) | |
| if len(masks) > 0: | |
| best_mask_idx = np.argmax(scores) | |
| component_masks.append(masks[best_mask_idx]) | |
| logger.info(f"Processed component box {box}") | |
| # Create visualization with different colors for each void | |
| combined_image = image.copy() | |
| # Font settings for labels | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.6 | |
| font_color = (0,0,0) # White text color | |
| font_thickness = 1 | |
| background_color = (255, 255, 255) # White background for text | |
| # Helper function to get bounding box coordinates | |
| def get_bounding_box(mask): | |
| coords = np.column_stack(np.where(mask)) | |
| x_min, y_min = coords.min(axis=0) | |
| x_max, y_max = coords.max(axis=0) | |
| return (x_min, y_min, x_max, y_max) | |
| # Helper function to add text with background | |
| def put_text_with_background(img, text, pos): | |
| # Calculate text size | |
| (text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, font_thickness) | |
| # Define the rectangle coordinates for background | |
| background_tl = (pos[0], pos[1] - text_h - 2) | |
| background_br = (pos[0] + text_w, pos[1] + 2) | |
| # Draw white rectangle as background | |
| cv2.rectangle(img, background_tl, background_br, background_color, -1) | |
| # Put the text over the background rectangle | |
| cv2.putText(img, text, pos, font, font_scale, font_color, font_thickness, cv2.LINE_AA) | |
| def get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, img_width, img_height): | |
| # Default to top-right of bounding box | |
| x_pos = min(y_max, img_width - text_w - 10) # Keep 10px margin from the right | |
| y_pos = max(x_min + text_h + 5, text_h + 5) # Keep 5px margin from the top | |
| return x_pos, y_pos | |
| # Apply void masks with different colors | |
| for mask in void_masks: | |
| mask = mask.astype(bool) | |
| combined_image[mask, 0] = np.clip(0.5 * image[mask, 0] + 0.5 * 255, 0, 255) # Red channel with transparency | |
| combined_image[mask, 1] = np.clip(0.5 * image[mask, 1], 0, 255) # Green channel reduced | |
| combined_image[mask, 2] = np.clip(0.5 * image[mask, 2], 0, 255) | |
| logger.info("Mask Drawn") | |
| # Apply component masks in green | |
| for mask in component_masks: | |
| mask = mask.astype(bool) | |
| # Only apply green where there is no red overlay | |
| non_red_area = mask & ~np.any([void_mask for void_mask in void_masks], axis=0) | |
| combined_image[non_red_area, 0] = np.clip(0.5 * image[non_red_area, 0], 0, 255) # Reduced red channel | |
| combined_image[non_red_area, 1] = np.clip(0.5 * image[non_red_area, 1] + 0.5 * 255, 0, 255) # Green channel | |
| combined_image[non_red_area, 2] = np.clip(0.5 * image[non_red_area, 2], 0, 255) | |
| logger.info("Mask Drawn") | |
| # Add labels on top of masks | |
| for i,mask in enumerate(void_masks): | |
| x_min, y_min, x_max, y_max = get_bounding_box(mask) | |
| (text_w, text_h), _ = cv2.getTextSize("Void", font, font_scale, font_thickness) | |
| label_position = get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, combined_image.shape[1], combined_image.shape[0]) | |
| put_text_with_background(combined_image, f"Void {i+1}", label_position) | |
| for i,mask in enumerate(component_masks): | |
| i=i+1 | |
| x_min, y_min, x_max, y_max = get_bounding_box(mask) | |
| (text_w, text_h), _ = cv2.getTextSize("Component", font, font_scale, font_thickness) | |
| label_position = get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, combined_image.shape[1], combined_image.shape[0]) | |
| put_text_with_background(combined_image, f"Component {i}", label_position) | |
| # Prepare an empty list to store the output in the required format | |
| mask_coordinates = [] | |
| for mask in void_masks: | |
| # Get contours from the mask | |
| contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| # Image dimensions | |
| height, width = mask.shape | |
| # For each contour, extract the normalized coordinates | |
| for contour in contours: | |
| contour_points = contour.reshape(-1, 2) # Flatten to (N, 2) where N is the number of points | |
| normalized_points = contour_points / [width, height] # Normalize to (0, 1) | |
| class_id = 1 # 1 for voids | |
| row = [class_id] + normalized_points.flatten().tolist() # Flatten and add the class | |
| mask_coordinates.append(row) | |
| for mask in component_masks: | |
| # Get contours from the mask | |
| contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| # Filter to keep only the largest contour | |
| contours = sorted(contours, key=cv2.contourArea, reverse=True) | |
| largest_contour = [contours[0]] if contours else [] | |
| # Image dimensions | |
| height, width = mask.shape | |
| # For each contour, extract the normalized coordinates | |
| for contour in largest_contour: | |
| contour_points = contour.reshape(-1, 2) # Flatten to (N, 2) where N is the number of points | |
| normalized_points = contour_points / [width, height] # Normalize to (0, 1) | |
| class_id = 0 # for components | |
| row = [class_id] + normalized_points.flatten().tolist() # Flatten and add the class | |
| mask_coordinates.append(row) | |
| mask_coordinates_filename = f'{filename}.txt' # Create a unique filename | |
| mask_coordinates_path = os.path.join(app.config['YOLO_TRAIN_LABEL_FOLDER'], mask_coordinates_filename) | |
| with open(mask_coordinates_path, "w") as file: | |
| for row in mask_coordinates: | |
| # Join elements of the row into a string with spaces in between and write to the file | |
| file.write(" ".join(map(str, row)) + "\n") | |
| # Save train image | |
| train_image_filepath = os.path.join(app.config['YOLO_TRAIN_IMAGE_FOLDER'], filename) | |
| shutil.copy(image_path, train_image_filepath) | |
| train_image_url = url_for('static', filename=f'yolo/dataset_yolo/train/images/{filename}') | |
| # Save result | |
| result_filename = f'segmented_{filename}' | |
| result_path = os.path.join(app.config['SAM_RESULT_FOLDER'], result_filename) | |
| plt.imsave(result_path, combined_image) | |
| logger.info("Mask generation completed successfully") | |
| return jsonify({ | |
| 'status': 'success', | |
| 'train_image_url':train_image_url, | |
| 'result_path': url_for('static', filename=f'sam/sam_results/{result_filename}') | |
| }) | |
| except Exception as e: | |
| logger.error(f"Mask generation error: {str(e)}") | |
| return jsonify({'error': str(e)}), 500 | |
| def classify(): | |
| """ | |
| Classify an image and return the classification result, area data, and the annotated image. | |
| Request body should contain a JSON object with a single key 'filename' specifying the image file to be classified. | |
| Returns a JSON object with the following keys: | |
| - status: 'success' if the classification is successful, 'error' if there is an error. | |
| - result_path: URL of the annotated image. | |
| - area_data: a list of dictionaries containing the area and overlap statistics for each component. | |
| - area_data_path: URL of the JSON file containing the area data. | |
| If there is an error, returns a JSON object with a single key 'error' containing the error message. | |
| """ | |
| try: | |
| data = request.json | |
| filename = data.get('filename', '') | |
| if not filename: | |
| return jsonify({'error': 'No filename provided'}), 400 | |
| image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| if not os.path.exists(image_path): | |
| return jsonify({'error': 'Image file not found'}), 404 | |
| # Read image | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| return jsonify({'error': 'Failed to load image'}), 500 | |
| results = model(image) | |
| result = results[0] | |
| component_masks = [] | |
| void_masks = [] | |
| # Extract masks and labels from results | |
| for mask, label in zip(result.masks.data, result.boxes.cls): | |
| mask_array = mask.cpu().numpy().astype(bool) # Convert to a binary mask (boolean array) | |
| if label == 1: # Assuming label '1' represents void | |
| void_masks.append(mask_array) | |
| elif label == 0: # Assuming label '0' represents component | |
| component_masks.append(mask_array) | |
| # Calculate area and overlap statistics | |
| area_data = [] | |
| for i, component_mask in enumerate(component_masks): | |
| component_area = np.sum(component_mask).item() # Total component area in pixels | |
| void_area_within_component = 0 | |
| max_void_area_percentage = 0 | |
| # Calculate overlap of each void mask with the component mask | |
| for void_mask in void_masks: | |
| overlap_area = np.sum(void_mask & component_mask).item() # Overlapping area | |
| void_area_within_component += overlap_area | |
| void_area_percentage = (overlap_area / component_area) * 100 if component_area > 0 else 0 | |
| max_void_area_percentage = max(max_void_area_percentage, void_area_percentage) | |
| # Append data for this component | |
| area_data.append({ | |
| "Image": filename, | |
| 'Component': f'Component {i+1}', | |
| 'Area': component_area, | |
| 'Void Area (pixels)': void_area_within_component, | |
| 'Void Area %': void_area_within_component / component_area * 100 if component_area > 0 else 0, | |
| 'Max Void Area %': max_void_area_percentage | |
| }) | |
| area_data_filename = f'area_data_{filename.split("/")[-1]}.json' # Create a unique filename | |
| area_data_path = os.path.join(app.config['AREA_DATA_FOLDER'], area_data_filename) | |
| with open(area_data_path, 'w') as json_file: | |
| json.dump(area_data, json_file, indent=4) | |
| annotated_image = result.plot() | |
| output_filename = f'output_{filename}' | |
| output_image_path = os.path.join(app.config['YOLO_RESULT_FOLDER'], output_filename) | |
| plt.imsave(output_image_path, annotated_image) | |
| logger.info("Classification completed successfully") | |
| return jsonify({ | |
| 'status': 'success', | |
| 'result_path': url_for('static', filename=f'yolo/yolo_results/{output_filename}'), | |
| 'area_data': area_data, | |
| 'area_data_path': url_for('static', filename=f'yolo/area_data/{area_data_filename}') | |
| }) | |
| except Exception as e: | |
| logger.error(f"Classification error: {str(e)}") | |
| return jsonify({'error': str(e)}), 500 | |
| retraining_status = { | |
| 'status': 'idle', | |
| 'progress': None, | |
| 'message': None | |
| } | |
| def start_retraining(): | |
| """ | |
| Start the model retraining process. | |
| If the request is a POST, start the model retraining process in a separate thread. | |
| If the request is a GET, render the retraining page. | |
| Returns: | |
| A JSON response with the status of the retraining process, or a rendered HTML page. | |
| """ | |
| if request.method == 'POST': | |
| # Reset status | |
| global retraining_status | |
| retraining_status['status'] = 'in_progress' | |
| retraining_status['progress'] = 'Initializing' | |
| # Start retraining in a separate thread | |
| threading.Thread(target=retrain_model_fn).start() | |
| return jsonify({'status': 'started'}) | |
| else: | |
| # GET request - render the retraining page | |
| return render_template('retrain.html') | |
| # Event handler for client connection | |
| def handle_connect(): | |
| print('Client connected') | |
| if __name__ == '__main__': | |
| app.run(port=5001, debug=True) |