Spaces:
Running
Running
| import streamlit as st | |
| loader_placeholder = st.empty() | |
| loader_placeholder.markdown(""" | |
| <div style=" | |
| display:flex; | |
| justify-content:center; | |
| align-items:center; | |
| height:50vh; | |
| font-size:40px; | |
| font-weight:bold; | |
| color:#00b4d8; | |
| animation: flash 1s infinite; | |
| "> | |
| Loading necessary libraries... | |
| </div> | |
| <style> | |
| @keyframes flash { | |
| 0% { opacity: 0.2; } | |
| 50% { opacity: 1; } | |
| 100% { opacity: 0.2; } | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| import numpy as np | |
| from st_click_detector import click_detector | |
| import cv2 | |
| from PIL import Image | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| from tf_keras_vis.gradcam_plus_plus import GradcamPlusPlus | |
| from tf_keras_vis.utils.model_modifiers import ReplaceToLinear | |
| from tf_keras_vis.utils.scores import CategoricalScore | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig | |
| from peft import PeftModel | |
| import base64 | |
| import os | |
| import io | |
| import traceback | |
| from tensorflow.keras.layers import ( | |
| Layer, Conv2D, Dense, | |
| GlobalAveragePooling2D, GlobalMaxPooling2D, | |
| Reshape, Multiply, Add, Activation, Concatenate | |
| ) | |
| from pathlib import Path | |
| loader_placeholder.empty() | |
| #-------------------------------------------------------------------------------------------------- | |
| # unnecessary for this app, but needed for CNN model to load, so its necessary actually | |
| #-------------------------------------------------------------------------------------------------- | |
| class F1Score(tf.keras.metrics.Metric): | |
| def __init__(self, name='f1_score', **kwargs): | |
| super().__init__(name=name, **kwargs) | |
| self.precision = tf.keras.metrics.Precision() | |
| self.recall = tf.keras.metrics.Recall() | |
| def update_state(self, y_true, y_pred, sample_weight=None): | |
| self.precision.update_state(y_true, y_pred, sample_weight) | |
| self.recall.update_state(y_true, y_pred, sample_weight) | |
| def result(self): | |
| p = self.precision.result() | |
| r = self.recall.result() | |
| return 2 * (p * r) / (p + r + tf.keras.backend.epsilon()) | |
| def reset_states(self): | |
| self.precision.reset_states() | |
| self.recall.reset_states() | |
| class ChannelAttention(Layer): | |
| def __init__(self, reduction=16, **kwargs): | |
| super(ChannelAttention, self).__init__(**kwargs) | |
| self.reduction = reduction | |
| def build(self, input_shape): | |
| channel = input_shape[-1] | |
| self.shared_dense_one = Dense(channel // self.reduction, activation='relu', kernel_initializer='he_normal', use_bias=True) | |
| self.shared_dense_two = Dense(channel, kernel_initializer='he_normal', use_bias=True) | |
| def call(self, inputs): | |
| avg_pool = GlobalAveragePooling2D()(inputs) | |
| max_pool = GlobalMaxPooling2D()(inputs) | |
| avg_pool = self.shared_dense_one(avg_pool) | |
| avg_pool = self.shared_dense_two(avg_pool) | |
| max_pool = self.shared_dense_one(max_pool) | |
| max_pool = self.shared_dense_two(max_pool) | |
| attention = Add()([avg_pool, max_pool]) | |
| attention = Activation('sigmoid')(attention) | |
| attention = Reshape((1, 1, -1))(attention) | |
| return Multiply()([inputs, attention]) | |
| class SpatialAttention(Layer): | |
| def __init__(self, **kwargs): | |
| super(SpatialAttention, self).__init__(**kwargs) | |
| self.conv2d = Conv2D(filters=1, kernel_size=7, strides=1, padding='same', activation='sigmoid') | |
| def call(self, inputs): | |
| avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True) | |
| max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True) | |
| concat = Concatenate(axis=-1)([avg_pool, max_pool]) | |
| attention = self.conv2d(concat) | |
| return Multiply()([inputs, attention]) | |
| def cbam_block(inputs, reduction=16): | |
| x = ChannelAttention(reduction)(inputs) | |
| x = SpatialAttention()(x) | |
| return x | |
| #---------------------------------------------------------------------------------------------------------- | |
| #--------------------------------------------------------------------------------------------------------- | |
| # ------------------------- | |
| # Helpers & small utilities | |
| # ------------------------- | |
| def bytes_from_path(path): | |
| with open(path, "rb") as f: | |
| return f.read() | |
| def image_to_data_uri(path: str, max_width=224, jpeg_quality=70): | |
| p = Path(path) | |
| if not p.exists(): | |
| return None | |
| img = Image.open(p).convert("RGB") | |
| # resize maintaining aspect ratio | |
| if img.width > max_width: | |
| new_h = int(max_width * img.height / img.width) | |
| img = img.resize((max_width, new_h), Image.BILINEAR) | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG", quality=jpeg_quality, optimize=True) | |
| b = buf.getvalue() | |
| data64 = base64.b64encode(b).decode("utf-8") | |
| return f"data:image/jpeg;base64,{data64}" | |
| labels = ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc'] | |
| full_names = { | |
| 'akiec': 'Actinic keratoses', | |
| 'bcc': 'Basal cell carcinoma', | |
| 'bkl': 'Benign keratosis-like lesions', | |
| 'df': 'Dermatofibroma', | |
| 'mel': 'Melanoma', | |
| 'nv': 'Melanocytic nevi', | |
| 'vasc': 'Vascular lesions' | |
| } | |
| def preprocess_image(image): | |
| if image.dtype != np.uint8: | |
| image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) | |
| lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) | |
| clahe = cv2.createCLAHE(clipLimit=0.01, tileGridSize=(8, 8)) | |
| lab[:, :, 0] = clahe.apply(lab[:, :, 0]) | |
| image_clahe = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) | |
| image_clahe = image_clahe.astype(np.float32) | |
| image_clahe = (image_clahe - np.min(image_clahe)) / (np.ptp(image_clahe) + 1e-8) | |
| return image_clahe | |
| def load_cnn_model(model_path="Proposed CBAM-Xception-DermNet.keras"): | |
| if 'cnn_model' in st.session_state: | |
| return st.session_state.cnn_model | |
| try: | |
| model = load_model(model_path) | |
| st.session_state.cnn_model = model | |
| return model | |
| except Exception as e: | |
| st.error(f"Failed to load CNN model from '{model_path}': {e}") | |
| st.exception(traceback.format_exc()) | |
| raise | |
| def load_vlm_model(): | |
| if st.session_state.get("vlm_loaded", False): | |
| return { | |
| "model": st.session_state.vlm_model, | |
| "processor": st.session_state.processor, | |
| "device": st.session_state.device, | |
| "dtype": st.session_state.dtype | |
| } | |
| USE_4BIT = True | |
| HF_MODEL_ID = "google/medgemma-4b-it" # Hugging Face repo ID | |
| LORA_OUTPUT_DIR = "./medgemma_lora_adapter" #local lora saved dir | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| hf_token = os.getenv("HF_TOKEN") #NOTE: hiding mandatory (reminder) | |
| # Determine dtype | |
| capability = torch.cuda.get_device_capability(0)[0] if torch.cuda.is_available() else 0 | |
| dtype = torch.bfloat16 if torch.cuda.is_available() and capability >= 8 else torch.float32 | |
| # 4-bit quantization config | |
| bnb_config = None | |
| if USE_4BIT: | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=dtype, | |
| ) | |
| # Load processor from LoRA adapter folder (it contains tokenizer, etc.) | |
| try: | |
| processor = AutoProcessor.from_pretrained( | |
| LORA_OUTPUT_DIR, | |
| trust_remote_code=True | |
| ) | |
| processor.tokenizer.padding_side = "right" | |
| except Exception as e: | |
| st.error(f"Failed to load processor from '{LORA_OUTPUT_DIR}': {e}") | |
| st.exception(traceback.format_exc()) | |
| raise | |
| # Load base model from Hugging Face hub | |
| try: | |
| base_model = AutoModelForImageTextToText.from_pretrained( | |
| HF_MODEL_ID, | |
| quantization_config=bnb_config if USE_4BIT else None, | |
| dtype=dtype, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| use_auth_token=hf_token # only needed if repo is private | |
| ) | |
| except Exception as e: | |
| st.error(f"Failed to load base model from Hugging Face hub: {e}") | |
| st.exception(traceback.format_exc()) | |
| raise | |
| # Attach LoRA adapter | |
| try: | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| LORA_OUTPUT_DIR, | |
| device_map="auto" | |
| ) | |
| except Exception as e: | |
| st.error(f"Failed to attach LoRA adapter: {e}") | |
| st.exception(traceback.format_exc()) | |
| raise | |
| model.eval() | |
| try: | |
| model.to(DEVICE) | |
| except Exception: | |
| # ignore if model already on correct device | |
| pass | |
| # Cache into session_state | |
| st.session_state.vlm_model = model | |
| st.session_state.processor = processor | |
| st.session_state.device = DEVICE | |
| st.session_state.dtype = dtype | |
| st.session_state.vlm_loaded = True | |
| return {"model": model, "processor": processor, "device": DEVICE, "dtype": dtype} | |
| def generate_vlm_response(processor, vlm_model, device, gradcam_image: Image.Image, pred_label, | |
| max_new_tokens=128): | |
| try: | |
| prompt_template = ( | |
| "You are an AI assistant specialized in model interpretability. " | |
| "I am providing:\n- CNN model Grad-CAM++ heatmap image\n- Model predicted class: {predicted_class}\n\n" | |
| "Based on the Grad-CAM++ heatmap, write a clear and concise 20β30 word explanation " | |
| "of which features the model focused on and why. Output only the explanation (no headings)." | |
| ) | |
| user_prompt = prompt_template.format(predicted_class=pred_label) | |
| chat = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": user_prompt} | |
| ], | |
| } | |
| ] | |
| formatted_prompt = processor.apply_chat_template(chat, add_generation_prompt=True, tokenize=False) | |
| inputs = processor(text=formatted_prompt, images=gradcam_image, return_tensors="pt", padding=True) | |
| try: | |
| inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
| except Exception: | |
| for k, v in inputs.items(): | |
| if isinstance(v, torch.Tensor): | |
| inputs[k] = v.to(device) | |
| if hasattr(inputs, "pixel_values") or ("pixel_values" in inputs): | |
| try: | |
| inputs["pixel_values"] = inputs["pixel_values"].to(dtype=vlm_model.dtype) | |
| except Exception: | |
| try: | |
| inputs["pixel_values"] = inputs["pixel_values"].to(dtype=torch.float16) | |
| except Exception: | |
| pass | |
| with torch.inference_mode(): | |
| output_ids = vlm_model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| ) | |
| # Some generate wrappers return object with .sequences | |
| if hasattr(output_ids, "sequences"): | |
| seqs = output_ids.sequences | |
| else: | |
| seqs = output_ids | |
| input_len = inputs["input_ids"].shape[-1] | |
| response = processor.decode(seqs[0, input_len:], skip_special_tokens=True) | |
| return response.strip() | |
| except Exception as e: | |
| st.error(f"VLM generation failed: {e}") | |
| st.exception(traceback.format_exc()) | |
| return None | |
| def classify_and_gradcam(image_bytes): | |
| pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| preprocessed = preprocess_image(np.array(pil_img)) | |
| input_tensor = np.expand_dims(preprocessed, axis=0) | |
| with st.spinner("Loading Classifier Model..."): | |
| cnn = load_cnn_model("Proposed CBAM-Xception-DermNet.keras") | |
| with st.spinner("Classifying..."): | |
| preds = cnn.predict(input_tensor)[0] | |
| pred_idx = int(np.argmax(preds)) | |
| pred_label = labels[pred_idx] | |
| conf = float(preds[pred_idx]) | |
| with st.spinner("Generating Attention Map..."): | |
| target_layer = "block14_sepconv2" | |
| score = CategoricalScore([pred_idx]) | |
| gradcam_vis = GradcamPlusPlus(cnn, model_modifier=ReplaceToLinear(), clone=True) | |
| cam = gradcam_vis(score, input_tensor, penultimate_layer=target_layer)[0] | |
| cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) | |
| heatmap = plt.cm.jet(cam)[..., :3] | |
| overlay = 0.25 * heatmap + 0.75 * preprocessed | |
| overlay = np.uint8(255 * np.clip(overlay, 0, 1)) | |
| overlay_pil = Image.fromarray(overlay) | |
| return pred_label, conf, overlay_pil | |
| # ------------------------- | |
| # Main display config & styling | |
| # ------------------------- | |
| st.set_page_config(page_title="Skin Cancer Classifier", layout="wide", initial_sidebar_state="expanded") | |
| st.markdown(""" | |
| <style> | |
| .stApp { background: linear-gradient(180deg, #f5f7fb 0%, #ffffff 100%); } | |
| .card { background: white; border-radius: 12px; padding: 14px; box-shadow: 0 8px 22px rgba(14,30,37,0.06); } | |
| .header-title { font-size:34px; font-weight:700; margin-bottom:4px; } | |
| .header-sub { color:#6b7280; margin-bottom:6px; } | |
| .small { font-size:13px; color:#6b7280; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| with st.sidebar: | |
| st.header("Important Notice") | |
| st.markdown(""" | |
| - This app is a prototype, not for clinical use. | |
| - Do not rely on classifications or explanations for medical decisions. | |
| - This apps model is fine tuned on only one small dataset. | |
| - It might not capture your original disease. | |
| - Always consult a qualified healthcare professional. | |
| - Results may not be accurate; use at your own risk. | |
| - Again, this is just a prototype! | |
| """, unsafe_allow_html=True) | |
| st.markdown("---") | |
| if st.button("Clear Models Cache"): | |
| for k in ["cnn_model", "vlm_model", "processor", "device", "dtype", "vlm_loaded"]: | |
| if k in st.session_state: | |
| del st.session_state[k] | |
| st.success("Model cache cleared. Models will reload on next use.") | |
| st.markdown("<div class='header-title'>Skin Cancer Image Classifier</div>", unsafe_allow_html=True) | |
| st.markdown("<div class='header-sub'>Local CNN inference β’ Model Attention (Grad-CAM++) visualizations β’ optional VLM explanations</div>", unsafe_allow_html=True) | |
| uploaded_file = st.file_uploader("Upload a skin lesion image", type=["jpg","jpeg","png"], key="uploaded_file" ) | |
| # --- Handle automatic reset if file is cleared --- | |
| #if uploaded_file is None and "selected_image" in st.session_state: | |
| # # Only clear if user manually removed an uploaded file | |
| # if not st.session_state.get("example_selected", False): | |
| # for key in ["selected_image", "vlm_response"]: | |
| # st.session_state.pop(key, None) | |
| # st.rerun() | |
| if uploaded_file is not None: | |
| st.session_state.selected_image = uploaded_file.read() | |
| st.session_state.example_selected = False | |
| st.session_state["vlm_response"] = None | |
| if uploaded_file is None and not st.session_state.get("example_selected", False): | |
| keys_to_clear = ["vlm_response", "pred_label", "conf", "overlay_pil", "last_image_bytes", "selected_image"] | |
| for k in keys_to_clear: | |
| if k in st.session_state: | |
| del st.session_state[k] | |
| # Main layout: image area and visualization | |
| original_image_col, attention_column = st.columns([2,2]) | |
| with original_image_col: | |
| st.markdown("<div class='card'>", unsafe_allow_html=True) | |
| st.subheader("Selected Image") | |
| if 'selected_image' in st.session_state: | |
| pil_img = Image.open(io.BytesIO(st.session_state.selected_image)).convert("RGB") | |
| st.image(pil_img, width=360, caption="Selected image", output_format="auto") | |
| else: | |
| st.info("No image selected. Upload or click an example below.") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # full column | |
| if 'selected_image' in st.session_state: | |
| img_bytes = st.session_state.selected_image | |
| if st.session_state.get("last_image_bytes") != img_bytes: | |
| pred_label, conf, overlay_pil = classify_and_gradcam(img_bytes) | |
| st.session_state["pred_label"] = pred_label | |
| st.session_state["conf"] = conf | |
| st.session_state["overlay_pil"] = overlay_pil | |
| st.session_state["last_image_bytes"] = img_bytes | |
| try: | |
| with st.spinner("Loading VLM Model. Please be patient..."): | |
| try: | |
| vlm_info = load_vlm_model() | |
| except Exception as e: | |
| st.error("VLM load failed. See logs above.") | |
| vlm_info = None | |
| if vlm_info is not None: | |
| try: | |
| img_for_vlm = overlay_pil.convert("RGB").resize((224, 224), Image.BILINEAR) | |
| except Exception: | |
| st.warning("Overlay image not available for VLM input; using original image.") | |
| img_for_vlm = pil_img.convert("RGB").resize((224, 224), Image.BILINEAR) | |
| with st.spinner("Generating Explanation...."): | |
| response = generate_vlm_response( | |
| vlm_info["processor"], | |
| vlm_info["model"], | |
| vlm_info["device"], | |
| img_for_vlm, | |
| pred_label, | |
| max_new_tokens=128 | |
| ) | |
| #response = "Debugging VLM response." # For debugging | |
| if response is None: | |
| st.error("VLM did not return a response.") | |
| else: | |
| st.session_state["vlm_response"] = response | |
| except Exception as e: | |
| st.error(f"Error in VLM generation flow: {e}") | |
| st.exception(traceback.format_exc()) | |
| with attention_column: | |
| st.markdown("<div class='card'>", unsafe_allow_html=True) | |
| st.subheader("Model Attention Visualization") | |
| if 'selected_image' in st.session_state: | |
| st.image(st.session_state["overlay_pil"], caption="Model Attention Overlay", width=360, output_format="auto") | |
| else: | |
| st.info("Model Attention will appear here after selecting an image and running classification.") | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Metrics placeholder | |
| c1, c2 = st.columns([3,1]) | |
| if st.session_state.get("selected_image") and st.session_state.get("pred_label"): | |
| c1.metric("Predicted", full_names[st.session_state["pred_label"]]) | |
| c2.metric("Confidence", f"{st.session_state['conf']:.2f}") | |
| else: | |
| c1.metric("Predicted", "β") | |
| c2.metric("Confidence", "β") | |
| # VLM Response placeholder | |
| st.subheader("Generated Explanation") | |
| if st.session_state.get("vlm_response"): | |
| st.info(st.session_state["vlm_response"]) | |
| else: | |
| st.info("VLM explanation will appear here after selecting an image and running classification.") | |
| example_paths = [ | |
| "images/ISIC_0025314.jpg", | |
| "images/ISIC_0025586.jpg", | |
| "images/ISIC_0025680.jpg", | |
| "images/ISIC_0026163.jpg" | |
| ] | |
| # Container div for toggle + gallery | |
| st.markdown(""" | |
| <div style='background-color:#f9fafb; padding:15px; border-radius:12px; margin-bottom:20px;'> | |
| """, unsafe_allow_html=True) | |
| toggle = st.toggle("Show Example Images", value=False) | |
| if toggle: | |
| # Toggle ON β show gallery | |
| st.markdown("<div class='header-sub'>Click on any image to analyze it instantly</div>", unsafe_allow_html=True) | |
| html = """ | |
| <style> | |
| .example-img { | |
| border-radius:10px; | |
| width:100%; | |
| display:block; | |
| box-shadow: 0 4px 12px rgba(14,30,37,0.06); | |
| transition: transform .12s ease, box-shadow .12s ease; | |
| cursor: pointer; | |
| } | |
| .example-img:hover { | |
| transform: scale(1.03); | |
| box-shadow: 0 14px 30px rgba(14,30,37,0.10); | |
| } | |
| .gallery-row { display:flex; gap:20px; } | |
| .gallery-item { flex:1; } | |
| </style> | |
| <div class="gallery-row"> | |
| """ | |
| for i, path in enumerate(example_paths): | |
| src = image_to_data_uri(path, max_width=480, jpeg_quality=70) | |
| if src is None: | |
| placeholder_svg = """ | |
| <svg xmlns='http://www.w3.org/2000/svg' width='400' height='300'> | |
| <rect width='100%' height='100%' fill='#f3f4f6'/> | |
| <text x='50%' y='50%' dominant-baseline='middle' text-anchor='middle' | |
| fill='#9ca3af' font-size='20'>missing</text> | |
| </svg> | |
| """ | |
| src = "data:image/svg+xml;base64," + base64.b64encode(placeholder_svg.encode()).decode() | |
| html += f""" | |
| <a href='#' id='img_{i}' class='gallery-item'> | |
| <img src='{src}' class='example-img' /> | |
| </a> | |
| """ | |
| html += "</div>" | |
| if "example_click_key" not in st.session_state: | |
| st.session_state.example_click_key = 0 | |
| clicked = click_detector(html, key=f"clicking_examples_{st.session_state.example_click_key}") | |
| if clicked: | |
| if uploaded_file is not None: | |
| st.warning("Please remove the uploaded file by clickng cross in the uploaded file name") | |
| else: | |
| idx = int(clicked.split("_")[1]) | |
| selected_path = example_paths[idx] | |
| img_bytes = open(selected_path, "rb").read() | |
| if st.session_state.get("last_image_bytes") != img_bytes: | |
| st.session_state.selected_image = img_bytes | |
| st.session_state.example_selected = True | |
| st.session_state["vlm_response"] = None | |
| st.session_state.example_click_key += 1 | |
| try: | |
| st.toast(f"β Selected image: {selected_path}", icon="πΈ") | |
| except Exception: | |
| st.success(f"Selected image: {selected_path}") | |
| st.rerun() | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| st.markdown(""" | |
| <div style='margin-top:12px; color:#6b7280; font-size:13px;'> | |
| Β© 2025 Faysal Ahmmed, Ajmy Alaly, Samanta Mehnaj, Asef Rahman, F.M. Mridha. All rights reserved. | |
| </div> | |
| """, unsafe_allow_html=True) | |