Faysal4200 commited on
Commit
fd356db
·
verified ·
1 Parent(s): 1692219

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +573 -573
app.py CHANGED
@@ -1,573 +1,573 @@
1
- import streamlit as st
2
- loader_placeholder = st.empty()
3
- loader_placeholder.markdown("""
4
- <div style="
5
- display:flex;
6
- justify-content:center;
7
- align-items:center;
8
- height:50vh;
9
- font-size:40px;
10
- font-weight:bold;
11
- color:#00b4d8;
12
- animation: flash 1s infinite;
13
- ">
14
- Loading necessary libraries...
15
- </div>
16
-
17
- <style>
18
- @keyframes flash {
19
- 0% { opacity: 0.2; }
20
- 50% { opacity: 1; }
21
- 100% { opacity: 0.2; }
22
- }
23
- </style>
24
- """, unsafe_allow_html=True)
25
- import numpy as np
26
- from st_click_detector import click_detector
27
- import cv2
28
- from PIL import Image
29
- import tensorflow as tf
30
- from tensorflow.keras.models import load_model
31
- from tf_keras_vis.gradcam_plus_plus import GradcamPlusPlus
32
- from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
33
- from tf_keras_vis.utils.scores import CategoricalScore
34
- import matplotlib.pyplot as plt
35
- import torch
36
- from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
37
- from peft import PeftModel
38
- import base64
39
- import os
40
- import io
41
- import traceback
42
- from tensorflow.keras.layers import (
43
- Layer, Conv2D, Dense,
44
- GlobalAveragePooling2D, GlobalMaxPooling2D,
45
- Reshape, Multiply, Add, Activation, Concatenate
46
- )
47
- from pathlib import Path
48
-
49
- loader_placeholder.empty()
50
- #--------------------------------------------------------------------------------------------------
51
- # unnecessary for this app, but needed for CNN model to load, so its necessary actually
52
- #--------------------------------------------------------------------------------------------------
53
- @tf.keras.utils.register_keras_serializable(package="Custom", name="F1Score")
54
- class F1Score(tf.keras.metrics.Metric):
55
- def __init__(self, name='f1_score', **kwargs):
56
- super().__init__(name=name, **kwargs)
57
- self.precision = tf.keras.metrics.Precision()
58
- self.recall = tf.keras.metrics.Recall()
59
-
60
- def update_state(self, y_true, y_pred, sample_weight=None):
61
- self.precision.update_state(y_true, y_pred, sample_weight)
62
- self.recall.update_state(y_true, y_pred, sample_weight)
63
-
64
- def result(self):
65
- p = self.precision.result()
66
- r = self.recall.result()
67
- return 2 * (p * r) / (p + r + tf.keras.backend.epsilon())
68
-
69
- def reset_states(self):
70
- self.precision.reset_states()
71
- self.recall.reset_states()
72
-
73
- @tf.keras.utils.register_keras_serializable(package="Custom", name="ChannelAttention")
74
- class ChannelAttention(Layer):
75
- def __init__(self, reduction=16, **kwargs):
76
- super(ChannelAttention, self).__init__(**kwargs)
77
- self.reduction = reduction
78
-
79
- def build(self, input_shape):
80
- channel = input_shape[-1]
81
- self.shared_dense_one = Dense(channel // self.reduction, activation='relu', kernel_initializer='he_normal', use_bias=True)
82
- self.shared_dense_two = Dense(channel, kernel_initializer='he_normal', use_bias=True)
83
-
84
- def call(self, inputs):
85
- avg_pool = GlobalAveragePooling2D()(inputs)
86
- max_pool = GlobalMaxPooling2D()(inputs)
87
-
88
- avg_pool = self.shared_dense_one(avg_pool)
89
- avg_pool = self.shared_dense_two(avg_pool)
90
-
91
- max_pool = self.shared_dense_one(max_pool)
92
- max_pool = self.shared_dense_two(max_pool)
93
-
94
- attention = Add()([avg_pool, max_pool])
95
- attention = Activation('sigmoid')(attention)
96
-
97
- attention = Reshape((1, 1, -1))(attention)
98
- return Multiply()([inputs, attention])
99
-
100
- @tf.keras.utils.register_keras_serializable(package="Custom", name="SpatialAttention")
101
- class SpatialAttention(Layer):
102
- def __init__(self, **kwargs):
103
- super(SpatialAttention, self).__init__(**kwargs)
104
- self.conv2d = Conv2D(filters=1, kernel_size=7, strides=1, padding='same', activation='sigmoid')
105
- def call(self, inputs):
106
- avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
107
- max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
108
- concat = Concatenate(axis=-1)([avg_pool, max_pool])
109
- attention = self.conv2d(concat)
110
- return Multiply()([inputs, attention])
111
-
112
- def cbam_block(inputs, reduction=16):
113
- x = ChannelAttention(reduction)(inputs)
114
- x = SpatialAttention()(x)
115
- return x
116
- #----------------------------------------------------------------------------------------------------------
117
- #---------------------------------------------------------------------------------------------------------
118
-
119
- # -------------------------
120
- # Helpers & small utilities
121
- # -------------------------
122
- def bytes_from_path(path):
123
- with open(path, "rb") as f:
124
- return f.read()
125
-
126
- def image_to_data_uri(path: str, max_width=224, jpeg_quality=70):
127
- p = Path(path)
128
- if not p.exists():
129
- return None
130
- img = Image.open(p).convert("RGB")
131
- # resize maintaining aspect ratio
132
- if img.width > max_width:
133
- new_h = int(max_width * img.height / img.width)
134
- img = img.resize((max_width, new_h), Image.BILINEAR)
135
- buf = io.BytesIO()
136
- img.save(buf, format="JPEG", quality=jpeg_quality, optimize=True)
137
- b = buf.getvalue()
138
- data64 = base64.b64encode(b).decode("utf-8")
139
- return f"data:image/jpeg;base64,{data64}"
140
-
141
-
142
- labels = ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
143
- full_names = {
144
- 'akiec': 'Actinic keratoses',
145
- 'bcc': 'Basal cell carcinoma',
146
- 'bkl': 'Benign keratosis-like lesions',
147
- 'df': 'Dermatofibroma',
148
- 'mel': 'Melanoma',
149
- 'nv': 'Melanocytic nevi',
150
- 'vasc': 'Vascular lesions'
151
- }
152
-
153
- def preprocess_image(image):
154
- if image.dtype != np.uint8:
155
- image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
156
- lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
157
- clahe = cv2.createCLAHE(clipLimit=0.01, tileGridSize=(8, 8))
158
- lab[:, :, 0] = clahe.apply(lab[:, :, 0])
159
- image_clahe = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
160
- image_clahe = image_clahe.astype(np.float32)
161
- image_clahe = (image_clahe - np.min(image_clahe)) / (np.ptp(image_clahe) + 1e-8)
162
- return image_clahe
163
-
164
- @st.cache_resource(show_spinner=False)
165
- def load_cnn_model(model_path="Proposed CBAM-Xception-DermNet.keras"):
166
- if 'cnn_model' in st.session_state:
167
- return st.session_state.cnn_model
168
- try:
169
- model = load_model(model_path)
170
- st.session_state.cnn_model = model
171
- return model
172
- except Exception as e:
173
- st.error(f"Failed to load CNN model from '{model_path}': {e}")
174
- st.exception(traceback.format_exc())
175
- raise
176
-
177
- @st.cache_resource(show_spinner=False)
178
- def load_vlm_model():
179
- if st.session_state.get("vlm_loaded", False):
180
- return {
181
- "model": st.session_state.vlm_model,
182
- "processor": st.session_state.processor,
183
- "device": st.session_state.device,
184
- "dtype": st.session_state.dtype
185
- }
186
-
187
- USE_4BIT = True
188
- HF_MODEL_ID = "google/medgemma-4b-it" # Hugging Face repo ID
189
- LORA_OUTPUT_DIR = "./medgemma_lora_adapter" #local lora saved dir
190
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
191
- hf_token = os.getenv("HF_TOKEN") #NOTE: hiding mandatory (reminder)
192
-
193
- # Determine dtype
194
- capability = torch.cuda.get_device_capability(0)[0] if torch.cuda.is_available() else 0
195
- dtype = torch.bfloat16 if torch.cuda.is_available() and capability >= 8 else torch.float32
196
-
197
- # 4-bit quantization config
198
- bnb_config = None
199
- if USE_4BIT:
200
- bnb_config = BitsAndBytesConfig(
201
- load_in_4bit=True,
202
- bnb_4bit_use_double_quant=True,
203
- bnb_4bit_quant_type="nf4",
204
- bnb_4bit_compute_dtype=dtype,
205
- )
206
-
207
- # Load processor from LoRA adapter folder (it contains tokenizer, etc.)
208
- try:
209
- processor = AutoProcessor.from_pretrained(
210
- LORA_OUTPUT_DIR,
211
- trust_remote_code=True
212
- )
213
- processor.tokenizer.padding_side = "right"
214
- except Exception as e:
215
- st.error(f"Failed to load processor from '{LORA_OUTPUT_DIR}': {e}")
216
- st.exception(traceback.format_exc())
217
- raise
218
-
219
- # Load base model from Hugging Face hub
220
- try:
221
- base_model = AutoModelForImageTextToText.from_pretrained(
222
- HF_MODEL_ID,
223
- quantization_config=bnb_config if USE_4BIT else None,
224
- dtype=dtype,
225
- device_map="auto",
226
- trust_remote_code=True,
227
- use_auth_token=hf_token # only needed if repo is private
228
- )
229
- except Exception as e:
230
- st.error(f"Failed to load base model from Hugging Face hub: {e}")
231
- st.exception(traceback.format_exc())
232
- raise
233
-
234
- # Attach LoRA adapter
235
- try:
236
- model = PeftModel.from_pretrained(
237
- base_model,
238
- LORA_OUTPUT_DIR,
239
- device_map="auto"
240
- )
241
- except Exception as e:
242
- st.error(f"Failed to attach LoRA adapter: {e}")
243
- st.exception(traceback.format_exc())
244
- raise
245
- model.eval()
246
- try:
247
- model.to(DEVICE)
248
- except Exception:
249
- # ignore if model already on correct device
250
- pass
251
- # Cache into session_state
252
- st.session_state.vlm_model = model
253
- st.session_state.processor = processor
254
- st.session_state.device = DEVICE
255
- st.session_state.dtype = dtype
256
- st.session_state.vlm_loaded = True
257
-
258
- return {"model": model, "processor": processor, "device": DEVICE, "dtype": dtype}
259
-
260
-
261
- def generate_vlm_response(processor, vlm_model, device, gradcam_image: Image.Image, pred_label,
262
- max_new_tokens=128):
263
- try:
264
- prompt_template = (
265
- "You are an AI assistant specialized in model interpretability. "
266
- "I am providing:\n- CNN model Grad-CAM++ heatmap image\n- Model predicted class: {predicted_class}\n\n"
267
- "Based on the Grad-CAM++ heatmap, write a clear and concise 20–30 word explanation "
268
- "of which features the model focused on and why. Output only the explanation (no headings)."
269
- )
270
- user_prompt = prompt_template.format(predicted_class=pred_label)
271
-
272
- chat = [
273
- {
274
- "role": "user",
275
- "content": [
276
- {"type": "image"},
277
- {"type": "text", "text": user_prompt}
278
- ],
279
- }
280
- ]
281
- formatted_prompt = processor.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
282
- inputs = processor(text=formatted_prompt, images=gradcam_image, return_tensors="pt", padding=True)
283
-
284
- try:
285
- inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
286
- except Exception:
287
- for k, v in inputs.items():
288
- if isinstance(v, torch.Tensor):
289
- inputs[k] = v.to(device)
290
-
291
- if hasattr(inputs, "pixel_values") or ("pixel_values" in inputs):
292
- try:
293
- inputs["pixel_values"] = inputs["pixel_values"].to(dtype=vlm_model.dtype)
294
- except Exception:
295
- try:
296
- inputs["pixel_values"] = inputs["pixel_values"].to(dtype=torch.float16)
297
- except Exception:
298
- pass
299
-
300
- with torch.inference_mode():
301
- output_ids = vlm_model.generate(
302
- **inputs,
303
- max_new_tokens=max_new_tokens,
304
- do_sample=False,
305
- pad_token_id=processor.tokenizer.eos_token_id,
306
- )
307
-
308
- # Some generate wrappers return object with .sequences
309
- if hasattr(output_ids, "sequences"):
310
- seqs = output_ids.sequences
311
- else:
312
- seqs = output_ids
313
-
314
- input_len = inputs["input_ids"].shape[-1]
315
- response = processor.decode(seqs[0, input_len:], skip_special_tokens=True)
316
- return response.strip()
317
-
318
- except Exception as e:
319
- st.error(f"VLM generation failed: {e}")
320
- st.exception(traceback.format_exc())
321
- return None
322
-
323
- def classify_and_gradcam(image_bytes):
324
- pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
325
- preprocessed = preprocess_image(np.array(pil_img))
326
- input_tensor = np.expand_dims(preprocessed, axis=0)
327
- with st.spinner("Loading Classifier Model..."):
328
- cnn = load_cnn_model("Proposed CBAM-Xception-DermNet.keras")
329
- with st.spinner("Classifying..."):
330
- preds = cnn.predict(input_tensor)[0]
331
- pred_idx = int(np.argmax(preds))
332
- pred_label = labels[pred_idx]
333
- conf = float(preds[pred_idx])
334
- with st.spinner("Generating Attention Map..."):
335
- target_layer = "block14_sepconv2"
336
- score = CategoricalScore([pred_idx])
337
- gradcam_vis = GradcamPlusPlus(cnn, model_modifier=ReplaceToLinear(), clone=True)
338
- cam = gradcam_vis(score, input_tensor, penultimate_layer=target_layer)[0]
339
- cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
340
- heatmap = plt.cm.jet(cam)[..., :3]
341
- overlay = 0.25 * heatmap + 0.75 * preprocessed
342
- overlay = np.uint8(255 * np.clip(overlay, 0, 1))
343
- overlay_pil = Image.fromarray(overlay)
344
-
345
- return pred_label, conf, overlay_pil
346
-
347
- # -------------------------
348
- # Main display config & styling
349
- # -------------------------
350
- st.set_page_config(page_title="Skin Cancer Classifier", layout="wide", initial_sidebar_state="expanded")
351
-
352
- st.markdown("""
353
- <style>
354
- .stApp { background: linear-gradient(180deg, #f5f7fb 0%, #ffffff 100%); }
355
- .card { background: white; border-radius: 12px; padding: 14px; box-shadow: 0 8px 22px rgba(14,30,37,0.06); }
356
- .header-title { font-size:34px; font-weight:700; margin-bottom:4px; }
357
- .header-sub { color:#6b7280; margin-bottom:6px; }
358
- .small { font-size:13px; color:#6b7280; }
359
- </style>
360
- """, unsafe_allow_html=True)
361
-
362
- with st.sidebar:
363
- st.header("Important Notice")
364
- st.markdown("""
365
- - This app is a prototype, not for clinical use.
366
- - Do not rely on classifications or explanations for medical decisions.
367
- - This apps model is fine tuned on only one small dataset.
368
- - It might not capture your original disease.
369
- - Always consult a qualified healthcare professional.
370
- - Results may not be accurate; use at your own risk.
371
- - Again, this is just a prototype!
372
- """, unsafe_allow_html=True)
373
-
374
- st.markdown("---")
375
- if st.button("Clear Models Cache"):
376
- for k in ["cnn_model", "vlm_model", "processor", "device", "dtype", "vlm_loaded"]:
377
- if k in st.session_state:
378
- del st.session_state[k]
379
- st.success("Model cache cleared. Models will reload on next use.")
380
-
381
-
382
- st.markdown("<div class='header-title'>Skin Cancer Image Classifier</div>", unsafe_allow_html=True)
383
- st.markdown("<div class='header-sub'>Local CNN inference • Model Attention (Grad-CAM++) visualizations • optional VLM explanations</div>", unsafe_allow_html=True)
384
-
385
- uploaded_file = st.file_uploader("Upload a skin lesion image", type=["jpg","jpeg","png"], key="uploaded_file" )
386
-
387
- # --- Handle automatic reset if file is cleared ---
388
- #if uploaded_file is None and "selected_image" in st.session_state:
389
- # # Only clear if user manually removed an uploaded file
390
- # if not st.session_state.get("example_selected", False):
391
- # for key in ["selected_image", "vlm_response"]:
392
- # st.session_state.pop(key, None)
393
- # st.rerun()
394
-
395
- if uploaded_file is not None:
396
- st.session_state.selected_image = uploaded_file.read()
397
- st.session_state.example_selected = False
398
- st.session_state["vlm_response"] = None
399
-
400
- if uploaded_file is None and not st.session_state.get("example_selected", False):
401
- keys_to_clear = ["vlm_response", "pred_label", "conf", "overlay_pil", "last_image_bytes", "selected_image"]
402
- for k in keys_to_clear:
403
- if k in st.session_state:
404
- del st.session_state[k]
405
-
406
- # Main layout: image area and visualization
407
- original_image_col, attention_column = st.columns([2,2])
408
-
409
- with original_image_col:
410
- st.markdown("<div class='card'>", unsafe_allow_html=True)
411
- st.subheader("Selected Image")
412
- if 'selected_image' in st.session_state:
413
- pil_img = Image.open(io.BytesIO(st.session_state.selected_image)).convert("RGB")
414
- st.image(pil_img, width=360, caption="Selected image", output_format="auto")
415
- else:
416
- st.info("No image selected. Upload or click an example below.")
417
- st.markdown("</div>", unsafe_allow_html=True)
418
-
419
- # full column
420
- if 'selected_image' in st.session_state:
421
- img_bytes = st.session_state.selected_image
422
- if st.session_state.get("last_image_bytes") != img_bytes:
423
- pred_label, conf, overlay_pil = classify_and_gradcam(img_bytes)
424
- st.session_state["pred_label"] = pred_label
425
- st.session_state["conf"] = conf
426
- st.session_state["overlay_pil"] = overlay_pil
427
- st.session_state["last_image_bytes"] = img_bytes
428
- try:
429
- with st.spinner("Loading VLM Model. Please be patient..."):
430
- try:
431
- vlm_info = load_vlm_model()
432
- except Exception as e:
433
- st.error("VLM load failed. See logs above.")
434
- vlm_info = None
435
-
436
- if vlm_info is not None:
437
- try:
438
- img_for_vlm = overlay_pil.convert("RGB").resize((224, 224), Image.BILINEAR)
439
- except Exception:
440
- st.warning("Overlay image not available for VLM input; using original image.")
441
- img_for_vlm = pil_img.convert("RGB").resize((224, 224), Image.BILINEAR)
442
-
443
- with st.spinner("Generating Explanation...."):
444
- response = generate_vlm_response(
445
- vlm_info["processor"],
446
- vlm_info["model"],
447
- vlm_info["device"],
448
- img_for_vlm,
449
- pred_label,
450
- max_new_tokens=128
451
- )
452
- #response = "Debugging VLM response." # For debugging
453
- if response is None:
454
- st.error("VLM did not return a response.")
455
- else:
456
- st.session_state["vlm_response"] = response
457
- except Exception as e:
458
- st.error(f"Error in VLM generation flow: {e}")
459
- st.exception(traceback.format_exc())
460
-
461
-
462
- with attention_column:
463
- st.markdown("<div class='card'>", unsafe_allow_html=True)
464
- st.subheader("Model Attention Visualization")
465
- if 'selected_image' in st.session_state:
466
- st.image(st.session_state["overlay_pil"], caption="Model Attention Overlay", width=360, output_format="auto")
467
- else:
468
- st.info("Model Attention will appear here after selecting an image and running classification.")
469
- st.markdown("</div>", unsafe_allow_html=True)
470
-
471
-
472
- # Metrics placeholder
473
- c1, c2 = st.columns([3,1])
474
- if st.session_state.get("selected_image") and st.session_state.get("pred_label"):
475
- c1.metric("Predicted", full_names[st.session_state["pred_label"]])
476
- c2.metric("Confidence", f"{st.session_state['conf']:.2f}")
477
- else:
478
- c1.metric("Predicted", "—")
479
- c2.metric("Confidence", "—")
480
-
481
- # VLM Response placeholder
482
- st.subheader("Generated Explanation")
483
- if st.session_state.get("vlm_response"):
484
- st.info(st.session_state["vlm_response"])
485
- else:
486
- st.info("VLM explanation will appear here after selecting an image and running classification.")
487
-
488
- example_paths = [
489
- "images/ISIC_0025314.jpg",
490
- "images/ISIC_0025586.jpg",
491
- "images/ISIC_0025680.jpg",
492
- "images/ISIC_0026163.jpg"
493
- ]
494
-
495
- # Container div for toggle + gallery
496
- st.markdown("""
497
- <div style='background-color:#f9fafb; padding:15px; border-radius:12px; margin-bottom:20px;'>
498
- """, unsafe_allow_html=True)
499
-
500
- toggle = st.toggle("Show Example Images", value=False)
501
-
502
- if toggle:
503
- # Toggle ON → show gallery
504
- st.markdown("<div class='header-sub'>Click on any image to analyze it instantly</div>", unsafe_allow_html=True)
505
- html = """
506
- <style>
507
- .example-img {
508
- border-radius:10px;
509
- width:100%;
510
- display:block;
511
- box-shadow: 0 4px 12px rgba(14,30,37,0.06);
512
- transition: transform .12s ease, box-shadow .12s ease;
513
- cursor: pointer;
514
- }
515
- .example-img:hover {
516
- transform: scale(1.03);
517
- box-shadow: 0 14px 30px rgba(14,30,37,0.10);
518
- }
519
- .gallery-row { display:flex; gap:20px; }
520
- .gallery-item { flex:1; }
521
- </style>
522
- <div class="gallery-row">
523
- """
524
-
525
- for i, path in enumerate(example_paths):
526
- src = image_to_data_uri(path, max_width=480, jpeg_quality=70)
527
- if src is None:
528
- placeholder_svg = """
529
- <svg xmlns='http://www.w3.org/2000/svg' width='400' height='300'>
530
- <rect width='100%' height='100%' fill='#f3f4f6'/>
531
- <text x='50%' y='50%' dominant-baseline='middle' text-anchor='middle'
532
- fill='#9ca3af' font-size='20'>missing</text>
533
- </svg>
534
- """
535
- src = "data:image/svg+xml;base64," + base64.b64encode(placeholder_svg.encode()).decode()
536
-
537
- html += f"""
538
- <a href='#' id='img_{i}' class='gallery-item'>
539
- <img src='{src}' class='example-img' />
540
- </a>
541
- """
542
-
543
- html += "</div>"
544
-
545
- if "example_click_key" not in st.session_state:
546
- st.session_state.example_click_key = 0
547
-
548
- clicked = click_detector(html, key=f"clicking_examples_{st.session_state.example_click_key}")
549
-
550
- if clicked:
551
- if uploaded_file is not None:
552
- st.warning("Please remove the uploaded file by clickng cross in the uploaded file name")
553
- else:
554
- idx = int(clicked.split("_")[1])
555
- selected_path = example_paths[idx]
556
- img_bytes = open(selected_path, "rb").read()
557
- if st.session_state.get("last_image_bytes") != img_bytes:
558
- st.session_state.selected_image = img_bytes
559
- st.session_state.example_selected = True
560
- st.session_state["vlm_response"] = None
561
- st.session_state.example_click_key += 1
562
- try:
563
- st.toast(f"✅ Selected image: {selected_path}", icon="📸")
564
- except Exception:
565
- st.success(f"Selected image: {selected_path}")
566
- st.rerun()
567
- st.markdown("</div>", unsafe_allow_html=True)
568
-
569
- st.markdown("""
570
- <div style='margin-top:12px; color:#6b7280; font-size:13px;'>
571
- © 2025 Faysal Ahmmed, Ajmy Alaly, Samanta Mehnaj, Asef Rahman, F.M. Mridha. All rights reserved.
572
- </div>
573
- """, unsafe_allow_html=True)
 
1
+ import streamlit as st
2
+ loader_placeholder = st.empty()
3
+ loader_placeholder.markdown("""
4
+ <div style="
5
+ display:flex;
6
+ justify-content:center;
7
+ align-items:center;
8
+ height:50vh;
9
+ font-size:40px;
10
+ font-weight:bold;
11
+ color:#00b4d8;
12
+ animation: flash 1s infinite;
13
+ ">
14
+ Loading necessary libraries...
15
+ </div>
16
+
17
+ <style>
18
+ @keyframes flash {
19
+ 0% { opacity: 0.2; }
20
+ 50% { opacity: 1; }
21
+ 100% { opacity: 0.2; }
22
+ }
23
+ </style>
24
+ """, unsafe_allow_html=True)
25
+ import numpy as np
26
+ from st_click_detector import click_detector
27
+ import cv2
28
+ from PIL import Image
29
+ import tensorflow as tf
30
+ from tensorflow.keras.models import load_model
31
+ from tf_keras_vis.gradcam_plus_plus import GradcamPlusPlus
32
+ from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
33
+ from tf_keras_vis.utils.scores import CategoricalScore
34
+ import matplotlib.pyplot as plt
35
+ import torch
36
+ from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
37
+ from peft import PeftModel
38
+ import base64
39
+ import os
40
+ import io
41
+ import traceback
42
+ from tensorflow.keras.layers import (
43
+ Layer, Conv2D, Dense,
44
+ GlobalAveragePooling2D, GlobalMaxPooling2D,
45
+ Reshape, Multiply, Add, Activation, Concatenate
46
+ )
47
+ from pathlib import Path
48
+
49
+ loader_placeholder.empty()
50
+ #--------------------------------------------------------------------------------------------------
51
+ # unnecessary for this app, but needed for CNN model to load, so its necessary actually
52
+ #--------------------------------------------------------------------------------------------------
53
+ @tf.keras.utils.register_keras_serializable(package="Custom", name="F1Score")
54
+ class F1Score(tf.keras.metrics.Metric):
55
+ def __init__(self, name='f1_score', **kwargs):
56
+ super().__init__(name=name, **kwargs)
57
+ self.precision = tf.keras.metrics.Precision()
58
+ self.recall = tf.keras.metrics.Recall()
59
+
60
+ def update_state(self, y_true, y_pred, sample_weight=None):
61
+ self.precision.update_state(y_true, y_pred, sample_weight)
62
+ self.recall.update_state(y_true, y_pred, sample_weight)
63
+
64
+ def result(self):
65
+ p = self.precision.result()
66
+ r = self.recall.result()
67
+ return 2 * (p * r) / (p + r + tf.keras.backend.epsilon())
68
+
69
+ def reset_states(self):
70
+ self.precision.reset_states()
71
+ self.recall.reset_states()
72
+
73
+ @tf.keras.utils.register_keras_serializable(package="Custom", name="ChannelAttention")
74
+ class ChannelAttention(Layer):
75
+ def __init__(self, reduction=16, **kwargs):
76
+ super(ChannelAttention, self).__init__(**kwargs)
77
+ self.reduction = reduction
78
+
79
+ def build(self, input_shape):
80
+ channel = input_shape[-1]
81
+ self.shared_dense_one = Dense(channel // self.reduction, activation='relu', kernel_initializer='he_normal', use_bias=True)
82
+ self.shared_dense_two = Dense(channel, kernel_initializer='he_normal', use_bias=True)
83
+
84
+ def call(self, inputs):
85
+ avg_pool = GlobalAveragePooling2D()(inputs)
86
+ max_pool = GlobalMaxPooling2D()(inputs)
87
+
88
+ avg_pool = self.shared_dense_one(avg_pool)
89
+ avg_pool = self.shared_dense_two(avg_pool)
90
+
91
+ max_pool = self.shared_dense_one(max_pool)
92
+ max_pool = self.shared_dense_two(max_pool)
93
+
94
+ attention = Add()([avg_pool, max_pool])
95
+ attention = Activation('sigmoid')(attention)
96
+
97
+ attention = Reshape((1, 1, -1))(attention)
98
+ return Multiply()([inputs, attention])
99
+
100
+ @tf.keras.utils.register_keras_serializable(package="Custom", name="SpatialAttention")
101
+ class SpatialAttention(Layer):
102
+ def __init__(self, **kwargs):
103
+ super(SpatialAttention, self).__init__(**kwargs)
104
+ self.conv2d = Conv2D(filters=1, kernel_size=7, strides=1, padding='same', activation='sigmoid')
105
+ def call(self, inputs):
106
+ avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
107
+ max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
108
+ concat = Concatenate(axis=-1)([avg_pool, max_pool])
109
+ attention = self.conv2d(concat)
110
+ return Multiply()([inputs, attention])
111
+
112
+ def cbam_block(inputs, reduction=16):
113
+ x = ChannelAttention(reduction)(inputs)
114
+ x = SpatialAttention()(x)
115
+ return x
116
+ #----------------------------------------------------------------------------------------------------------
117
+ #---------------------------------------------------------------------------------------------------------
118
+
119
+ # -------------------------
120
+ # Helpers & small utilities
121
+ # -------------------------
122
+ def bytes_from_path(path):
123
+ with open(path, "rb") as f:
124
+ return f.read()
125
+
126
+ def image_to_data_uri(path: str, max_width=224, jpeg_quality=70):
127
+ p = Path(path)
128
+ if not p.exists():
129
+ return None
130
+ img = Image.open(p).convert("RGB")
131
+ # resize maintaining aspect ratio
132
+ if img.width > max_width:
133
+ new_h = int(max_width * img.height / img.width)
134
+ img = img.resize((max_width, new_h), Image.BILINEAR)
135
+ buf = io.BytesIO()
136
+ img.save(buf, format="JPEG", quality=jpeg_quality, optimize=True)
137
+ b = buf.getvalue()
138
+ data64 = base64.b64encode(b).decode("utf-8")
139
+ return f"data:image/jpeg;base64,{data64}"
140
+
141
+
142
+ labels = ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
143
+ full_names = {
144
+ 'akiec': 'Actinic keratoses',
145
+ 'bcc': 'Basal cell carcinoma',
146
+ 'bkl': 'Benign keratosis-like lesions',
147
+ 'df': 'Dermatofibroma',
148
+ 'mel': 'Melanoma',
149
+ 'nv': 'Melanocytic nevi',
150
+ 'vasc': 'Vascular lesions'
151
+ }
152
+
153
+ def preprocess_image(image):
154
+ if image.dtype != np.uint8:
155
+ image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
156
+ lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
157
+ clahe = cv2.createCLAHE(clipLimit=0.01, tileGridSize=(8, 8))
158
+ lab[:, :, 0] = clahe.apply(lab[:, :, 0])
159
+ image_clahe = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
160
+ image_clahe = image_clahe.astype(np.float32)
161
+ image_clahe = (image_clahe - np.min(image_clahe)) / (np.ptp(image_clahe) + 1e-8)
162
+ return image_clahe
163
+
164
+ @st.cache_resource(show_spinner=False)
165
+ def load_cnn_model(model_path="Proposed CBAM-Xception-DermNet.keras"):
166
+ if 'cnn_model' in st.session_state:
167
+ return st.session_state.cnn_model
168
+ try:
169
+ model = load_model(model_path)
170
+ st.session_state.cnn_model = model
171
+ return model
172
+ except Exception as e:
173
+ st.error(f"Failed to load CNN model from '{model_path}': {e}")
174
+ st.exception(traceback.format_exc())
175
+ raise
176
+
177
+ @st.cache_resource(show_spinner=False)
178
+ def load_vlm_model():
179
+ if st.session_state.get("vlm_loaded", False):
180
+ return {
181
+ "model": st.session_state.vlm_model,
182
+ "processor": st.session_state.processor,
183
+ "device": st.session_state.device,
184
+ "dtype": st.session_state.dtype
185
+ }
186
+
187
+ USE_4BIT = True
188
+ HF_MODEL_ID = "google/medgemma-4b-it" # Hugging Face repo ID
189
+ LORA_OUTPUT_DIR = "./medgemma_lora_adapter" #local lora saved dir
190
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
191
+ hf_token = os.getenv("HF_TOKEN") #NOTE: hiding mandatory (reminder)
192
+
193
+ # Determine dtype
194
+ capability = torch.cuda.get_device_capability(0)[0] if torch.cuda.is_available() else 0
195
+ dtype = torch.bfloat16 if torch.cuda.is_available() and capability >= 8 else torch.float32
196
+
197
+ # 4-bit quantization config
198
+ bnb_config = None
199
+ if USE_4BIT:
200
+ bnb_config = BitsAndBytesConfig(
201
+ load_in_4bit=True,
202
+ bnb_4bit_use_double_quant=True,
203
+ bnb_4bit_quant_type="nf4",
204
+ bnb_4bit_compute_dtype=dtype,
205
+ )
206
+
207
+ # Load processor from LoRA adapter folder (it contains tokenizer, etc.)
208
+ try:
209
+ processor = AutoProcessor.from_pretrained(
210
+ LORA_OUTPUT_DIR,
211
+ trust_remote_code=True
212
+ )
213
+ processor.tokenizer.padding_side = "right"
214
+ except Exception as e:
215
+ st.error(f"Failed to load processor from '{LORA_OUTPUT_DIR}': {e}")
216
+ st.exception(traceback.format_exc())
217
+ raise
218
+
219
+ # Load base model from Hugging Face hub
220
+ try:
221
+ base_model = AutoModelForImageTextToText.from_pretrained(
222
+ HF_MODEL_ID,
223
+ quantization_config=bnb_config if USE_4BIT else None,
224
+ dtype=dtype,
225
+ device_map="auto",
226
+ trust_remote_code=True,
227
+ use_auth_token=hf_token # only needed if repo is private
228
+ )
229
+ except Exception as e:
230
+ st.error(f"Failed to load base model from Hugging Face hub: {e}")
231
+ st.exception(traceback.format_exc())
232
+ raise
233
+
234
+ # Attach LoRA adapter
235
+ try:
236
+ model = PeftModel.from_pretrained(
237
+ base_model,
238
+ LORA_OUTPUT_DIR,
239
+ device_map="auto"
240
+ )
241
+ except Exception as e:
242
+ st.error(f"Failed to attach LoRA adapter: {e}")
243
+ st.exception(traceback.format_exc())
244
+ raise
245
+ model.eval()
246
+ try:
247
+ model.to(DEVICE)
248
+ except Exception:
249
+ # ignore if model already on correct device
250
+ pass
251
+ # Cache into session_state
252
+ st.session_state.vlm_model = model
253
+ st.session_state.processor = processor
254
+ st.session_state.device = DEVICE
255
+ st.session_state.dtype = dtype
256
+ st.session_state.vlm_loaded = True
257
+
258
+ return {"model": model, "processor": processor, "device": DEVICE, "dtype": dtype}
259
+
260
+
261
+ def generate_vlm_response(processor, vlm_model, device, gradcam_image: Image.Image, pred_label,
262
+ max_new_tokens=128):
263
+ try:
264
+ prompt_template = (
265
+ "You are an AI assistant specialized in model interpretability. "
266
+ "I am providing:\n- CNN model Grad-CAM++ heatmap image\n- Model predicted class: {predicted_class}\n\n"
267
+ "Based on the Grad-CAM++ heatmap, write a clear and concise 20–30 word explanation "
268
+ "of which features the model focused on and why. Output only the explanation (no headings)."
269
+ )
270
+ user_prompt = prompt_template.format(predicted_class=pred_label)
271
+
272
+ chat = [
273
+ {
274
+ "role": "user",
275
+ "content": [
276
+ {"type": "image"},
277
+ {"type": "text", "text": user_prompt}
278
+ ],
279
+ }
280
+ ]
281
+ formatted_prompt = processor.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
282
+ inputs = processor(text=formatted_prompt, images=gradcam_image, return_tensors="pt", padding=True)
283
+
284
+ try:
285
+ inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
286
+ except Exception:
287
+ for k, v in inputs.items():
288
+ if isinstance(v, torch.Tensor):
289
+ inputs[k] = v.to(device)
290
+
291
+ if hasattr(inputs, "pixel_values") or ("pixel_values" in inputs):
292
+ try:
293
+ inputs["pixel_values"] = inputs["pixel_values"].to(dtype=vlm_model.dtype)
294
+ except Exception:
295
+ try:
296
+ inputs["pixel_values"] = inputs["pixel_values"].to(dtype=torch.float16)
297
+ except Exception:
298
+ pass
299
+
300
+ with torch.inference_mode():
301
+ output_ids = vlm_model.generate(
302
+ **inputs,
303
+ max_new_tokens=max_new_tokens,
304
+ do_sample=False,
305
+ pad_token_id=processor.tokenizer.eos_token_id,
306
+ )
307
+
308
+ # Some generate wrappers return object with .sequences
309
+ if hasattr(output_ids, "sequences"):
310
+ seqs = output_ids.sequences
311
+ else:
312
+ seqs = output_ids
313
+
314
+ input_len = inputs["input_ids"].shape[-1]
315
+ response = processor.decode(seqs[0, input_len:], skip_special_tokens=True)
316
+ return response.strip()
317
+
318
+ except Exception as e:
319
+ st.error(f"VLM generation failed: {e}")
320
+ st.exception(traceback.format_exc())
321
+ return None
322
+
323
+ def classify_and_gradcam(image_bytes):
324
+ pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
325
+ preprocessed = preprocess_image(np.array(pil_img))
326
+ input_tensor = np.expand_dims(preprocessed, axis=0)
327
+ with st.spinner("Loading Classifier Model..."):
328
+ cnn = load_cnn_model("Proposed CBAM-Xception-DermNet.keras")
329
+ with st.spinner("Classifying..."):
330
+ preds = cnn.predict(input_tensor)[0]
331
+ pred_idx = int(np.argmax(preds))
332
+ pred_label = labels[pred_idx]
333
+ conf = float(preds[pred_idx])
334
+ with st.spinner("Generating Attention Map..."):
335
+ target_layer = "block14_sepconv2"
336
+ score = CategoricalScore([pred_idx])
337
+ gradcam_vis = GradcamPlusPlus(cnn, model_modifier=ReplaceToLinear(), clone=True)
338
+ cam = gradcam_vis(score, input_tensor, penultimate_layer=target_layer)[0]
339
+ cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
340
+ heatmap = plt.cm.jet(cam)[..., :3]
341
+ overlay = 0.25 * heatmap + 0.75 * preprocessed
342
+ overlay = np.uint8(255 * np.clip(overlay, 0, 1))
343
+ overlay_pil = Image.fromarray(overlay)
344
+
345
+ return pred_label, conf, overlay_pil
346
+
347
+ # -------------------------
348
+ # Main display config & styling
349
+ # -------------------------
350
+ st.set_page_config(page_title="Skin Cancer Classifier", layout="wide", initial_sidebar_state="expanded")
351
+
352
+ st.markdown("""
353
+ <style>
354
+ .stApp { background: linear-gradient(180deg, #f5f7fb 0%, #ffffff 100%); }
355
+ .card { background: white; border-radius: 12px; padding: 14px; box-shadow: 0 8px 22px rgba(14,30,37,0.06); }
356
+ .header-title { font-size:34px; font-weight:700; margin-bottom:4px; }
357
+ .header-sub { color:#6b7280; margin-bottom:6px; }
358
+ .small { font-size:13px; color:#6b7280; }
359
+ </style>
360
+ """, unsafe_allow_html=True)
361
+
362
+ with st.sidebar:
363
+ st.header("Important Notice")
364
+ st.markdown("""
365
+ - This app is a prototype, not for clinical use.
366
+ - Do not rely on classifications or explanations for medical decisions.
367
+ - This apps model is fine tuned on only one small dataset.
368
+ - It might not capture your original disease.
369
+ - Always consult a qualified healthcare professional.
370
+ - Results may not be accurate; use at your own risk.
371
+ - Again, this is just a prototype!
372
+ """, unsafe_allow_html=True)
373
+
374
+ st.markdown("---")
375
+ if st.button("Clear Models Cache"):
376
+ for k in ["cnn_model", "vlm_model", "processor", "device", "dtype", "vlm_loaded"]:
377
+ if k in st.session_state:
378
+ del st.session_state[k]
379
+ st.success("Model cache cleared. Models will reload on next use.")
380
+
381
+
382
+ st.markdown("<div class='header-title'>Skin Cancer Image Classifier</div>", unsafe_allow_html=True)
383
+ st.markdown("<div class='header-sub'>CNN Classifier • Model Attention (Grad-CAM++) visualizations • VLM explanations</div>", unsafe_allow_html=True)
384
+
385
+ uploaded_file = st.file_uploader("Upload a skin lesion image", type=["jpg","jpeg","png"], key="uploaded_file" )
386
+
387
+ # --- Handle automatic reset if file is cleared ---
388
+ #if uploaded_file is None and "selected_image" in st.session_state:
389
+ # # Only clear if user manually removed an uploaded file
390
+ # if not st.session_state.get("example_selected", False):
391
+ # for key in ["selected_image", "vlm_response"]:
392
+ # st.session_state.pop(key, None)
393
+ # st.rerun()
394
+
395
+ if uploaded_file is not None:
396
+ st.session_state.selected_image = uploaded_file.read()
397
+ st.session_state.example_selected = False
398
+ st.session_state["vlm_response"] = None
399
+
400
+ if uploaded_file is None and not st.session_state.get("example_selected", False):
401
+ keys_to_clear = ["vlm_response", "pred_label", "conf", "overlay_pil", "last_image_bytes", "selected_image"]
402
+ for k in keys_to_clear:
403
+ if k in st.session_state:
404
+ del st.session_state[k]
405
+
406
+ # Main layout: image area and visualization
407
+ original_image_col, attention_column = st.columns([2,2])
408
+
409
+ with original_image_col:
410
+ st.markdown("<div class='card'>", unsafe_allow_html=True)
411
+ st.subheader("Selected Image")
412
+ if 'selected_image' in st.session_state:
413
+ pil_img = Image.open(io.BytesIO(st.session_state.selected_image)).convert("RGB")
414
+ st.image(pil_img, width=360, caption="Selected image", output_format="auto")
415
+ else:
416
+ st.info("No image selected. Upload or click an example below.")
417
+ st.markdown("</div>", unsafe_allow_html=True)
418
+
419
+ # full column
420
+ if 'selected_image' in st.session_state:
421
+ img_bytes = st.session_state.selected_image
422
+ if st.session_state.get("last_image_bytes") != img_bytes:
423
+ pred_label, conf, overlay_pil = classify_and_gradcam(img_bytes)
424
+ st.session_state["pred_label"] = pred_label
425
+ st.session_state["conf"] = conf
426
+ st.session_state["overlay_pil"] = overlay_pil
427
+ st.session_state["last_image_bytes"] = img_bytes
428
+ try:
429
+ with st.spinner("Loading VLM Model. Please be patient..."):
430
+ try:
431
+ vlm_info = load_vlm_model()
432
+ except Exception as e:
433
+ st.error("VLM load failed. See logs above.")
434
+ vlm_info = None
435
+
436
+ if vlm_info is not None:
437
+ try:
438
+ img_for_vlm = overlay_pil.convert("RGB").resize((224, 224), Image.BILINEAR)
439
+ except Exception:
440
+ st.warning("Overlay image not available for VLM input; using original image.")
441
+ img_for_vlm = pil_img.convert("RGB").resize((224, 224), Image.BILINEAR)
442
+
443
+ with st.spinner("Generating Explanation...."):
444
+ response = generate_vlm_response(
445
+ vlm_info["processor"],
446
+ vlm_info["model"],
447
+ vlm_info["device"],
448
+ img_for_vlm,
449
+ pred_label,
450
+ max_new_tokens=128
451
+ )
452
+ #response = "Debugging VLM response." # For debugging
453
+ if response is None:
454
+ st.error("VLM did not return a response.")
455
+ else:
456
+ st.session_state["vlm_response"] = response
457
+ except Exception as e:
458
+ st.error(f"Error in VLM generation flow: {e}")
459
+ st.exception(traceback.format_exc())
460
+
461
+
462
+ with attention_column:
463
+ st.markdown("<div class='card'>", unsafe_allow_html=True)
464
+ st.subheader("Model Attention Visualization")
465
+ if 'selected_image' in st.session_state:
466
+ st.image(st.session_state["overlay_pil"], caption="Model Attention Overlay", width=360, output_format="auto")
467
+ else:
468
+ st.info("Model Attention will appear here after selecting an image and running classification.")
469
+ st.markdown("</div>", unsafe_allow_html=True)
470
+
471
+
472
+ # Metrics placeholder
473
+ c1, c2 = st.columns([3,1])
474
+ if st.session_state.get("selected_image") and st.session_state.get("pred_label"):
475
+ c1.metric("Predicted", full_names[st.session_state["pred_label"]])
476
+ c2.metric("Confidence", f"{st.session_state['conf']:.2f}")
477
+ else:
478
+ c1.metric("Predicted", "—")
479
+ c2.metric("Confidence", "—")
480
+
481
+ # VLM Response placeholder
482
+ st.subheader("Generated Explanation")
483
+ if st.session_state.get("vlm_response"):
484
+ st.info(st.session_state["vlm_response"])
485
+ else:
486
+ st.info("VLM explanation will appear here after selecting an image and running classification.")
487
+
488
+ example_paths = [
489
+ "images/ISIC_0025314.jpg",
490
+ "images/ISIC_0025586.jpg",
491
+ "images/ISIC_0025680.jpg",
492
+ "images/ISIC_0026163.jpg"
493
+ ]
494
+
495
+ # Container div for toggle + gallery
496
+ st.markdown("""
497
+ <div style='background-color:#f9fafb; padding:15px; border-radius:12px; margin-bottom:20px;'>
498
+ """, unsafe_allow_html=True)
499
+
500
+ toggle = st.toggle("Show Example Images", value=False)
501
+
502
+ if toggle:
503
+ # Toggle ON → show gallery
504
+ st.markdown("<div class='header-sub'>Click on any image to analyze it instantly</div>", unsafe_allow_html=True)
505
+ html = """
506
+ <style>
507
+ .example-img {
508
+ border-radius:10px;
509
+ width:100%;
510
+ display:block;
511
+ box-shadow: 0 4px 12px rgba(14,30,37,0.06);
512
+ transition: transform .12s ease, box-shadow .12s ease;
513
+ cursor: pointer;
514
+ }
515
+ .example-img:hover {
516
+ transform: scale(1.03);
517
+ box-shadow: 0 14px 30px rgba(14,30,37,0.10);
518
+ }
519
+ .gallery-row { display:flex; gap:20px; }
520
+ .gallery-item { flex:1; }
521
+ </style>
522
+ <div class="gallery-row">
523
+ """
524
+
525
+ for i, path in enumerate(example_paths):
526
+ src = image_to_data_uri(path, max_width=480, jpeg_quality=70)
527
+ if src is None:
528
+ placeholder_svg = """
529
+ <svg xmlns='http://www.w3.org/2000/svg' width='400' height='300'>
530
+ <rect width='100%' height='100%' fill='#f3f4f6'/>
531
+ <text x='50%' y='50%' dominant-baseline='middle' text-anchor='middle'
532
+ fill='#9ca3af' font-size='20'>missing</text>
533
+ </svg>
534
+ """
535
+ src = "data:image/svg+xml;base64," + base64.b64encode(placeholder_svg.encode()).decode()
536
+
537
+ html += f"""
538
+ <a href='#' id='img_{i}' class='gallery-item'>
539
+ <img src='{src}' class='example-img' />
540
+ </a>
541
+ """
542
+
543
+ html += "</div>"
544
+
545
+ if "example_click_key" not in st.session_state:
546
+ st.session_state.example_click_key = 0
547
+
548
+ clicked = click_detector(html, key=f"clicking_examples_{st.session_state.example_click_key}")
549
+
550
+ if clicked:
551
+ if uploaded_file is not None:
552
+ st.warning("Please remove the uploaded file by clickng cross in the uploaded file name")
553
+ else:
554
+ idx = int(clicked.split("_")[1])
555
+ selected_path = example_paths[idx]
556
+ img_bytes = open(selected_path, "rb").read()
557
+ if st.session_state.get("last_image_bytes") != img_bytes:
558
+ st.session_state.selected_image = img_bytes
559
+ st.session_state.example_selected = True
560
+ st.session_state["vlm_response"] = None
561
+ st.session_state.example_click_key += 1
562
+ try:
563
+ st.toast(f"✅ Selected image: {selected_path}", icon="📸")
564
+ except Exception:
565
+ st.success(f"Selected image: {selected_path}")
566
+ st.rerun()
567
+ st.markdown("</div>", unsafe_allow_html=True)
568
+
569
+ st.markdown("""
570
+ <div style='margin-top:12px; color:#6b7280; font-size:13px;'>
571
+ © 2025 Faysal Ahmmed, Ajmy Alaly, Samanta Mehnaj, Asef Rahman, F.M. Mridha. All rights reserved.
572
+ </div>
573
+ """, unsafe_allow_html=True)