Spaces:
Runtime error
Runtime error
| # Gradio App Code (based on paste.txt) with Triton Integration and Fallback | |
| import psutil | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from huggingface_hub import snapshot_download | |
| import rasterio | |
| from rasterio.enums import Resampling | |
| from rasterio.plot import reshape_as_image | |
| import sys | |
| import time # For potential timeouts/delays | |
| # --- Triton Client Imports --- | |
| try: | |
| import tritonclient.http as httpclient | |
| import tritonclient.utils as triton_utils # For InferenceServerException | |
| TRITON_CLIENT_AVAILABLE = True | |
| except ImportError: | |
| print("WARNING: tritonclient is not installed. Triton inference will not be available.") | |
| print("Install using: pip install tritonclient[all]") | |
| TRITON_CLIENT_AVAILABLE = False | |
| httpclient = None # Define dummy to avoid NameErrors later | |
| triton_utils = None | |
| # --- Configuration --- | |
| # Download the entire repository for local fallback and utils | |
| repo_id = "truthdotphd/cloud-detection" | |
| repo_subdir = "." | |
| print(f"Downloading/Checking Hugging Face repo '{repo_id}'...") | |
| repo_dir = snapshot_download(repo_id=repo_id, local_dir=repo_subdir, local_dir_use_symlinks=False) # Use False for symlinks in Gradio/Docker usually | |
| print(f"Repo downloaded/cached at: {repo_dir}") | |
| # Add the repository directory to the Python path for local modules | |
| sys.path.append(repo_dir) | |
| # Import the necessary functions from the downloaded modules for LOCAL fallback | |
| try: | |
| # Adjust path if omnicloudmask is inside a subfolder | |
| omnicloudmask_path = os.path.join(repo_dir, "omnicloudmask") | |
| if os.path.isdir(omnicloudmask_path): | |
| sys.path.append(omnicloudmask_path) # Add subfolder if exists | |
| from omnicloudmask import predict_from_array | |
| LOCAL_MODEL_AVAILABLE = True | |
| print("Local omnicloudmask module loaded successfully.") | |
| except ImportError as e: | |
| print(f"ERROR: Could not import local 'predict_from_array' from omnicloudmask: {e}") | |
| print("Local fallback will not be available.") | |
| LOCAL_MODEL_AVAILABLE = False | |
| predict_from_array = None # Define dummy | |
| # --- Triton Server Configuration --- | |
| TRITON_IP = "206.123.129.87" # Use the public IP provided | |
| HTTP_TRITON_URL = f"{TRITON_IP}:8000" | |
| # GRPC_TRITON_URL = f"{TRITON_IP}:8001" # Keep for potential future use | |
| TRITON_MODEL_NAME = "cloud-detection" # Ensure this matches your deployed model name | |
| TRITON_INPUT_NAME = "input_jp2_bytes" # Ensure this matches your model's config.pbtxt | |
| TRITON_OUTPUT_NAME = "output_mask" # Ensure this matches your model's config.pbtxt | |
| TRITON_TIMEOUT_SECONDS = 300 # 5 minutes timeout for connection/network | |
| # --- Utility Functions (mostly from paste.txt) --- | |
| def visualize_rgb(red_file, green_file, blue_file): | |
| """ | |
| Create and display an RGB visualization immediately after images are uploaded. | |
| (Modified slightly: doesn't need nir_file) | |
| """ | |
| if not all([red_file, green_file, blue_file]): | |
| return None | |
| try: | |
| # Load bands (using load_band utility) | |
| # Get target shape from red band | |
| with rasterio.open(red_file) as src: | |
| target_height = src.height | |
| target_width = src.width | |
| blue_data = load_band(blue_file) | |
| green_data = load_band(green_file) | |
| red_data = load_band(red_file) | |
| # Compute max values for scaling (simple approach) | |
| red_max = np.percentile(red_data[red_data>0], 98) if np.any(red_data>0) else 1.0 | |
| green_max = np.percentile(green_data[green_data>0], 98) if np.any(green_data>0) else 1.0 | |
| blue_max = np.percentile(blue_data[blue_data>0], 98) if np.any(blue_data>0) else 1.0 | |
| # Create RGB image for visualization with dynamic normalization | |
| rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32) | |
| epsilon = 1e-10 | |
| rgb_image[:, :, 0] = np.clip(red_data / (red_max + epsilon), 0, 1) | |
| rgb_image[:, :, 1] = np.clip(green_data / (green_max + epsilon), 0, 1) | |
| rgb_image[:, :, 2] = np.clip(blue_data / (blue_max + epsilon), 0, 1) | |
| # Simple brightness/contrast adjustment (gamma correction) | |
| gamma = 1.8 | |
| rgb_image_enhanced = np.power(rgb_image, 1/gamma) | |
| # Convert to uint8 for display | |
| rgb_display = (rgb_image_enhanced * 255).astype(np.uint8) | |
| return rgb_display | |
| except Exception as e: | |
| print(f"Error generating RGB preview: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def visualize_jp2(file_path): | |
| """ | |
| Visualize a single JP2 file. (Unchanged from paste.txt) | |
| """ | |
| try: | |
| with rasterio.open(file_path) as src: | |
| data = src.read(1) | |
| # Check if data is all zero or invalid | |
| if np.all(data == 0) or np.ptp(data) == 0: | |
| print(f"Warning: Data in {file_path} is constant or zero. Cannot normalize.") | |
| # Return a black image or handle as appropriate | |
| return np.zeros((src.height, src.width, 3), dtype=np.uint8) | |
| # Normalize the data for visualization | |
| data_norm = (data - np.min(data)) / (np.max(data) - np.min(data)) | |
| # Apply a colormap for better visualization | |
| cmap = plt.get_cmap('viridis') | |
| colored_image = cmap(data_norm) | |
| # Convert to 8-bit for display | |
| return (colored_image[:, :, :3] * 255).astype(np.uint8) | |
| except Exception as e: | |
| print(f"Error visualizing JP2 file {file_path}: {e}") | |
| return None | |
| def load_band(file_path, resample=False, target_height=None, target_width=None): | |
| """ | |
| Load a single band from a raster file with optional resampling. (Unchanged from paste.txt) | |
| """ | |
| try: | |
| with rasterio.open(file_path) as src: | |
| if resample and target_height is not None and target_width is not None: | |
| # Ensure output shape matches target channels (1 for single band) | |
| out_shape = (1, target_height, target_width) | |
| band_data = src.read( | |
| out_shape=out_shape, | |
| resampling=Resampling.bilinear | |
| )[0].astype(np.float32) # Read only the first band after resampling | |
| else: | |
| band_data = src.read(1).astype(np.float32) # Read only the first band | |
| return band_data | |
| except Exception as e: | |
| print(f"Error loading band {file_path}: {e}") | |
| raise # Re-raise error to be caught by calling function | |
| def prepare_input_array(red_file, green_file, blue_file, nir_file): | |
| """ | |
| Prepare a stacked array (R, G, NIR) for the LOCAL model and an RGB image for visualization. | |
| (Slightly modified from paste.txt to handle potential loading errors) | |
| Returns: | |
| prediction_array (np.ndarray): Stacked array (R,G,NIR) for local model, or None on error. | |
| rgb_image_enhanced (np.ndarray): RGB image (0-1 float) for visualization, or None on error. | |
| """ | |
| try: | |
| # Get dimensions from red band to use for resampling | |
| with rasterio.open(red_file) as src: | |
| target_height = src.height | |
| target_width = src.width | |
| # Load bands (resample NIR band to match 10m resolution) | |
| blue_data = load_band(blue_file) # Needed for RGB viz | |
| green_data = load_band(green_file) | |
| red_data = load_band(red_file) | |
| nir_data = load_band( | |
| nir_file, | |
| resample=True, | |
| target_height=target_height, | |
| target_width=target_width | |
| ) | |
| # --- Prepare RGB Image for Visualization (similar to visualize_rgb but returns float array) --- | |
| red_max = np.percentile(red_data[red_data>0], 98) if np.any(red_data>0) else 1.0 | |
| green_max = np.percentile(green_data[green_data>0], 98) if np.any(green_data>0) else 1.0 | |
| blue_max = np.percentile(blue_data[blue_data>0], 98) if np.any(blue_data>0) else 1.0 | |
| epsilon = 1e-10 | |
| rgb_image = np.zeros((target_height, target_width, 3), dtype=np.float32) | |
| rgb_image[:, :, 0] = np.clip(red_data / (red_max + epsilon), 0, 1) | |
| rgb_image[:, :, 1] = np.clip(green_data / (green_max + epsilon), 0, 1) | |
| rgb_image[:, :, 2] = np.clip(blue_data / (blue_max + epsilon), 0, 1) | |
| # Apply gamma correction for enhancement | |
| gamma = 1.8 | |
| rgb_image_enhanced = np.power(rgb_image, 1/gamma) | |
| # --- End RGB Image Preparation --- | |
| # Stack bands in CHW format for LOCAL cloud mask prediction (red, green, nir) | |
| # Ensure all bands have the same shape before stacking | |
| if not (red_data.shape == green_data.shape == nir_data.shape): | |
| print("ERROR: Band shapes mismatch after loading/resampling!") | |
| print(f"Shapes - Red: {red_data.shape}, Green: {green_data.shape}, NIR: {nir_data.shape}") | |
| return None, None # Indicate error | |
| prediction_array = np.stack([red_data, green_data, nir_data], axis=0) # CHW format | |
| print(f"Local prediction array shape: {prediction_array.shape}") | |
| print(f"RGB visualization image shape: {rgb_image_enhanced.shape}") | |
| return prediction_array, rgb_image_enhanced | |
| except Exception as e: | |
| print(f"Error during input preparation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, None # Indicate error | |
| def visualize_cloud_mask(rgb_image, pred_mask): | |
| """ | |
| Create a visualization of the cloud mask overlaid on the RGB image. | |
| (Unchanged from paste.txt, but added error checks) | |
| """ | |
| if rgb_image is None or pred_mask is None: | |
| print("Cannot visualize cloud mask: Missing RGB image or prediction mask.") | |
| return None | |
| try: | |
| # Ensure pred_mask has the right dimensions (H, W) | |
| if pred_mask.ndim == 3 and pred_mask.shape[0] == 1: # Squeeze channel dim if present | |
| pred_mask = np.squeeze(pred_mask, axis=0) | |
| elif pred_mask.ndim != 2: | |
| print(f"ERROR: Unexpected prediction mask dimension: {pred_mask.ndim}, shape: {pred_mask.shape}") | |
| # Attempt to squeeze if possible, otherwise fail | |
| try: | |
| pred_mask = np.squeeze(pred_mask) | |
| if pred_mask.ndim != 2: raise ValueError("Still not 2D after squeeze") | |
| except Exception as sq_err: | |
| print(f"Could not convert mask to 2D: {sq_err}") | |
| return None # Cannot visualize | |
| print(f"Visualization - RGB image shape: {rgb_image.shape}, Pred mask shape: {pred_mask.shape}") | |
| # Ensure mask has the same spatial dimensions as the image | |
| if pred_mask.shape != rgb_image.shape[:2]: | |
| print(f"Warning: Resizing prediction mask from {pred_mask.shape} to {rgb_image.shape[:2]} for visualization.") | |
| # Ensure mask is integer type for nearest neighbor interpolation | |
| if not np.issubdtype(pred_mask.dtype, np.integer): | |
| print("Warning: Prediction mask is not integer type, casting to uint8 for resize.") | |
| pred_mask = pred_mask.astype(np.uint8) | |
| pred_mask_resized = cv2.resize( | |
| pred_mask, | |
| (rgb_image.shape[1], rgb_image.shape[0]), # Target shape (width, height) for cv2.resize | |
| interpolation=cv2.INTER_NEAREST # Use nearest to preserve class labels | |
| ) | |
| pred_mask = pred_mask_resized | |
| print(f"Resized mask shape: {pred_mask.shape}") | |
| # Define colors for each class | |
| colors = { | |
| 0: [0, 255, 0], # Clear - Green | |
| 1: [255, 0, 0], # Thick Cloud - Red (Changed from White for better contrast) | |
| 2: [255, 255, 0], # Thin Cloud - Yellow (Changed from Gray) | |
| 3: [0, 0, 255] # Cloud Shadow - Blue (Changed from Gray) | |
| } | |
| # Create a color-coded mask visualization | |
| mask_vis = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8) | |
| for class_idx, color in colors.items(): | |
| # Handle potential out-of-bounds class indices in mask | |
| mask_vis[pred_mask == class_idx] = color | |
| # Create a blended visualization | |
| alpha = 0.4 # Transparency of the mask overlay | |
| # Ensure rgb_image is uint8 for blending | |
| rgb_uint8 = (np.clip(rgb_image, 0, 1) * 255).astype(np.uint8) | |
| blended = cv2.addWeighted(rgb_uint8, 1-alpha, mask_vis, alpha, 0) | |
| # --- Create Legend --- | |
| legend_height = 100 | |
| legend_width = blended.shape[1] # Match image width | |
| legend = np.ones((legend_height, legend_width, 3), dtype=np.uint8) * 255 # White background | |
| legend_text = ["Clear", "Thick Cloud", "Thin Cloud", "Cloud Shadow"] | |
| legend_colors = [colors.get(i, [0,0,0]) for i in range(4)] # Use .get for safety | |
| box_size = 15 | |
| text_offset_x = 40 | |
| start_y = 15 | |
| padding_y = 20 | |
| for i, (text, color) in enumerate(zip(legend_text, legend_colors)): | |
| # Draw color box | |
| cv2.rectangle(legend, | |
| (10, start_y + i*padding_y - box_size // 2), | |
| (10 + box_size, start_y + i*padding_y + box_size // 2), | |
| color, -1) | |
| # Draw text | |
| cv2.putText(legend, text, | |
| (text_offset_x, start_y + i*padding_y + box_size // 4), # Adjust vertical alignment | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) | |
| # --- End Legend --- | |
| # Combine image and legend | |
| final_output = np.vstack([blended, legend]) | |
| return final_output | |
| except Exception as e: | |
| print(f"Error during visualization: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None # Return None if visualization fails | |
| # --- Triton Client Functions (Adapted from paste-2.txt) --- | |
| def is_triton_server_healthy(url=HTTP_TRITON_URL): | |
| """Checks if the Triton Inference Server is live.""" | |
| if not TRITON_CLIENT_AVAILABLE: | |
| return False | |
| try: | |
| triton_client = httpclient.InferenceServerClient(url=url, connection_timeout=10.0) # Short timeout for health check | |
| server_live = triton_client.is_server_live() | |
| if server_live: | |
| print(f"Triton server at {url} is live.") | |
| # Optionally check readiness: | |
| # server_ready = triton_client.is_server_ready() | |
| # print(f"Triton server at {url} is ready: {server_ready}") | |
| # return server_ready | |
| else: | |
| print(f"Triton server at {url} is not live.") | |
| return server_live | |
| except Exception as e: | |
| print(f"Could not connect to Triton server at {url}: {e}") | |
| return False | |
| def get_jp2_bytes_for_triton(red_file_path, green_file_path, nir_file_path): | |
| """ | |
| Reads the raw bytes of Red, Green, and NIR JP2 files for Triton. | |
| Order: Red, Green, NIR (must match Triton model input expectation) | |
| """ | |
| byte_list = [] | |
| files_to_read = [red_file_path, green_file_path, nir_file_path] | |
| band_names = ['Red', 'Green', 'NIR'] | |
| for file_path, band_name in zip(files_to_read, band_names): | |
| try: | |
| with open(file_path, "rb") as f: | |
| file_bytes = f.read() | |
| byte_list.append(file_bytes) | |
| print(f"Read {len(file_bytes)} bytes for {band_name} band from {os.path.basename(file_path)}") | |
| except FileNotFoundError: | |
| print(f"ERROR: File not found: {file_path}") | |
| raise # Propagate error | |
| except Exception as e: | |
| print(f"ERROR: Could not read file {file_path}: {e}") | |
| raise # Propagate error | |
| # Create NumPy array of object type to hold bytes | |
| input_byte_array = np.array(byte_list, dtype=object) | |
| # Expected shape is (3,) -> a 1D array containing 3 byte objects | |
| print(f"Prepared Triton input byte array with shape: {input_byte_array.shape} and dtype: {input_byte_array.dtype}") | |
| return input_byte_array | |
| def run_inference_triton_http(input_byte_array): | |
| """ | |
| Run inference using Triton HTTP client with raw JP2 bytes. | |
| """ | |
| if not TRITON_CLIENT_AVAILABLE: | |
| raise RuntimeError("Triton client library not available.") | |
| print("Attempting inference using Triton HTTP client...") | |
| try: | |
| client = httpclient.InferenceServerClient( | |
| url=HTTP_TRITON_URL, | |
| verbose=False, | |
| connection_timeout=TRITON_TIMEOUT_SECONDS, | |
| network_timeout=TRITON_TIMEOUT_SECONDS | |
| ) | |
| except Exception as e: | |
| print(f"ERROR: Couldn't create Triton HTTP client: {e}") | |
| raise # Propagate error | |
| # Prepare input tensor (BYTES type) | |
| # Shape [3] matches the 1D numpy array holding 3 byte strings | |
| inputs = [httpclient.InferInput(TRITON_INPUT_NAME, input_byte_array.shape, "BYTES")] | |
| inputs[0].set_data_from_numpy(input_byte_array, binary_data=True) # binary_data=True is important for BYTES | |
| # Prepare output tensor request | |
| outputs = [httpclient.InferRequestedOutput(TRITON_OUTPUT_NAME, binary_data=True)] | |
| # Send inference request | |
| try: | |
| print(f"Sending inference request to Triton model '{TRITON_MODEL_NAME}' at {HTTP_TRITON_URL}...") | |
| response = client.infer( | |
| model_name=TRITON_MODEL_NAME, | |
| inputs=inputs, | |
| outputs=outputs, | |
| request_id=str(os.getpid()), # Optional request ID | |
| timeout=TRITON_TIMEOUT_SECONDS | |
| ) | |
| print("Triton inference request successful.") | |
| mask = response.as_numpy(TRITON_OUTPUT_NAME) | |
| print(f"Received output mask from Triton with shape: {mask.shape}, dtype: {mask.dtype}") | |
| return mask | |
| except triton_utils.InferenceServerException as e: | |
| print(f"ERROR: Triton server failed inference: Status code {e.status()}, message: {e.message()}") | |
| print(f"Debug details: {e.debug_details()}") | |
| raise # Propagate error to trigger fallback | |
| except Exception as e: | |
| print(f"ERROR: An unexpected error occurred during Triton HTTP inference: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise # Propagate error to trigger fallback | |
| # --- Main Processing Function with Fallback Logic --- | |
| def process_satellite_images(red_file, green_file, blue_file, nir_file, batch_size, patch_size, patch_overlap): | |
| """ | |
| Process satellite images: Try Triton first, fallback to local model. | |
| """ | |
| if not all([red_file, green_file, blue_file, nir_file]): | |
| return None, None, "ERROR: Please upload all four channel files (Red, Green, Blue, NIR)" | |
| # Store file paths from Gradio Image components | |
| red_file_path = red_file if isinstance(red_file, str) else red_file.name | |
| green_file_path = green_file if isinstance(green_file, str) else green_file.name | |
| blue_file_path = blue_file if isinstance(blue_file, str) else blue_file.name | |
| nir_file_path = nir_file if isinstance(nir_file, str) else nir_file.name | |
| print("\n--- Starting Cloud Detection Process ---") | |
| print(f"Input files: R={os.path.basename(red_file_path)}, G={os.path.basename(green_file_path)}, B={os.path.basename(blue_file_path)}, N={os.path.basename(nir_file_path)}") | |
| pred_mask = None | |
| status_message = "" | |
| rgb_display_image = None # For the raw RGB output panel | |
| rgb_float_image = None # For overlay visualization | |
| # 1. Prepare Visualization Image (always needed) & Local Input Array (needed for fallback) | |
| print("Preparing visualization image and local model input array...") | |
| local_input_array, rgb_float_image = prepare_input_array(red_file_path, green_file_path, blue_file_path, nir_file_path) | |
| if rgb_float_image is not None: | |
| # Convert float image (0-1) to uint8 (0-255) for the RGB output panel | |
| rgb_display_image = (np.clip(rgb_float_image, 0, 1) * 255).astype(np.uint8) | |
| else: | |
| print("ERROR: Failed to create RGB visualization image.") | |
| # Return early if visualization prep failed, as likely indicates file loading issues | |
| return None, None, "ERROR: Failed to load or process input band files." | |
| # 2. Check Triton Server Health | |
| use_triton = False | |
| if TRITON_CLIENT_AVAILABLE: | |
| print(f"Checking Triton server health at {HTTP_TRITON_URL}...") | |
| if is_triton_server_healthy(HTTP_TRITON_URL): | |
| use_triton = True | |
| else: | |
| print("Triton server is not healthy or unavailable.") | |
| status_message += "Triton server unavailable. " | |
| else: | |
| print("Triton client library not installed. Skipping Triton check.") | |
| status_message += "Triton client not installed. " | |
| # 3. Attempt Triton Inference if Healthy | |
| if use_triton: | |
| try: | |
| print("Preparing JP2 bytes for Triton...") | |
| # Use Red, Green, NIR file paths | |
| triton_byte_input = get_jp2_bytes_for_triton(red_file_path, green_file_path, nir_file_path) | |
| pred_mask = run_inference_triton_http(triton_byte_input) | |
| status_message += "Inference performed using Triton Server. " | |
| print("Triton inference successful.") | |
| except Exception as e: | |
| print(f"Triton inference failed: {e}. Falling back to local model.") | |
| status_message += f"Triton inference failed ({type(e).__name__}). " | |
| pred_mask = None # Ensure mask is None to trigger fallback | |
| use_triton = False # Explicitly mark Triton as not used | |
| # 4. Fallback to Local Model if Triton failed or wasn't available/healthy | |
| if pred_mask is None: # Check if mask wasn't obtained from Triton | |
| status_message += "Falling back to local inference. " | |
| if LOCAL_MODEL_AVAILABLE and local_input_array is not None: | |
| print("Running local inference using omnicloudmask...") | |
| try: | |
| # Predict cloud mask using local omnicloudmask | |
| pred_mask = predict_from_array( | |
| local_input_array, | |
| batch_size=batch_size, | |
| patch_size=patch_size, | |
| patch_overlap=patch_overlap | |
| ) | |
| print(f"Local prediction successful. Output mask shape: {pred_mask.shape}, dtype: {pred_mask.dtype}") | |
| status_message += "Local inference successful." | |
| except Exception as e: | |
| print(f"ERROR: Local inference failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| status_message += f"Local inference FAILED: {e}" | |
| # Keep pred_mask as None | |
| elif not LOCAL_MODEL_AVAILABLE: | |
| status_message += "Local model not available. Cannot perform inference." | |
| print("ERROR: Local model could not be loaded.") | |
| elif local_input_array is None: | |
| status_message += "Local input data preparation failed. Cannot perform local inference." | |
| print("ERROR: Failed to prepare input array for local model.") | |
| else: | |
| status_message += "Unknown state, cannot perform inference." # Should not happen | |
| # 5. Process Results (Stats and Visualization) if mask was generated | |
| if pred_mask is not None: | |
| # Ensure mask is squeezed to 2D if necessary (local model might return extra dim) | |
| if pred_mask.ndim == 3 and pred_mask.shape[0] == 1: | |
| flat_mask = np.squeeze(pred_mask, axis=0) | |
| elif pred_mask.ndim == 2: | |
| flat_mask = pred_mask | |
| else: | |
| print(f"ERROR: Unexpected mask shape after inference: {pred_mask.shape}") | |
| status_message += " ERROR: Invalid mask shape received." | |
| flat_mask = None # Invalidate mask | |
| if flat_mask is not None: | |
| # Calculate class distribution | |
| clear_pixels = np.sum(flat_mask == 0) | |
| thick_cloud_pixels = np.sum(flat_mask == 1) | |
| thin_cloud_pixels = np.sum(flat_mask == 2) | |
| cloud_shadow_pixels = np.sum(flat_mask == 3) | |
| total_pixels = flat_mask.size | |
| stats = f""" | |
| Cloud Mask Statistics ({'Triton' if use_triton else 'Local'}): | |
| - Clear: {clear_pixels} pixels ({clear_pixels/total_pixels*100:.2f}%) | |
| - Thick Cloud: {thick_cloud_pixels} pixels ({thick_cloud_pixels/total_pixels*100:.2f}%) | |
| - Thin Cloud: {thin_cloud_pixels} pixels ({thin_cloud_pixels/total_pixels*100:.2f}%) | |
| - Cloud Shadow: {cloud_shadow_pixels} pixels ({cloud_shadow_pixels/total_pixels*100:.2f}%) | |
| - Total Cloud Cover (Thick+Thin): {(thick_cloud_pixels + thin_cloud_pixels)/total_pixels*100:.2f}% | |
| """ | |
| status_message += f"\nMask stats calculated. Total pixels: {total_pixels}." | |
| # Visualize the cloud mask on the original image | |
| print("Generating final visualization...") | |
| visualization = visualize_cloud_mask(rgb_float_image, flat_mask) # Use float image for viz function | |
| if visualization is None: | |
| status_message += " ERROR: Failed to generate visualization." | |
| print("--- Cloud Detection Process Finished ---") | |
| return rgb_display_image, visualization, status_message + "\n" + stats | |
| else: | |
| # Mask had wrong shape | |
| return rgb_display_image, None, status_message + "\nERROR: Could not process prediction mask." | |
| else: | |
| # Inference failed both ways or initial loading failed | |
| print("--- Cloud Detection Process Failed ---") | |
| return rgb_display_image, None, status_message + "\nERROR: Could not generate cloud mask." | |
| # --- Gradio Interface (from paste.txt) --- | |
| def check_cpu_usage(): | |
| """Check and return the current CPU usage.""" | |
| return f"CPU Usage: {psutil.cpu_percent()}%" | |
| # --- Build Gradio App --- | |
| print("Building Gradio interface...") | |
| with gr.Blocks(title="Satellite Cloud Detection (Triton/Local)") as demo: | |
| gr.Markdown(""" | |
| # Satellite Cloud Detection (with Triton Fallback) | |
| Upload separate JP2 files for Red (e.g., B04), Green (e.g., B03), Blue (e.g., B02), and NIR (e.g., B8A) channels. | |
| The application will **first attempt** to use a remote Triton Inference Server. If the server is unavailable or inference fails, | |
| it will **fall back** to using the local OmniCloudMask model. | |
| **Pixel Classification:** | |
| - Clear (Green) | |
| - Thick Cloud (Red) | |
| - Thin Cloud (Yellow) | |
| - Cloud Shadow (Blue) | |
| The model works best with imagery at 10-50m resolution. | |
| """) | |
| # Main cloud detection interface | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input Bands (JP2)") | |
| # Use filepaths which are needed for both local reading and byte reading | |
| red_input = gr.File(label="Red Channel (e.g., B04)", type="filepath") | |
| green_input = gr.File(label="Green Channel (e.g., B03)", type="filepath") | |
| blue_input = gr.File(label="Blue Channel (e.g., B02)", type="filepath") | |
| nir_input = gr.File(label="NIR Channel (e.g., B8A)", type="filepath") | |
| gr.Markdown("### Local Model Parameters (Used for Fallback)") | |
| batch_size = gr.Slider(minimum=1, maximum=32, value=4, step=1, | |
| label="Batch Size", | |
| info="Memory usage/speed for local model") | |
| patch_size = gr.Slider(minimum=256, maximum=2048, value=1024, step=128, | |
| label="Patch Size", | |
| info="Patch size for local model processing") | |
| patch_overlap = gr.Slider(minimum=64, maximum=512, value=256, step=64, | |
| label="Patch Overlap", | |
| info="Overlap for local model processing") | |
| process_btn = gr.Button("Process Cloud Detection", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Results") | |
| # Output components | |
| rgb_output = gr.Image(label="Original RGB Image (Approx. True Color)", type="numpy") | |
| cloud_output = gr.Image(label="Cloud Detection Visualization (Mask Overlay)", type="numpy") | |
| stats_output = gr.Textbox(label="Processing Status & Statistics", lines=10) | |
| # CPU usage monitoring section (Optional) | |
| with gr.Accordion("System Monitoring", open=False): | |
| cpu_button = gr.Button("Check CPU Usage") | |
| cpu_output = gr.Textbox(label="Current CPU Usage") | |
| cpu_button.click(fn=check_cpu_usage, inputs=None, outputs=cpu_output) | |
| # Examples section | |
| # Ensure example paths are relative to where the script is run, | |
| # or absolute if needed. Assumes 'jp2s' folder is present. | |
| example_base = os.path.join(repo_dir, "jp2s") # Use downloaded repo path | |
| example_files = [ | |
| os.path.join(example_base, "B04.jp2"), # Red | |
| os.path.join(example_base, "B03.jp2"), # Green | |
| os.path.join(example_base, "B02.jp2"), # Blue | |
| os.path.join(example_base, "B8A.jp2") # NIR | |
| ] | |
| # Check if example files actually exist before adding example | |
| if all(os.path.exists(f) for f in example_files): | |
| print("Adding examples...") | |
| gr.Examples( | |
| examples=[example_files + [4, 1024, 256]], # Corresponds to inputs below | |
| inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap], | |
| outputs=[rgb_output, cloud_output, stats_output], # Define outputs for examples too | |
| fn=process_satellite_images, # Function to run for examples | |
| cache_examples=False # Maybe disable caching if files change or for debugging | |
| ) | |
| else: | |
| print(f"WARN: Example JP2 files not found in '{example_base}'. Skipping examples.") | |
| # Setup main button click handler | |
| process_btn.click( | |
| fn=process_satellite_images, | |
| inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap], | |
| outputs=[rgb_output, cloud_output, stats_output] | |
| ) | |
| # --- Launch the App --- | |
| print("Launching Gradio app...") | |
| # Allow queueing and potentially increase workers if needed | |
| demo.queue(default_concurrency_limit=4).launch(debug=True, share=False) # share=True for public link if needed | |