kerzel's picture
remove some debug output
9e3d421
"""
Collection of various utils
"""
import numpy as np
import imageio.v3 as iio
from PIL import Image
# we may have very large images (e.g. panoramic SEM images), allow to read them w/o warnings
Image.MAX_IMAGE_PIXELS = 933120000
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
import logging # ADDED for logging
import math
###
### load SEM images (Note: Not directly used with Gradio gr.Image(type="pil"))
###
def load_image(filename : str) -> np.ndarray :
"""Load an SEM image
Args:
filename (str): full path and name of the image file to be loaded
Returns:
np.ndarray: file as numpy ndarray
"""
image = iio.imread(filename,mode='F')
return image
###
### show SEM image with boxes in various colours around each damage site
###
def show_boxes(image : np.ndarray, damage_sites : dict, box_size = [250,250],
save_image = False, image_path : str = None) :
"""
Shows an SEM image with colored boxes around identified damage sites.
Args:
image (np.ndarray): SEM image to be shown.
damage_sites (dict): Python dictionary using the coordinates as key (x,y), and the label as value.
box_size (list, optional): Size of the rectangle drawn around each centroid. Defaults to [250,250].
save_image (bool, optional): Save the image with the boxes or not. Defaults to False.
image_path (str, optional) : Full path and name of the output file to be saved.
"""
logging.debug(f"show_boxes: Input image type: {type(image)}")
# Ensure image is a NumPy array of appropriate type for matplotlib
if isinstance(image, Image.Image):
image_to_plot = np.array(image.convert('L')) # Convert to grayscale NumPy array
logging.debug("show_boxes: Converted PIL Image to grayscale NumPy array for plotting.")
elif isinstance(image, np.ndarray):
if image.ndim == 3 and image.shape[2] in [3,4]: # RGB or RGBA NumPy array
image_to_plot = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale
logging.debug("show_boxes: Converted multi-channel NumPy array to grayscale for plotting.")
else: # Assume grayscale already
image_to_plot = image
logging.debug("show_boxes: Image is already a grayscale NumPy array.")
else:
logging.error("show_boxes: Unsupported image format received.")
image_to_plot = np.zeros((100,100), dtype=np.uint8) # Fallback to black image
_, ax = plt.subplots(1)
ax.imshow(image_to_plot, cmap='gray') # show image on correct axis
ax.set_xticks([])
ax.set_yticks([])
for key, label in damage_sites.items():
position = [key[0], key[1]] # Assuming key[0] is y (row) and key[1] is x (column)
edgecolor = {
'Inclusion': 'b',
'Interface': 'g',
'Martensite': 'r',
'Notch': 'y',
'Shadowing': 'm',
'Not Classified': 'k' # Added Not Classified for completeness
}.get(label, 'k') # default: black
# Ensure box_size elements are floats for division
half_box_w = box_size[1] / 2.0
half_box_h = box_size[0] / 2.0
# x-coordinate of the bottom-left corner
rect_x = position[1] - half_box_w
# y-coordinate of the bottom-left corner (matplotlib origin is bottom-left)
rect_y = position[0] - half_box_h
rect = patches.Rectangle((rect_x, rect_y),
box_size[1], box_size[0],
linewidth=1, edgecolor=edgecolor, facecolor='none')
ax.add_patch(rect)
legend_elements = [
Line2D([0], [0], color='b', lw=4, label='Inclusion'),
Line2D([0], [0], color='g', lw=4, label='Interface'),
Line2D([0], [0], color='r', lw=4, label='Martensite'),
Line2D([0], [0], color='y', lw=4, label='Notch'),
Line2D([0], [0], color='m', lw=4, label='Shadow'),
Line2D([0], [0], color='k', lw=4, label='Not Classified')
]
ax.legend(handles=legend_elements, bbox_to_anchor=(1.04, 1), loc="upper left")
fig = ax.figure
fig.tight_layout(pad=0)
if save_image and image_path:
fig.savefig(image_path, dpi=1200, bbox_inches='tight')
canvas = fig.canvas
canvas.draw()
data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(
canvas.get_width_height()[::-1] + (4,))
data = data[:, :, :3] # RGB only
plt.close(fig)
return data
##
## orig
##
# ###
# ### cut out small images from panorama, append colour information
# ###
# def prepare_classifier_input(panorama, centroids: list, window_size=[250, 250]) -> list: # Removed np.ndarray type hint for panorama
# """
# Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
# Each extracted patch is resized to the specified window size and converted into a 3-channel (RGB-like)
# normalized image suitable for use with classification neural networks that expect color input.
# Parameters
# ----------
# panorama : PIL.Image.Image or np.ndarray
# Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data,
# or a PIL Image object.
# centroids : list of [int, int]
# List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
# identified in preprocessing (e.g., clustering).
# window_size : list of int, optional
# Size [height, width] of each extracted image patch. Defaults to [250, 250].
# Returns
# -------
# list of np.ndarray
# List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
# centroids that allow full window extraction within image bounds are used.
# """
# logging.info(f"prepare_classifier_input: Input panorama type: {type(panorama)}") # Added logging
# # --- MINIMAL FIX START ---
# # Convert PIL Image to NumPy array if necessary
# if isinstance(panorama, Image.Image):
# # Convert to grayscale NumPy array as your original code expects this structure for processing
# if panorama.mode == 'RGB':
# panorama_array = np.array(panorama.convert('L'))
# logging.info("prepare_classifier_input: Converted RGB PIL Image to grayscale NumPy array.")
# else:
# panorama_array = np.array(panorama)
# logging.info("prepare_classifier_input: Converted PIL Image to grayscale NumPy array.")
# elif isinstance(panorama, np.ndarray):
# # Ensure it's treated as a grayscale array for consistency with original logic
# if panorama.ndim == 3 and panorama.shape[2] in [3, 4]: # RGB or RGBA NumPy array
# panorama_array = np.mean(panorama, axis=2).astype(panorama.dtype) # Convert to grayscale
# logging.info("prepare_classifier_input: Converted multi-channel NumPy array to grayscale.")
# else:
# panorama_array = panorama # Assume it's already grayscale 2D or (H,W,1)
# logging.info("prepare_classifier_input: Panorama is already a suitable NumPy array.")
# else:
# logging.error("prepare_classifier_input: Unsupported panorama format received. Expected PIL Image or NumPy array.")
# raise ValueError("Unsupported panorama format for classifier input.")
# # Now, ensure panorama_array has a channel dimension if it's 2D for consistency
# if panorama_array.ndim == 2:
# panorama_array = np.expand_dims(panorama_array, axis=-1) # (H, W, 1)
# logging.info("prepare_classifier_input: Expanded 2D panorama to 3D (H,W,1).")
# # --- MINIMAL FIX END ---
# H, W, _ = panorama_array.shape # Use panorama_array here
# win_h, win_w = window_size
# images = []
# for (cy, cx) in centroids:
# # Ensure coordinates are integers
# cy, cx = int(round(cy)), int(round(cx))
# x1 = int(cx - win_w / 2)
# y1 = int(cy - win_h / 2)
# x2 = x1 + win_w
# y2 = y1 + win_h
# # Skip if patch would go out of bounds
# if x1 < 0 or y1 < 0 or x2 > W or y2 > H:
# logging.warning(f"prepare_classifier_input: Skipping centroid ({cy},{cx}) as patch is out of bounds.") # Added warning
# continue
# # Extract and normalize patch
# patch = panorama_array[y1:y2, x1:x2, 0].astype(np.float32) # Use panorama_array
# patch = patch * 2. / 255. - 1. # Keep your original normalization
# # Replicate grayscale channel to simulate RGB
# patch_color = np.repeat(patch[:, :, np.newaxis], 3, axis=2)
# images.append(patch_color)
# return images
###
### refactored
###
import numpy as np
from PIL import Image
import logging
from typing import List, Union, Tuple
def prepare_classifier_input(
panorama: Union[Image.Image, np.ndarray],
centroids: List[Tuple[int, int]],
window_size: List[int] = [250, 250]
) -> List[np.ndarray]:
"""
Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
Each extracted patch is resized to the specified window size and converted into a 3-channel (RGB-like)
normalized image suitable for use with classification neural networks that expect color input.
Parameters
----------
panorama : PIL.Image.Image or np.ndarray
Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data,
or a PIL Image object.
centroids : list of [int, int]
List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
identified in preprocessing (e.g., clustering).
window_size : list of int, optional
Size [height, width] of each extracted image patch. Defaults to [250, 250].
Returns
-------
list of np.ndarray
List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
centroids that allow full window extraction within image bounds are used.
"""
logging.debug(f"prepare_classifier_input: Input panorama type: {type(panorama)}")
# Convert input to standardized NumPy array format
panorama_array = _convert_to_grayscale_array(panorama)
# Ensure we have the correct dimensions
if panorama_array.ndim == 2:
H, W = panorama_array.shape
logging.debug("prepare_classifier_input: Working with 2D grayscale array.")
elif panorama_array.ndim == 3:
H, W, C = panorama_array.shape
if C == 1:
# Squeeze the single channel dimension for easier processing
panorama_array = panorama_array.squeeze(axis=2)
H, W = panorama_array.shape
logging.debug("prepare_classifier_input: Squeezed single channel dimension.")
else:
logging.error(f"prepare_classifier_input: Unexpected number of channels: {C}")
raise ValueError(f"Expected 1 channel, got {C}")
else:
logging.error(f"prepare_classifier_input: Unexpected array dimensions: {panorama_array.ndim}")
raise ValueError(f"Expected 2D or 3D array, got {panorama_array.ndim}D")
win_h, win_w = window_size
images = []
logging.info(f"prepare_classifier_input: Image dimensions: {H}x{W}, Window size: {win_h}x{win_w}")
logging.info(f"prepare_classifier_input: Processing {len(centroids)} centroids")
for i, (cy, cx) in enumerate(centroids):
# Ensure coordinates are integers
cy, cx = int(round(cy)), int(round(cx))
# Calculate patch boundaries
half_h, half_w = win_h // 2, win_w // 2
y1 = cy - half_h
y2 = y1 + win_h
x1 = cx - half_w
x2 = x1 + win_w
# Check bounds more explicitly
if y1 < 0 or x1 < 0 or y2 > H or x2 > W:
logging.warning(
f"prepare_classifier_input: Skipping centroid {i+1}/{len(centroids)} "
f"at ({cy},{cx}) - patch bounds ({y1}:{y2}, {x1}:{x2}) exceed image bounds (0:{H}, 0:{W})"
)
continue
try:
# Extract patch with explicit bounds checking
patch = panorama_array[y1:y2, x1:x2].astype(np.float32)
# Verify patch dimensions
if patch.shape != (win_h, win_w):
logging.warning(
f"prepare_classifier_input: Patch {i+1} has unexpected shape {patch.shape}, "
f"expected ({win_h}, {win_w}). Skipping."
)
continue
# Normalize patch: [0, 255] -> [-1, 1]
patch_normalized = (patch * 2.0 / 255.0) - 1.0
# Convert to 3-channel RGB-like format
patch_rgb = np.stack([patch_normalized] * 3, axis=2)
images.append(patch_rgb)
logging.debug(f"prepare_classifier_input: Successfully processed centroid {i+1} at ({cy},{cx})")
except Exception as e:
logging.error(
f"prepare_classifier_input: Error processing centroid {i+1} at ({cy},{cx}): {e}"
)
continue
logging.info(f"prepare_classifier_input: Successfully extracted {len(images)} patches from {len(centroids)} centroids")
# Add diagnostic information about the output
if images:
sample_shape = images[0].shape
sample_dtype = images[0].dtype
sample_min = images[0].min()
sample_max = images[0].max()
logging.info(f"prepare_classifier_input: Output patches - Shape: {sample_shape}, Dtype: {sample_dtype}, Range: [{sample_min:.3f}, {sample_max:.3f}]")
# Verify all patches have consistent shapes
shapes = [img.shape for img in images]
if not all(shape == sample_shape for shape in shapes):
logging.warning("prepare_classifier_input: Inconsistent patch shapes detected!")
for i, shape in enumerate(shapes):
if shape != sample_shape:
logging.warning(f" Patch {i}: {shape} (expected {sample_shape})")
else:
logging.warning("prepare_classifier_input: No valid patches were extracted!")
return images
def _convert_to_grayscale_array(panorama: Union[Image.Image, np.ndarray]) -> np.ndarray:
"""
Helper function to convert various input formats to a standardized grayscale NumPy array.
Parameters
----------
panorama : PIL.Image.Image or np.ndarray
Input image in various formats
Returns
-------
np.ndarray
Standardized grayscale array
"""
if isinstance(panorama, Image.Image):
if panorama.mode in ['RGB', 'RGBA']:
# Convert to grayscale
panorama_array = np.array(panorama.convert('L'))
logging.debug("_convert_to_grayscale_array: Converted RGB/RGBA PIL Image to grayscale.")
elif panorama.mode == 'L':
panorama_array = np.array(panorama)
logging.debug("_convert_to_grayscale_array: Converted grayscale PIL Image to NumPy array.")
else:
# Handle other modes by converting to grayscale
panorama_array = np.array(panorama.convert('L'))
logging.debug(f"_convert_to_grayscale_array: Converted PIL Image mode '{panorama.mode}' to grayscale.")
elif isinstance(panorama, np.ndarray):
if panorama.ndim == 2:
# Already grayscale
panorama_array = panorama.copy()
logging.debug("_convert_to_grayscale_array: Using existing 2D grayscale array.")
elif panorama.ndim == 3:
if panorama.shape[2] in [3, 4]: # RGB or RGBA
# Convert to grayscale using luminance weights
if panorama.shape[2] == 3: # RGB
panorama_array = np.dot(panorama, [0.299, 0.587, 0.114]).astype(panorama.dtype)
else: # RGBA
panorama_array = np.dot(panorama[:, :, :3], [0.299, 0.587, 0.114]).astype(panorama.dtype)
logging.debug("_convert_to_grayscale_array: Converted multi-channel NumPy array to grayscale using luminance weights.")
elif panorama.shape[2] == 1:
# Already single channel
panorama_array = panorama.copy()
logging.debug("_convert_to_grayscale_array: Using existing single-channel array.")
else:
raise ValueError(f"Unsupported number of channels: {panorama.shape[2]}")
else:
raise ValueError(f"Unsupported array dimensions: {panorama.ndim}")
else:
raise ValueError(f"Unsupported panorama type: {type(panorama)}")
return panorama_array
##
## debug
##
import numpy as np
import logging
from typing import List, Any
def debug_classification_input(patches: List[np.ndarray], model: Any = None) -> None:
"""
Debug function to help identify issues in the classification pipeline.
Call this right before your classification step.
Parameters
----------
patches : List[np.ndarray]
List of image patches from prepare_classifier_input
model : Any, optional
Your classification model (for additional debugging)
"""
logging.info("=== CLASSIFICATION DEBUG INFO ===")
logging.info(f"Number of patches: {len(patches)}")
if not patches:
logging.error("No patches provided for classification!")
return
for i, patch in enumerate(patches):
logging.info(f"Patch {i}:")
logging.info(f" Shape: {patch.shape}")
logging.info(f" Dtype: {patch.dtype}")
logging.info(f" Range: [{patch.min():.3f}, {patch.max():.3f}]")
logging.info(f" Memory layout: {patch.flags}")
# Check for common issues
if np.isnan(patch).any():
logging.warning(f" Contains NaN values: {np.isnan(patch).sum()}")
if np.isinf(patch).any():
logging.warning(f" Contains infinite values: {np.isinf(patch).sum()}")
# Check if patch is contiguous (some models require this)
if not patch.flags.c_contiguous:
logging.warning(f" Patch {i} is not C-contiguous")
# Test conversion to common formats
try:
patches_array = np.array(patches)
logging.info(f"Stacked array shape: {patches_array.shape}")
logging.info(f"Stacked array dtype: {patches_array.dtype}")
except Exception as e:
logging.error(f"Failed to stack patches into array: {e}")
# Test batch preparation (common source of slice errors)
try:
if len(patches) > 0:
# Common preprocessing steps that might cause issues
test_batch = np.stack(patches, axis=0) # Shape: (batch_size, height, width, channels)
logging.info(f"Test batch shape: {test_batch.shape}")
# Test various indexing operations that might cause slice errors
test_slice = test_batch[0] # Should work
logging.info(f"Single item slice shape: {test_slice.shape}")
test_batch_slice = test_batch[:] # Should work
logging.info(f"Full batch slice shape: {test_batch_slice.shape}")
except Exception as e:
logging.error(f"Error during batch preparation testing: {e}")
logging.error(f"Error type: {type(e)}")
import traceback
logging.error(f"Traceback: {traceback.format_exc()}")
logging.info("=== END CLASSIFICATION DEBUG ===")
def safe_classify_patches(patches: List[np.ndarray], classify_func, **kwargs) -> Any:
"""
Wrapper function to safely run classification with better error handling.
Parameters
----------
patches : List[np.ndarray]
List of image patches
classify_func : callable
Your classification function
**kwargs
Additional arguments for classify_func
Returns
-------
Any
Classification results or None if error occurred
"""
try:
logging.debug("Starting safe classification...")
# Debug the input
debug_classification_input(patches)
# Ensure patches are properly formatted
if not patches:
logging.error("No patches to classify")
return None
# Make sure all patches are contiguous arrays
patches_clean = []
for i, patch in enumerate(patches):
if not patch.flags.c_contiguous:
patch_clean = np.ascontiguousarray(patch)
logging.debug(f"Made patch {i} contiguous")
else:
patch_clean = patch
patches_clean.append(patch_clean)
# Call the actual classification function
logging.debug("Calling classification function...")
result = classify_func(patches_clean, **kwargs)
logging.debug("Classification completed successfully")
return result
except Exception as e:
logging.error(f"Error in safe_classify_patches: {e}")
logging.error(f"Error type: {type(e)}")
import traceback
logging.error(f"Full traceback: {traceback.format_exc()}")
return None
# Example usage function
def example_usage():
"""
Example of how to use the debug functions in your pipeline
"""
# Your existing code that calls prepare_classifier_input
# patches = prepare_classifier_input(panorama, centroids, window_size)
# Add debugging before classification
# debug_classification_input(patches)
# Use safe wrapper for classification
# results = safe_classify_patches(patches, your_classify_function, model=your_model)
pass
########################################
##
##
########################################
def extract_predictions_from_tfsm(model_output):
"""
Helper function to extract predictions from TFSMLayer output.
TFSMLayer often returns a dictionary with multiple outputs.
"""
logging.debug(f"Model output type: {type(model_output)}")
logging.debug(f"Model output keys: {model_output.keys() if isinstance(model_output, dict) else 'Not a dict'}")
if isinstance(model_output, dict):
# Try common output key names
possible_keys = ['output', 'predictions', 'dense', 'logits', 'probabilities']
# First, log all available keys
available_keys = list(model_output.keys())
logging.debug(f"Available output keys: {available_keys}")
# Try to find the right output
for key in possible_keys:
if key in model_output:
logging.debug(f"Using output key: {key}")
return model_output[key].numpy()
# If no standard key found, use the first available key
if available_keys:
first_key = available_keys[0]
logging.debug(f"Using first available key: {first_key}")
return model_output[first_key].numpy()
else:
raise ValueError("No output keys found in model response")
else:
# If it's not a dictionary, assume it's already the tensor we need
logging.debug("Model output is not a dictionary, using directly")
return model_output.numpy() if hasattr(model_output, 'numpy') else np.array(model_output)