File size: 2,030 Bytes
1206896 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
"""
Simple example of using Sybil for lung cancer risk prediction
"""
import sys
import os
# Install required packages if needed
# !pip install torch torchvision transformers pydicom torchio sybil
# Import the model
from modeling_sybil_wrapper import SybilHFWrapper
from configuration_sybil import SybilConfig
def predict_cancer_risk(dicom_paths):
"""
Predict lung cancer risk from DICOM files.
Args:
dicom_paths: List of paths to DICOM files from a CT scan
Returns:
Risk scores for years 1-6
"""
# Load model
print("Loading Sybil model...")
config = SybilConfig()
model = SybilHFWrapper.from_pretrained("Lab-Rasool/sybil")
# Run prediction
print(f"Processing {len(dicom_paths)} DICOM files...")
output = model(dicom_paths=dicom_paths, return_attentions=False)
# Extract risk scores
risk_scores = output.risk_scores.numpy()
return risk_scores
def main():
# Example usage with demo data
# In practice, replace with actual DICOM file paths
demo_dicom_paths = [
"path/to/slice001.dcm",
"path/to/slice002.dcm",
# ... more slices
]
# For testing, you can download demo data:
# Download from: https://github.com/reginabarzilaygroup/Sybil
print("=" * 50)
print("Sybil Lung Cancer Risk Prediction")
print("=" * 50)
# Uncomment when you have actual DICOM files:
# risk_scores = predict_cancer_risk(demo_dicom_paths)
# Print results
# print("\nLung Cancer Risk Predictions:")
# print("-" * 30)
# for year, score in enumerate(risk_scores, 1):
# risk_pct = score * 100
# print(f"Year {year}: {risk_pct:.2f}% risk")
print("\nNote: This example requires actual DICOM files.")
print("Please provide paths to LDCT scan DICOM files.")
print("\nFor more information:")
print("- Original paper: https://doi.org/10.1200/JCO.22.01345")
print("- GitHub: https://github.com/reginabarzilaygroup/Sybil")
if __name__ == "__main__":
main() |