import streamlit as st loader_placeholder = st.empty() loader_placeholder.markdown("""
Loading necessary libraries...
""", unsafe_allow_html=True) import numpy as np from st_click_detector import click_detector import cv2 from PIL import Image import tensorflow as tf from tensorflow.keras.models import load_model from tf_keras_vis.gradcam_plus_plus import GradcamPlusPlus from tf_keras_vis.utils.model_modifiers import ReplaceToLinear from tf_keras_vis.utils.scores import CategoricalScore import matplotlib.pyplot as plt import torch from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig from peft import PeftModel import base64 import os import io import traceback from tensorflow.keras.layers import ( Layer, Conv2D, Dense, GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Multiply, Add, Activation, Concatenate ) from pathlib import Path loader_placeholder.empty() #-------------------------------------------------------------------------------------------------- # unnecessary for this app, but needed for CNN model to load, so its necessary actually #-------------------------------------------------------------------------------------------------- @tf.keras.utils.register_keras_serializable(package="Custom", name="F1Score") class F1Score(tf.keras.metrics.Metric): def __init__(self, name='f1_score', **kwargs): super().__init__(name=name, **kwargs) self.precision = tf.keras.metrics.Precision() self.recall = tf.keras.metrics.Recall() def update_state(self, y_true, y_pred, sample_weight=None): self.precision.update_state(y_true, y_pred, sample_weight) self.recall.update_state(y_true, y_pred, sample_weight) def result(self): p = self.precision.result() r = self.recall.result() return 2 * (p * r) / (p + r + tf.keras.backend.epsilon()) def reset_states(self): self.precision.reset_states() self.recall.reset_states() @tf.keras.utils.register_keras_serializable(package="Custom", name="ChannelAttention") class ChannelAttention(Layer): def __init__(self, reduction=16, **kwargs): super(ChannelAttention, self).__init__(**kwargs) self.reduction = reduction def build(self, input_shape): channel = input_shape[-1] self.shared_dense_one = Dense(channel // self.reduction, activation='relu', kernel_initializer='he_normal', use_bias=True) self.shared_dense_two = Dense(channel, kernel_initializer='he_normal', use_bias=True) def call(self, inputs): avg_pool = GlobalAveragePooling2D()(inputs) max_pool = GlobalMaxPooling2D()(inputs) avg_pool = self.shared_dense_one(avg_pool) avg_pool = self.shared_dense_two(avg_pool) max_pool = self.shared_dense_one(max_pool) max_pool = self.shared_dense_two(max_pool) attention = Add()([avg_pool, max_pool]) attention = Activation('sigmoid')(attention) attention = Reshape((1, 1, -1))(attention) return Multiply()([inputs, attention]) @tf.keras.utils.register_keras_serializable(package="Custom", name="SpatialAttention") class SpatialAttention(Layer): def __init__(self, **kwargs): super(SpatialAttention, self).__init__(**kwargs) self.conv2d = Conv2D(filters=1, kernel_size=7, strides=1, padding='same', activation='sigmoid') def call(self, inputs): avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True) max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True) concat = Concatenate(axis=-1)([avg_pool, max_pool]) attention = self.conv2d(concat) return Multiply()([inputs, attention]) def cbam_block(inputs, reduction=16): x = ChannelAttention(reduction)(inputs) x = SpatialAttention()(x) return x #---------------------------------------------------------------------------------------------------------- #--------------------------------------------------------------------------------------------------------- # ------------------------- # Helpers & small utilities # ------------------------- def bytes_from_path(path): with open(path, "rb") as f: return f.read() def image_to_data_uri(path: str, max_width=224, jpeg_quality=70): p = Path(path) if not p.exists(): return None img = Image.open(p).convert("RGB") # resize maintaining aspect ratio if img.width > max_width: new_h = int(max_width * img.height / img.width) img = img.resize((max_width, new_h), Image.BILINEAR) buf = io.BytesIO() img.save(buf, format="JPEG", quality=jpeg_quality, optimize=True) b = buf.getvalue() data64 = base64.b64encode(b).decode("utf-8") return f"data:image/jpeg;base64,{data64}" labels = ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc'] full_names = { 'akiec': 'Actinic keratoses', 'bcc': 'Basal cell carcinoma', 'bkl': 'Benign keratosis-like lesions', 'df': 'Dermatofibroma', 'mel': 'Melanoma', 'nv': 'Melanocytic nevi', 'vasc': 'Vascular lesions' } def preprocess_image(image): if image.dtype != np.uint8: image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) clahe = cv2.createCLAHE(clipLimit=0.01, tileGridSize=(8, 8)) lab[:, :, 0] = clahe.apply(lab[:, :, 0]) image_clahe = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) image_clahe = image_clahe.astype(np.float32) image_clahe = (image_clahe - np.min(image_clahe)) / (np.ptp(image_clahe) + 1e-8) return image_clahe @st.cache_resource(show_spinner=False) def load_cnn_model(model_path="Proposed CBAM-Xception-DermNet.keras"): if 'cnn_model' in st.session_state: return st.session_state.cnn_model try: model = load_model(model_path) st.session_state.cnn_model = model return model except Exception as e: st.error(f"Failed to load CNN model from '{model_path}': {e}") st.exception(traceback.format_exc()) raise @st.cache_resource(show_spinner=False) def load_vlm_model(): if st.session_state.get("vlm_loaded", False): return { "model": st.session_state.vlm_model, "processor": st.session_state.processor, "device": st.session_state.device, "dtype": st.session_state.dtype } USE_4BIT = True HF_MODEL_ID = "google/medgemma-4b-it" # Hugging Face repo ID LORA_OUTPUT_DIR = "./medgemma_lora_adapter" #local lora saved dir DEVICE = "cuda" if torch.cuda.is_available() else "cpu" hf_token = os.getenv("HF_TOKEN") #NOTE: hiding mandatory (reminder) # Determine dtype capability = torch.cuda.get_device_capability(0)[0] if torch.cuda.is_available() else 0 dtype = torch.bfloat16 if torch.cuda.is_available() and capability >= 8 else torch.float32 # 4-bit quantization config bnb_config = None if USE_4BIT: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype, ) # Load processor from LoRA adapter folder (it contains tokenizer, etc.) try: processor = AutoProcessor.from_pretrained( LORA_OUTPUT_DIR, trust_remote_code=True ) processor.tokenizer.padding_side = "right" except Exception as e: st.error(f"Failed to load processor from '{LORA_OUTPUT_DIR}': {e}") st.exception(traceback.format_exc()) raise # Load base model from Hugging Face hub try: base_model = AutoModelForImageTextToText.from_pretrained( HF_MODEL_ID, quantization_config=bnb_config if USE_4BIT else None, dtype=dtype, device_map="auto", trust_remote_code=True, use_auth_token=hf_token # only needed if repo is private ) except Exception as e: st.error(f"Failed to load base model from Hugging Face hub: {e}") st.exception(traceback.format_exc()) raise # Attach LoRA adapter try: model = PeftModel.from_pretrained( base_model, LORA_OUTPUT_DIR, device_map="auto" ) except Exception as e: st.error(f"Failed to attach LoRA adapter: {e}") st.exception(traceback.format_exc()) raise model.eval() try: model.to(DEVICE) except Exception: # ignore if model already on correct device pass # Cache into session_state st.session_state.vlm_model = model st.session_state.processor = processor st.session_state.device = DEVICE st.session_state.dtype = dtype st.session_state.vlm_loaded = True return {"model": model, "processor": processor, "device": DEVICE, "dtype": dtype} def generate_vlm_response(processor, vlm_model, device, gradcam_image: Image.Image, pred_label, max_new_tokens=128): try: prompt_template = ( "You are an AI assistant specialized in model interpretability. " "I am providing:\n- CNN model Grad-CAM++ heatmap image\n- Model predicted class: {predicted_class}\n\n" "Based on the Grad-CAM++ heatmap, write a clear and concise 20–30 word explanation " "of which features the model focused on and why. Output only the explanation (no headings)." ) user_prompt = prompt_template.format(predicted_class=pred_label) chat = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": user_prompt} ], } ] formatted_prompt = processor.apply_chat_template(chat, add_generation_prompt=True, tokenize=False) inputs = processor(text=formatted_prompt, images=gradcam_image, return_tensors="pt", padding=True) try: inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()} except Exception: for k, v in inputs.items(): if isinstance(v, torch.Tensor): inputs[k] = v.to(device) if hasattr(inputs, "pixel_values") or ("pixel_values" in inputs): try: inputs["pixel_values"] = inputs["pixel_values"].to(dtype=vlm_model.dtype) except Exception: try: inputs["pixel_values"] = inputs["pixel_values"].to(dtype=torch.float16) except Exception: pass with torch.inference_mode(): output_ids = vlm_model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=processor.tokenizer.eos_token_id, ) # Some generate wrappers return object with .sequences if hasattr(output_ids, "sequences"): seqs = output_ids.sequences else: seqs = output_ids input_len = inputs["input_ids"].shape[-1] response = processor.decode(seqs[0, input_len:], skip_special_tokens=True) return response.strip() except Exception as e: st.error(f"VLM generation failed: {e}") st.exception(traceback.format_exc()) return None def classify_and_gradcam(image_bytes): pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB") preprocessed = preprocess_image(np.array(pil_img)) input_tensor = np.expand_dims(preprocessed, axis=0) with st.spinner("Loading Classifier Model..."): cnn = load_cnn_model("Proposed CBAM-Xception-DermNet.keras") with st.spinner("Classifying..."): preds = cnn.predict(input_tensor)[0] pred_idx = int(np.argmax(preds)) pred_label = labels[pred_idx] conf = float(preds[pred_idx]) with st.spinner("Generating Attention Map..."): target_layer = "block14_sepconv2" score = CategoricalScore([pred_idx]) gradcam_vis = GradcamPlusPlus(cnn, model_modifier=ReplaceToLinear(), clone=True) cam = gradcam_vis(score, input_tensor, penultimate_layer=target_layer)[0] cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) heatmap = plt.cm.jet(cam)[..., :3] overlay = 0.25 * heatmap + 0.75 * preprocessed overlay = np.uint8(255 * np.clip(overlay, 0, 1)) overlay_pil = Image.fromarray(overlay) return pred_label, conf, overlay_pil # ------------------------- # Main display config & styling # ------------------------- st.set_page_config(page_title="Skin Cancer Classifier", layout="wide", initial_sidebar_state="expanded") st.markdown(""" """, unsafe_allow_html=True) with st.sidebar: st.header("Important Notice") st.markdown(""" - This app is a prototype, not for clinical use. - Do not rely on classifications or explanations for medical decisions. - This apps model is fine tuned on only one small dataset. - It might not capture your original disease. - Always consult a qualified healthcare professional. - Results may not be accurate; use at your own risk. - Again, this is just a prototype! """, unsafe_allow_html=True) st.markdown("---") if st.button("Clear Models Cache"): for k in ["cnn_model", "vlm_model", "processor", "device", "dtype", "vlm_loaded"]: if k in st.session_state: del st.session_state[k] st.success("Model cache cleared. Models will reload on next use.") st.markdown("
Skin Cancer Image Classifier
", unsafe_allow_html=True) st.markdown("
Local CNN inference • Model Attention (Grad-CAM++) visualizations • optional VLM explanations
", unsafe_allow_html=True) uploaded_file = st.file_uploader("Upload a skin lesion image", type=["jpg","jpeg","png"], key="uploaded_file" ) # --- Handle automatic reset if file is cleared --- #if uploaded_file is None and "selected_image" in st.session_state: # # Only clear if user manually removed an uploaded file # if not st.session_state.get("example_selected", False): # for key in ["selected_image", "vlm_response"]: # st.session_state.pop(key, None) # st.rerun() if uploaded_file is not None: st.session_state.selected_image = uploaded_file.read() st.session_state.example_selected = False st.session_state["vlm_response"] = None if uploaded_file is None and not st.session_state.get("example_selected", False): keys_to_clear = ["vlm_response", "pred_label", "conf", "overlay_pil", "last_image_bytes", "selected_image"] for k in keys_to_clear: if k in st.session_state: del st.session_state[k] # Main layout: image area and visualization original_image_col, attention_column = st.columns([2,2]) with original_image_col: st.markdown("
", unsafe_allow_html=True) st.subheader("Selected Image") if 'selected_image' in st.session_state: pil_img = Image.open(io.BytesIO(st.session_state.selected_image)).convert("RGB") st.image(pil_img, width=360, caption="Selected image", output_format="auto") else: st.info("No image selected. Upload or click an example below.") st.markdown("
", unsafe_allow_html=True) # full column if 'selected_image' in st.session_state: img_bytes = st.session_state.selected_image if st.session_state.get("last_image_bytes") != img_bytes: pred_label, conf, overlay_pil = classify_and_gradcam(img_bytes) st.session_state["pred_label"] = pred_label st.session_state["conf"] = conf st.session_state["overlay_pil"] = overlay_pil st.session_state["last_image_bytes"] = img_bytes try: with st.spinner("Loading VLM Model. Please be patient..."): try: vlm_info = load_vlm_model() except Exception as e: st.error("VLM load failed. See logs above.") vlm_info = None if vlm_info is not None: try: img_for_vlm = overlay_pil.convert("RGB").resize((224, 224), Image.BILINEAR) except Exception: st.warning("Overlay image not available for VLM input; using original image.") img_for_vlm = pil_img.convert("RGB").resize((224, 224), Image.BILINEAR) with st.spinner("Generating Explanation...."): response = generate_vlm_response( vlm_info["processor"], vlm_info["model"], vlm_info["device"], img_for_vlm, pred_label, max_new_tokens=128 ) #response = "Debugging VLM response." # For debugging if response is None: st.error("VLM did not return a response.") else: st.session_state["vlm_response"] = response except Exception as e: st.error(f"Error in VLM generation flow: {e}") st.exception(traceback.format_exc()) with attention_column: st.markdown("
", unsafe_allow_html=True) st.subheader("Model Attention Visualization") if 'selected_image' in st.session_state: st.image(st.session_state["overlay_pil"], caption="Model Attention Overlay", width=360, output_format="auto") else: st.info("Model Attention will appear here after selecting an image and running classification.") st.markdown("
", unsafe_allow_html=True) # Metrics placeholder c1, c2 = st.columns([3,1]) if st.session_state.get("selected_image") and st.session_state.get("pred_label"): c1.metric("Predicted", full_names[st.session_state["pred_label"]]) c2.metric("Confidence", f"{st.session_state['conf']:.2f}") else: c1.metric("Predicted", "—") c2.metric("Confidence", "—") # VLM Response placeholder st.subheader("Generated Explanation") if st.session_state.get("vlm_response"): st.info(st.session_state["vlm_response"]) else: st.info("VLM explanation will appear here after selecting an image and running classification.") example_paths = [ "images/ISIC_0025314.jpg", "images/ISIC_0025586.jpg", "images/ISIC_0025680.jpg", "images/ISIC_0026163.jpg" ] # Container div for toggle + gallery st.markdown("""
""", unsafe_allow_html=True) toggle = st.toggle("Show Example Images", value=False) if toggle: # Toggle ON → show gallery st.markdown("
Click on any image to analyze it instantly
", unsafe_allow_html=True) html = """ " if "example_click_key" not in st.session_state: st.session_state.example_click_key = 0 clicked = click_detector(html, key=f"clicking_examples_{st.session_state.example_click_key}") if clicked: if uploaded_file is not None: st.warning("Please remove the uploaded file by clickng cross in the uploaded file name") else: idx = int(clicked.split("_")[1]) selected_path = example_paths[idx] img_bytes = open(selected_path, "rb").read() if st.session_state.get("last_image_bytes") != img_bytes: st.session_state.selected_image = img_bytes st.session_state.example_selected = True st.session_state["vlm_response"] = None st.session_state.example_click_key += 1 try: st.toast(f"✅ Selected image: {selected_path}", icon="📸") except Exception: st.success(f"Selected image: {selected_path}") st.rerun() st.markdown("
", unsafe_allow_html=True) st.markdown("""
© 2025 Faysal Ahmmed, Ajmy Alaly, Samanta Mehnaj, Asef Rahman, F.M. Mridha. All rights reserved.
""", unsafe_allow_html=True)