|
|
--- |
|
|
license: apache-2.0 |
|
|
datasets: |
|
|
- marmal88/skin_cancer |
|
|
base_model: |
|
|
- google/vit-base-patch16-224-in21k |
|
|
pipeline_tag: image-classification |
|
|
tags: |
|
|
- medical |
|
|
--- |
|
|
|
|
|
## Installation |
|
|
|
|
|
First, clone the repository: |
|
|
|
|
|
```bash |
|
|
git clone https://github.com/ethicalabs-ai/SkinCancerViT.git |
|
|
cd SkinCancerViT |
|
|
``` |
|
|
|
|
|
Then, install the package in editable mode using uv (or pip): |
|
|
|
|
|
```bash |
|
|
uv sync # Recommended if you use uv |
|
|
# Or, if using pip: |
|
|
# pip install -e . |
|
|
``` |
|
|
|
|
|
## Quick Start / Usage |
|
|
|
|
|
This package allows you to load and use a pre-trained SkinCancerViT model for prediction. |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from skincancer_vit.model import SkinCancerViTModel |
|
|
from PIL import Image |
|
|
from datasets import load_dataset # To get a random sample |
|
|
|
|
|
# Load the model from Hugging Face Hub |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = SkinCancerViTModel.from_pretrained("ethicalabs/SkinCancerViT") |
|
|
model.to(device) # Move model to the desired device |
|
|
model.eval() # Set model to evaluation mode |
|
|
|
|
|
# Example Prediction from a Specific Image File |
|
|
image_file_path = "images/patient-001.jpg" # Specify your image file path here |
|
|
specific_image = Image.open(image_file_path).convert("RGB") |
|
|
|
|
|
# Example tabular data for this prediction |
|
|
specific_age = 42 |
|
|
specific_localization = "face" # Ensure this matches one of your trained localization categories |
|
|
|
|
|
predicted_dx, confidence = model.full_predict( |
|
|
raw_image=specific_image, |
|
|
raw_age=specific_age, |
|
|
raw_localization=specific_localization, |
|
|
device=device |
|
|
) |
|
|
|
|
|
print(f"Predicted Diagnosis: {predicted_dx}") |
|
|
print(f"Confidence: {confidence:.4f}") |
|
|
|
|
|
# Example Prediction from a Random Test Sample from the Dataset |
|
|
dataset = load_dataset("marmal88/skin_cancer", split="test") |
|
|
random_sample = dataset.shuffle(seed=42).select(range(1))[0] # Get the first shuffled sample |
|
|
|
|
|
sample_image = random_sample["image"] |
|
|
sample_age = random_sample["age"] |
|
|
sample_localization = random_sample["localization"] |
|
|
sample_true_dx = random_sample["dx"] |
|
|
|
|
|
predicted_dx_sample, confidence_sample = model.full_predict( |
|
|
raw_image=sample_image, |
|
|
raw_age=sample_age, |
|
|
raw_localization=sample_localization, |
|
|
device=device |
|
|
) |
|
|
|
|
|
print(f"Predicted Diagnosis: {predicted_dx_sample}") |
|
|
print(f"Confidence: {confidence_sample:.4f}") |
|
|
print(f"Correct Prediction: {predicted_dx_sample == sample_true_dx}") |
|
|
``` |