SkinCancerViT / README.md
mrs83's picture
Update README.md
633d29a verified
|
raw
history blame
2.34 kB
metadata
license: apache-2.0
datasets:
  - marmal88/skin_cancer
base_model:
  - google/vit-base-patch16-224-in21k
pipeline_tag: image-classification
tags:
  - medical

Installation

First, clone the repository:

git clone https://github.com/ethicalabs-ai/SkinCancerViT.git
cd SkinCancerViT

Then, install the package in editable mode using uv (or pip):

uv sync   # Recommended if you use uv
# Or, if using pip:
# pip install -e .

Quick Start / Usage

This package allows you to load and use a pre-trained SkinCancerViT model for prediction.

import torch
from skincancer_vit.model import SkinCancerViTModel
from PIL import Image
from datasets import load_dataset   # To get a random sample

# Load the model from Hugging Face Hub
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SkinCancerViTModel.from_pretrained("ethicalabs/SkinCancerViT")
model.to(device)   # Move model to the desired device
model.eval()   # Set model to evaluation mode

# Example Prediction from a Specific Image File
image_file_path = "images/patient-001.jpg"   # Specify your image file path here
specific_image = Image.open(image_file_path).convert("RGB")

# Example tabular data for this prediction
specific_age = 42
specific_localization = "face"   # Ensure this matches one of your trained localization categories

predicted_dx, confidence = model.full_predict(
    raw_image=specific_image,
    raw_age=specific_age,
    raw_localization=specific_localization,
    device=device
)

print(f"Predicted Diagnosis: {predicted_dx}")
print(f"Confidence: {confidence:.4f}")

# Example Prediction from a Random Test Sample from the Dataset
dataset = load_dataset("marmal88/skin_cancer", split="test")
random_sample = dataset.shuffle(seed=42).select(range(1))[0] # Get the first shuffled sample

sample_image = random_sample["image"]
sample_age = random_sample["age"]
sample_localization = random_sample["localization"]
sample_true_dx = random_sample["dx"]

predicted_dx_sample, confidence_sample = model.full_predict(
    raw_image=sample_image,
    raw_age=sample_age,
    raw_localization=sample_localization,
    device=device
)

print(f"Predicted Diagnosis: {predicted_dx_sample}")
print(f"Confidence: {confidence_sample:.4f}")
print(f"Correct Prediction: {predicted_dx_sample == sample_true_dx}")