medical-agent / dataset_optimiser_with_finetunning.py
Nourhenem's picture
Upload folder using huggingface_hub
1eb76aa verified
import os
import re
import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Union
from dataclasses import dataclass
from pathlib import Path
# Core libraries
import torch
from transformers import (
AutoTokenizer, AutoModel, AutoModelForTokenClassification,
TrainingArguments, Trainer, pipeline, DataCollatorForTokenClassification
)
from torch.utils.data import Dataset
import torch.nn.functional as F
# Vector database
import chromadb
from chromadb.config import Settings
# Utilities
import logging
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class MedicalEntity:
"""Structure pour les entités médicales extraites par NER"""
exam_types: List[Tuple[str, float]] # (entity, confidence)
specialties: List[Tuple[str, float]]
anatomical_regions: List[Tuple[str, float]]
pathologies: List[Tuple[str, float]]
medical_procedures: List[Tuple[str, float]]
measurements: List[Tuple[str, float]]
medications: List[Tuple[str, float]]
symptoms: List[Tuple[str, float]]
class MedicalNERDataset(Dataset):
"""Dataset personnalisé pour l'entraînement NER médical"""
def __init__(self, texts, labels, tokenizer, max_length=512):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
labels = self.labels[idx]
# Tokenisation
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_offsets_mapping=True,
return_tensors='pt'
)
# Alignement des labels avec les tokens
aligned_labels = self._align_labels_with_tokens(
labels, encoding.offset_mapping.squeeze().tolist()
)
return {
'input_ids': encoding.input_ids.flatten(),
'attention_mask': encoding.attention_mask.flatten(),
'labels': torch.tensor(aligned_labels, dtype=torch.long)
}
def _align_labels_with_tokens(self, labels, offset_mapping):
"""Aligne les labels BIO avec les tokens du tokenizer"""
aligned_labels = []
label_idx = 0
for start, end in offset_mapping:
if start == 0 and end == 0: # Token spécial [CLS], [SEP], [PAD]
aligned_labels.append(-100) # Ignore dans la loss
else:
if label_idx < len(labels):
aligned_labels.append(labels[label_idx])
label_idx += 1
else:
aligned_labels.append(0) # O (Outside)
return aligned_labels
class AdvancedMedicalNER:
"""NER médical avancé basé sur CamemBERT-Bio fine-tuné"""
def __init__(self, model_name: str = "auto", cache_dir: str = "./models_cache"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
# Auto-détection du meilleur modèle NER médical disponible
self.model_name = self._select_best_model(model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Labels BIO pour entités médicales
self.entity_labels = [
"O", # Outside
"B-EXAM_TYPES", "I-EXAM_TYPES", # Types d'examens
"B-SPECIALTIES", "I-SPECIALTIES", # Spécialités médicales
"B-ANATOMICAL_REGIONS", "I-ANATOMICAL_REGIONS", # Régions anatomiques
"B-PATHOLOGIES", "I-PATHOLOGIES", # Pathologies
"B-PROCEDURES", "I-PROCEDURES", # Procédures médicales
"B-MEASUREMENTS", "I-MEASUREMENTS", # Mesures/valeurs
"B-MEDICATIONS", "I-MEDICATIONS", # Médicaments
"B-SYMPTOMS", "I-SYMPTOMS" # Symptômes
]
self.id2label = {i: label for i, label in enumerate(self.entity_labels)}
self.label2id = {label: i for i, label in enumerate(self.entity_labels)}
# Chargement du modèle NER
self._load_ner_model()
def _select_best_model(self, model_name: str) -> str:
"""Sélection automatique du meilleur modèle NER médical"""
if model_name != "auto":
return model_name
# Liste des modèles par ordre de préférence
preferred_models = [
"almanach/camembert-bio-base", # CamemBERT Bio français
"Dr-BERT/DrBERT-7GB", # DrBERT spécialisé
"emilyalsentzer/Bio_ClinicalBERT", # Bio Clinical BERT
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
"dmis-lab/biobert-base-cased-v1.2", # BioBERT
"camembert-base" # Fallback CamemBERT standard
]
for model in preferred_models:
try:
# Test de disponibilité
AutoTokenizer.from_pretrained(model, cache_dir=self.cache_dir)
logger.info(f"Modèle sélectionné: {model}")
return model
except:
continue
# Fallback ultime
logger.warning("Utilisation du modèle de base camembert-base")
return "camembert-base"
def _load_ner_model(self):
"""Charge ou crée le modèle NER fine-tuné"""
fine_tuned_path = self.cache_dir / "medical_ner_model"
if fine_tuned_path.exists():
logger.info("Chargement du modèle NER fine-tuné existant")
self.tokenizer = AutoTokenizer.from_pretrained(fine_tuned_path)
self.ner_model = AutoModelForTokenClassification.from_pretrained(fine_tuned_path)
else:
logger.info("Création d'un nouveau modèle NER médical")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, cache_dir=self.cache_dir)
# Modèle pour classification de tokens (NER)
self.ner_model = AutoModelForTokenClassification.from_pretrained(
self.model_name,
num_labels=len(self.entity_labels),
id2label=self.id2label,
label2id=self.label2id,
cache_dir=self.cache_dir
)
self.ner_model.to(self.device)
# Pipeline NER
self.ner_pipeline = pipeline(
"token-classification",
model=self.ner_model,
tokenizer=self.tokenizer,
device=0 if torch.cuda.is_available() else -1,
aggregation_strategy="simple"
)
def extract_entities(self, text: str) -> MedicalEntity:
"""Extraction d'entités avec le modèle NER fine-tuné"""
# Prédiction NER
try:
ner_results = self.ner_pipeline(text)
except Exception as e:
logger.error(f"Erreur NER: {e}")
return MedicalEntity([], [], [], [], [], [], [], [])
# Groupement des entités par type
entities = {
"EXAM_TYPES": [],
"SPECIALTIES": [],
"ANATOMICAL_REGIONS": [],
"PATHOLOGIES": [],
"PROCEDURES": [],
"MEASUREMENTS": [],
"MEDICATIONS": [],
"SYMPTOMS": []
}
for result in ner_results:
entity_type = result['entity_group'].replace('B-', '').replace('I-', '')
entity_text = result['word']
confidence = result['score']
if entity_type in entities and confidence > 0.7: # Seuil de confiance
entities[entity_type].append((entity_text, confidence))
return MedicalEntity(
exam_types=entities["EXAM_TYPES"],
specialties=entities["SPECIALTIES"],
anatomical_regions=entities["ANATOMICAL_REGIONS"],
pathologies=entities["PATHOLOGIES"],
medical_procedures=entities["PROCEDURES"],
measurements=entities["MEASUREMENTS"],
medications=entities["MEDICATIONS"],
symptoms=entities["SYMPTOMS"]
)
def load_dataset(self, dataset_path: str) -> List[Dict]:
"""Charge le dataset depuis le fichier JSON"""
try:
with open(dataset_path, 'r', encoding='utf-8') as f:
# Chaque ligne est un objet JSON séparé
data = []
for line in f:
if line.strip():
data.append(json.loads(line.strip()))
return data
except Exception as e:
logger.error(f"Erreur lors du chargement du dataset: {e}")
return []
"""
def _text_to_bio_labels(self, text: str, entities_dict: Dict[str, List[str]]) -> List[int]:
#Convertit le texte et les entités en labels BI en utilisant les offsets
# Tokenisation du texte
tokens = self.tokenizer.tokenize(text)
labels = [0] * len(tokens) # Initialisation avec "O" (Outside)
# Mapping des types d'entités vers les labels BIO
entity_type_mapping = {
"exam_types": ("B-EXAM_TYPES", "I-EXAM_TYPES"),
"specialties": ("B-SPECIALTIES", "I-SPECIALTIES"),
"anatomical_regions": ("B-ANATOMICAL_REGIONS", "I-ANATOMICAL_REGIONS"),
"pathologies": ("B-PATHOLOGIES", "I-PATHOLOGIES"),
"procedures": ("B-PROCEDURES", "I-PROCEDURES"),
"measurements": ("B-MEASUREMENTS", "I-MEASUREMENTS"),
"medications": ("B-MEDICATIONS", "I-MEDICATIONS"),
"symptoms": ("B-SYMPTOMS", "I-SYMPTOMS")
}
# Attribution des labels pour chaque type d'entité
for entity_type, entity_list in entities_dict.items():
if entity_type in entity_type_mapping and entity_list:
b_label, i_label = entity_type_mapping[entity_type]
b_label_id = self.label2id[b_label]
i_label_id = self.label2id[i_label]
for entity in entity_list:
# Recherche de l'entité dans le texte tokenizé
entity_tokens = self.tokenizer.tokenize(entity.lower())
text_lower = text.lower()
# Recherche de la position de l'entité
start_pos = text_lower.find(entity.lower())
if start_pos != -1:
# Approximation de la position dans les tokens
# (méthode simplifiée - pourrait être améliorée)
char_to_token_ratio = len(tokens) / len(text)
approx_token_start = int(start_pos * char_to_token_ratio)
approx_token_end = min(
len(tokens),
approx_token_start + len(entity_tokens)
)
# Attribution des labels BIO
for i in range(approx_token_start, approx_token_end):
if i < len(labels):
if i == approx_token_start:
labels[i] = b_label_id # B- pour le premier token
else:
labels[i] = i_label_id # I- pour les tokens suivants
return labels
"""
def _text_to_bio_labels(self, text: str, entities_dict: Dict[str, List[str]]) -> List[int]:
"""Convertit le texte et les entités en labels BIO en utilisant les offsets (robuste)"""
# Encodage avec offsets
encoding = self.tokenizer(
text,
return_offsets_mapping=True,
add_special_tokens=False
)
tokens = encoding.tokens()
offsets = encoding["offset_mapping"]
labels = [self.label2id["O"]] * len(tokens) # Initialisation avec "O"
# Mapping des types d'entités vers les labels BIO
entity_type_mapping = {
"exam_types": ("B-EXAM_TYPES", "I-EXAM_TYPES"),
"specialties": ("B-SPECIALTIES", "I-SPECIALTIES"),
"anatomical_regions": ("B-ANATOMICAL_REGIONS", "I-ANATOMICAL_REGIONS"),
"pathologies": ("B-PATHOLOGIES", "I-PATHOLOGIES"),
"procedures": ("B-PROCEDURES", "I-PROCEDURES"),
"measurements": ("B-MEASUREMENTS", "I-MEASUREMENTS"),
"medications": ("B-MEDICATIONS", "I-MEDICATIONS"),
"symptoms": ("B-SYMPTOMS", "I-SYMPTOMS")
}
# Attribution des labels
for entity_type, entity_list in entities_dict.items():
if entity_type in entity_type_mapping and entity_list:
b_label, i_label = entity_type_mapping[entity_type]
b_label_id = self.label2id[b_label]
i_label_id = self.label2id[i_label]
for entity in entity_list:
start_char = text.lower().find(entity.lower())
if start_char == -1:
continue
end_char = start_char + len(entity)
# Trouver tous les tokens qui chevauchent l’entité
entity_token_idxs = [
i for i, (tok_start, tok_end) in enumerate(offsets)
if tok_start < end_char and tok_end > start_char
]
if not entity_token_idxs:
continue
# Attribution BIO
for j, tok_idx in enumerate(entity_token_idxs):
if j == 0:
labels[tok_idx] = b_label_id
else:
labels[tok_idx] = i_label_id
return labels
def _prepare_training_data(self, templates_data: List[Dict]) -> Dict:
"""Prépare les données d'entraînement pour le NER à partir du dataset"""
if not templates_data:
logger.warning("Aucune donnée de template fournie")
return {'train': MedicalNERDataset([], [], self.tokenizer)}
texts = []
labels = []
logger.info(f"Préparation de {len(templates_data)} échantillons pour l'entraînement")
for sample in tqdm(templates_data, desc="Conversion en format BIO"):
try:
text = sample['text']
entities_dict = sample['labels']
# Conversion en labels BIO
bio_labels = self._text_to_bio_labels(text, entities_dict)
texts.append(text)
labels.append(bio_labels)
except Exception as e:
logger.error(f"Erreur lors du traitement d'un échantillon: {e}")
continue
if not texts:
logger.error("Aucun échantillon valide trouvé pour l'entraînement")
return {'train': MedicalNERDataset([], [], self.tokenizer)}
# Division train/validation si suffisamment de données
if len(texts) > 10:
train_texts, val_texts, train_labels, val_labels = train_test_split(
texts, labels, test_size=0.2, random_state=42
)
train_dataset = MedicalNERDataset(train_texts, train_labels, self.tokenizer)
val_dataset = MedicalNERDataset(val_texts, val_labels, self.tokenizer)
logger.info(f"Dataset divisé: {len(train_texts)} train, {len(val_texts)} validation")
return {'train': train_dataset, 'eval': val_dataset}
else:
train_dataset = MedicalNERDataset(texts, labels, self.tokenizer)
logger.info(f"Dataset d'entraînement: {len(texts)} échantillons")
return {'train': train_dataset}
def fine_tune_on_templates(self, templates_data: List[Dict] = None,
dataset_path: str = "dataset.json",
output_dir: str = None,
epochs: int = 3):
"""Fine-tuning du modèle NER sur des templates médicaux"""
if output_dir is None:
output_dir = self.cache_dir / "medical_ner_model"
# Chargement des données
if templates_data is None:
logger.info(f"Chargement du dataset depuis {dataset_path}")
templates_data = self.load_dataset(dataset_path)
if not templates_data:
logger.error("Aucune donnée disponible pour l'entraînement")
return
logger.info("Début du fine-tuning NER sur templates médicaux")
# Préparation des données d'entraînement
datasets = self._prepare_training_data(templates_data)
if len(datasets['train']) == 0:
logger.error("Dataset d'entraînement vide")
return
# Data collator pour gérer le padding
data_collator = DataCollatorForTokenClassification(
tokenizer=self.tokenizer,
padding=True
)
# Configuration d'entraînement
training_args = TrainingArguments(
output_dir=str(output_dir),
num_train_epochs=epochs,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir=f"{output_dir}/logs",
logging_steps=50,
save_strategy="epoch",
evaluation_strategy="epoch" if 'eval' in datasets else "no",
load_best_model_at_end=True if 'eval' in datasets else False,
metric_for_best_model="eval_loss" if 'eval' in datasets else None,
greater_is_better=False,
remove_unused_columns=False,
)
# Fonction de calcul des métriques
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=2)
# Calcul de l'accuracy en ignorant les labels -100
mask = labels != -100
accuracy = (predictions[mask] == labels[mask]).mean()
return {"accuracy": accuracy}
# Trainer
trainer = Trainer(
model=self.ner_model,
args=training_args,
train_dataset=datasets['train'],
eval_dataset=datasets.get('eval'),
tokenizer=self.tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if 'eval' in datasets else None,
)
# Entraînement
logger.info("Début de l'entraînement...")
trainer.train()
# Sauvegarde
trainer.save_model()
self.tokenizer.save_pretrained(output_dir)
# Recharger le modèle et le pipeline
self._load_ner_model()
logger.info(f"Fine-tuning terminé, modèle sauvé dans {output_dir}")
# Affichage des métriques finales si évaluation disponible
if 'eval' in datasets:
eval_results = trainer.evaluate()
logger.info(f"Métriques finales: {eval_results}")
class AdvancedMedicalEmbedding:
"""Générateur d'embeddings médicaux avancés avec cross-encoder reranking"""
def __init__(self,
base_model: str = "almanach/camembert-bio-base",
cross_encoder_model: str = "auto"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.base_model_name = base_model
# Modèle principal pour embeddings
self._load_base_model()
# Cross-encoder pour reranking
self._load_cross_encoder(cross_encoder_model)
def _load_base_model(self):
"""Charge le modèle de base pour les embeddings"""
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
self.base_model = AutoModel.from_pretrained(self.base_model_name)
self.base_model.to(self.device)
logger.info(f"Modèle de base chargé: {self.base_model_name}")
except Exception as e:
logger.error(f"Erreur chargement modèle de base: {e}")
raise
def _load_cross_encoder(self, model_name: str):
"""Charge le cross-encoder pour reranking"""
if model_name == "auto":
# Sélection automatique du meilleur cross-encoder médical
cross_encoders = [
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
"emilyalsentzer/Bio_ClinicalBERT",
self.base_model_name # Fallback
]
for model in cross_encoders:
try:
self.cross_tokenizer = AutoTokenizer.from_pretrained(model)
self.cross_model = AutoModel.from_pretrained(model)
self.cross_model.to(self.device)
logger.info(f"Cross-encoder chargé: {model}")
break
except:
continue
else:
self.cross_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.cross_model = AutoModel.from_pretrained(model_name)
self.cross_model.to(self.device)
def generate_embedding(self, text: str, entities: MedicalEntity = None) -> np.ndarray:
"""Génère un embedding enrichi pour un texte médical"""
# Tokenisation
inputs = self.tokenizer(
text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
# Génération embedding
with torch.no_grad():
outputs = self.base_model(**inputs)
# Mean pooling
attention_mask = inputs['attention_mask']
token_embeddings = outputs.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Enrichissement avec entités NER
if entities:
embedding = self._enrich_with_ner_entities(embedding, entities)
return embedding.cpu().numpy().flatten().astype(np.float32)
def _enrich_with_ner_entities(self, base_embedding: torch.Tensor, entities: MedicalEntity) -> torch.Tensor:
"""Enrichit l'embedding avec les entités extraites par NER"""
# Concaténer les entités importantes avec leurs scores de confiance
entity_texts = []
confidence_weights = []
for entity_list in [entities.exam_types, entities.specialties,
entities.anatomical_regions, entities.pathologies]:
for entity_text, confidence in entity_list:
entity_texts.append(entity_text)
confidence_weights.append(confidence)
if not entity_texts:
return base_embedding
# Génération d'embeddings pour les entités
entity_text_combined = " [SEP] ".join(entity_texts)
entity_inputs = self.tokenizer(
entity_text_combined,
padding=True,
truncation=True,
max_length=256,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
entity_outputs = self.base_model(**entity_inputs)
entity_embedding = torch.mean(entity_outputs.last_hidden_state, dim=1)
# Fusion pondérée par les scores de confiance
avg_confidence = np.mean(confidence_weights) if confidence_weights else 0.5
fusion_weight = min(0.4, avg_confidence) # Max 40% pour les entités
enriched_embedding = (1 - fusion_weight) * base_embedding + fusion_weight * entity_embedding
return enriched_embedding
def cross_encoder_rerank(self,
query: str,
candidates: List[Dict],
top_k: int = 3) -> List[Dict]:
"""Reranking avec cross-encoder pour affiner la sélection"""
if len(candidates) <= top_k:
return candidates
reranked_candidates = []
for candidate in candidates:
# Création de la paire query-candidate
pair_text = f"{query} [SEP] {candidate['document']}"
# Tokenisation
inputs = self.cross_tokenizer(
pair_text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
# Score de similarité cross-encoder
with torch.no_grad():
outputs = self.cross_model(**inputs)
# Utilisation du [CLS] token pour le score de similarité
cls_embedding = outputs.last_hidden_state[:, 0, :]
similarity_score = torch.sigmoid(torch.mean(cls_embedding)).item()
candidate_copy = candidate.copy()
candidate_copy['cross_encoder_score'] = similarity_score
candidate_copy['final_score'] = (
0.6 * candidate['similarity_score'] +
0.4 * similarity_score
)
reranked_candidates.append(candidate_copy)
# Tri par score final
reranked_candidates.sort(key=lambda x: x['final_score'], reverse=True)
return reranked_candidates[:top_k]
class MedicalTemplateVectorDB:
"""Base de données vectorielle optimisée pour templates médicaux"""
def __init__(self, db_path: str = "./medical_vector_db", collection_name: str = "medical_templates"):
self.db_path = db_path
self.collection_name = collection_name
# ChromaDB avec configuration optimisée
self.client = chromadb.PersistentClient(
path=db_path,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
# Collection avec métrique de distance optimisée
try:
self.collection = self.client.get_collection(collection_name)
logger.info(f"Collection '{collection_name}' chargée")
except:
self.collection = self.client.create_collection(
name=collection_name,
metadata={
"hnsw:space": "cosine",
"hnsw:M": 32, # Connectivité du graphe
"hnsw:ef_construction": 200, # Qualité vs vitesse construction
"hnsw:ef_search": 50 # Qualité vs vitesse recherche
}
)
logger.info(f"Collection '{collection_name}' créée avec optimisations HNSW")
def add_template(self,
template_id: str,
template_text: str,
embedding: np.ndarray,
entities: MedicalEntity,
metadata: Dict[str, Any] = None):
"""Ajoute un template avec métadonnées enrichies par NER"""
# Métadonnées automatiques basées sur NER
auto_metadata = {
"exam_types": [entity[0] for entity in entities.exam_types],
"specialties": [entity[0] for entity in entities.specialties],
"anatomical_regions": [entity[0] for entity in entities.anatomical_regions],
"pathologies": [entity[0] for entity in entities.pathologies],
"procedures": [entity[0] for entity in entities.medical_procedures],
"text_length": len(template_text),
"entity_confidence_avg": np.mean([
entity[1] for entity_list in [
entities.exam_types, entities.specialties,
entities.anatomical_regions, entities.pathologies
] for entity in entity_list
]) if any([entities.exam_types, entities.specialties,
entities.anatomical_regions, entities.pathologies]) else 0.0
}
if metadata:
auto_metadata.update(metadata)
self.collection.add(
embeddings=[embedding.tolist()],
documents=[template_text],
metadatas=[auto_metadata],
ids=[template_id]
)
logger.info(f"Template {template_id} ajouté avec métadonnées NER automatiques")
def advanced_search(self,
query_embedding: np.ndarray,
n_results: int = 10,
entity_filters: Dict[str, List[str]] = None,
confidence_threshold: float = 0.0) -> List[Dict]:
"""Recherche avancée avec filtres basés sur entités NER"""
where_clause = {}
# Filtres basés sur entités NER extraites
if entity_filters:
for entity_type, entity_values in entity_filters.items():
if entity_values:
where_clause[entity_type] = {"$in": entity_values}
# Filtre par confiance moyenne des entités
if confidence_threshold > 0:
where_clause["entity_confidence_avg"] = {"$gte": confidence_threshold}
results = self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=n_results,
where=where_clause if where_clause else None,
include=["documents", "metadatas", "distances"]
)
# Formatage des résultats
formatted_results = []
for i in range(len(results['ids'][0])):
formatted_results.append({
'id': results['ids'][0][i],
'document': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'similarity_score': 1 - results['distances'][0][i],
'distance': results['distances'][0][i]
})
return formatted_results
class AdvancedMedicalTemplateProcessor:
"""Processeur avancé avec NER fine-tuné et reranking cross-encoder"""
def __init__(self,
base_model: str = "almanach/camembert-bio-base",
db_path: str = "./advanced_medical_vector_db"):
self.ner_extractor = AdvancedMedicalNER()
self.embedding_generator = AdvancedMedicalEmbedding(base_model)
self.vector_db = MedicalTemplateVectorDB(db_path)
logger.info("Processeur médical avancé initialisé avec NER fine-tuné et cross-encoder reranking")
def process_templates_batch(self,
templates: List[Dict[str, str]] = None,
dataset_path: str = "dataset.json",
batch_size: int = 8,
fine_tune_ner: bool = False) -> None:
"""Traitement avancé avec option de fine-tuning NER"""
# Chargement des données si templates non fournis
if templates is None:
logger.info(f"Chargement des templates depuis {dataset_path}")
templates = self.ner_extractor.load_dataset(dataset_path)
# Conversion du format dataset vers le format attendu
templates = [
{
'id': f"template_{i:04d}",
'text': template['text'],
'metadata': {'labels': template.get('labels', {})}
}
for i, template in enumerate(templates)
]
if fine_tune_ner:
logger.info("Fine-tuning du modèle NER sur les templates...")
# Reconversion pour le fine-tuning
training_data = [
{
'text': template['text'],
'labels': template['metadata'].get('labels', {})
}
for template in templates
]
self.ner_extractor.fine_tune_on_templates(training_data)
logger.info(f"Traitement avancé de {len(templates)} templates")
for i in tqdm(range(0, len(templates), batch_size), desc="Traitement avancé"):
batch = templates[i:i+batch_size]
for template in batch:
try:
template_id = template['id']
template_text = template['text']
metadata = template.get('metadata', {})
# NER avancé
entities = self.ner_extractor.extract_entities(template_text)
# Embedding enrichi
embedding = self.embedding_generator.generate_embedding(template_text, entities)
# Stockage avec métadonnées NER
self.vector_db.add_template(
template_id=template_id,
template_text=template_text,
embedding=embedding,
entities=entities,
metadata=metadata
)
except Exception as e:
logger.error(f"Erreur traitement template {template.get('id', 'unknown')}: {e}")
continue
def find_best_template_with_reranking(self,
transcription: str,
initial_candidates: int = 10,
final_results: int = 3) -> List[Dict]:
"""Recherche optimale avec reranking cross-encoder"""
# 1. Extraction NER de la transcription
query_entities = self.ner_extractor.extract_entities(transcription)
# 2. Génération embedding enrichi
query_embedding = self.embedding_generator.generate_embedding(transcription, query_entities)
# 3. Filtres automatiques basés sur entités extraites
entity_filters = {}
if query_entities.exam_types:
entity_filters['exam_types'] = [entity[0] for entity in query_entities.exam_types]
if query_entities.specialties:
entity_filters['specialties'] = [entity[0] for entity in query_entities.specialties]
if query_entities.anatomical_regions:
entity_filters['anatomical_regions'] = [entity[0] for entity in query_entities.anatomical_regions]
# 4. Recherche vectorielle initiale
initial_candidates_results = self.vector_db.advanced_search(
query_embedding=query_embedding,
n_results=initial_candidates,
entity_filters=entity_filters,
confidence_threshold=0.6
)
# 5. Reranking avec cross-encoder
if len(initial_candidates_results) > final_results:
final_results_reranked = self.embedding_generator.cross_encoder_rerank(
query=transcription,
candidates=initial_candidates_results,
top_k=final_results
)
else:
final_results_reranked = initial_candidates_results
# 6. Enrichissement des résultats avec détails NER
for result in final_results_reranked:
result['query_entities'] = {
'exam_types': query_entities.exam_types,
'specialties': query_entities.specialties,
'anatomical_regions': query_entities.anatomical_regions,
'pathologies': query_entities.pathologies
}
return final_results_reranked
def evaluate_ner_performance(self, test_dataset_path: str = None) -> Dict[str, float]:
"""Évalue les performances du modèle NER fine-tuné"""
if test_dataset_path is None:
logger.warning("Aucun dataset de test fourni pour l'évaluation")
return {}
test_data = self.ner_extractor.load_dataset(test_dataset_path)
if not test_data:
logger.error("Dataset de test vide")
return {}
correct_predictions = 0
total_entities = 0
entity_type_stats = {}
for sample in tqdm(test_data, desc="Évaluation NER"):
text = sample['text']
true_entities = sample['labels']
# Prédiction
predicted_entities = self.ner_extractor.extract_entities(text)
# Conversion en format comparable
predicted_dict = {
'exam_types': [entity[0].lower() for entity in predicted_entities.exam_types],
'specialties': [entity[0].lower() for entity in predicted_entities.specialties],
'anatomical_regions': [entity[0].lower() for entity in predicted_entities.anatomical_regions],
'pathologies': [entity[0].lower() for entity in predicted_entities.pathologies],
'procedures': [entity[0].lower() for entity in predicted_entities.medical_procedures],
'measurements': [entity[0].lower() for entity in predicted_entities.measurements],
'medications': [entity[0].lower() for entity in predicted_entities.medications],
'symptoms': [entity[0].lower() for entity in predicted_entities.symptoms]
}
# Comparaison
for entity_type, true_entities_list in true_entities.items():
if entity_type in predicted_dict:
predicted_entities_list = predicted_dict[entity_type]
# Statistiques par type d'entité
if entity_type not in entity_type_stats:
entity_type_stats[entity_type] = {'correct': 0, 'total': 0}
true_entities_lower = [entity.lower() for entity in true_entities_list]
for true_entity in true_entities_lower:
total_entities += 1
entity_type_stats[entity_type]['total'] += 1
if true_entity in predicted_entities_list:
correct_predictions += 1
entity_type_stats[entity_type]['correct'] += 1
# Calcul des métriques
overall_accuracy = correct_predictions / total_entities if total_entities > 0 else 0
metrics = {
'overall_accuracy': overall_accuracy,
'total_entities': total_entities,
'correct_predictions': correct_predictions
}
# Métriques par type d'entité
for entity_type, stats in entity_type_stats.items():
if stats['total'] > 0:
accuracy = stats['correct'] / stats['total']
metrics[f'{entity_type}_accuracy'] = accuracy
metrics[f'{entity_type}_total'] = stats['total']
logger.info(f"Évaluation NER terminée - Accuracy globale: {overall_accuracy:.4f}")
return metrics
def export_processed_templates(self, output_path: str = "processed_templates.json"):
"""Exporte les templates traités avec leurs embeddings et entités"""
try:
# Récupération de tous les templates de la base vectorielle
all_results = self.vector_db.collection.get(
include=["documents", "metadatas", "embeddings"]
)
processed_templates = []
for i in range(len(all_results['ids'])):
template_data = {
'id': all_results['ids'][i],
'text': all_results['documents'][i],
'metadata': all_results['metadatas'][i],
'embedding': all_results['embeddings'][i] if all_results.get('embeddings') else None
}
processed_templates.append(template_data)
# Sauvegarde
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(processed_templates, f, ensure_ascii=False, indent=2)
logger.info(f"Templates traités exportés vers {output_path}")
logger.info(f"Nombre de templates exportés: {len(processed_templates)}")
except Exception as e:
logger.error(f"Erreur lors de l'export: {e}")
# Utilitaires pour l'analyse et le debugging
class MedicalNERAnalyzer:
"""Outils d'analyse et de debugging pour le système NER médical"""
def __init__(self, processor: AdvancedMedicalTemplateProcessor):
self.processor = processor
def analyze_text(self, text: str) -> Dict:
"""Analyse complète d'un texte médical"""
# Extraction NER
entities = self.processor.ner_extractor.extract_entities(text)
# Génération d'embedding
embedding = self.processor.embedding_generator.generate_embedding(text, entities)
# Statistiques
analysis = {
'text': text,
'text_length': len(text),
'entities': {
'exam_types': entities.exam_types,
'specialties': entities.specialties,
'anatomical_regions': entities.anatomical_regions,
'pathologies': entities.pathologies,
'procedures': entities.medical_procedures,
'measurements': entities.measurements,
'medications': entities.medications,
'symptoms': entities.symptoms
},
'embedding_shape': embedding.shape,
'entity_count_total': sum([
len(entities.exam_types),
len(entities.specialties),
len(entities.anatomical_regions),
len(entities.pathologies),
len(entities.medical_procedures),
len(entities.measurements),
len(entities.medications),
len(entities.symptoms)
]),
'confidence_scores': {
'exam_types': [conf for _, conf in entities.exam_types],
'specialties': [conf for _, conf in entities.specialties],
'anatomical_regions': [conf for _, conf in entities.anatomical_regions],
'pathologies': [conf for _, conf in entities.pathologies]
}
}
return analysis
def compare_entities(self, text1: str, text2: str) -> Dict:
"""Compare les entités extraites de deux textes"""
entities1 = self.processor.ner_extractor.extract_entities(text1)
entities2 = self.processor.ner_extractor.extract_entities(text2)
def entities_to_set(entities):
all_entities = set()
for entity_list in [entities.exam_types, entities.specialties,
entities.anatomical_regions, entities.pathologies]:
for entity, _ in entity_list:
all_entities.add(entity.lower())
return all_entities
set1 = entities_to_set(entities1)
set2 = entities_to_set(entities2)
return {
'text1_entities': list(set1),
'text2_entities': list(set2),
'common_entities': list(set1.intersection(set2)),
'unique_to_text1': list(set1.difference(set2)),
'unique_to_text2': list(set2.difference(set1)),
'similarity_ratio': len(set1.intersection(set2)) / len(set1.union(set2)) if set1.union(set2) else 0
}
def generate_entity_report(self, dataset_path: str) -> Dict:
"""Génère un rapport statistique sur les entités du dataset"""
dataset = self.processor.ner_extractor.load_dataset(dataset_path)
entity_stats = {
'exam_types': {},
'specialties': {},
'anatomical_regions': {},
'pathologies': {},
'procedures': {},
'measurements': {},
'medications': {},
'symptoms': {}
}
total_samples = len(dataset)
for sample in tqdm(dataset, desc="Analyse du dataset"):
labels = sample.get('labels', {})
for entity_type, entities in labels.items():
if entity_type in entity_stats:
for entity in entities:
entity_lower = entity.lower()
if entity_lower not in entity_stats[entity_type]:
entity_stats[entity_type][entity_lower] = 0
entity_stats[entity_type][entity_lower] += 1
# Génération du rapport
report = {
'total_samples': total_samples,
'entity_statistics': {}
}
for entity_type, entity_counts in entity_stats.items():
if entity_counts:
sorted_entities = sorted(entity_counts.items(), key=lambda x: x[1], reverse=True)
report['entity_statistics'][entity_type] = {
'unique_count': len(entity_counts),
'total_occurrences': sum(entity_counts.values()),
'top_10': sorted_entities[:10],
'average_occurrences': sum(entity_counts.values()) / len(entity_counts)
}
return report
# Exemple d'utilisation avancée
def main():
"""Exemple d'utilisation du système avancé avec fine-tuning"""
# Initialisation du processeur avancé
processor = AdvancedMedicalTemplateProcessor()
# 1. Traitement des templates avec fine-tuning NER
print("=== ÉTAPE 1: Traitement et Fine-tuning ===")
processor.process_templates_batch(
dataset_path="dataset.json",
fine_tune_ner=True, # Active le fine-tuning
batch_size=8
)
# 2. Évaluation des performances NER (optionnel, si dataset de test disponible)
print("\n=== ÉTAPE 2: Évaluation des performances ===")
# metrics = processor.evaluate_ner_performance("test_dataset.json")
# print(f"Métriques d'évaluation: {metrics}")
# 3. Analyse d'un texte médical
print("\n=== ÉTAPE 3: Analyse de texte ===")
analyzer = MedicalNERAnalyzer(processor)
test_text = """madame bacon nicole bilan œdème droit gonalgies ostéophytes
incontinence veineuse modérée portions surale droite crurale gauche saphéniennes"""
analysis = analyzer.analyze_text(test_text)
print(f"Analyse du texte:")
print(f"- Nombre total d'entités: {analysis['entity_count_total']}")
print(f"- Types d'examens détectés: {analysis['entities']['exam_types']}")
print(f"- Régions anatomiques: {analysis['entities']['anatomical_regions']}")
print(f"- Pathologies: {analysis['entities']['pathologies']}")
# 4. Recherche avec reranking
print("\n=== ÉTAPE 4: Recherche avec reranking ===")
best_matches = processor.find_best_template_with_reranking(
transcription=test_text,
initial_candidates=15,
final_results=3
)
# Affichage des résultats
for i, match in enumerate(best_matches):
print(f"\n--- Match {i+1} ---")
print(f"Template ID: {match['id']}")
print(f"Score final: {match.get('final_score', match['similarity_score']):.4f}")
print(f"Score cross-encoder: {match.get('cross_encoder_score', 'N/A')}")
print(f"Extrait du texte: {match['document'][:200]}...")
# Affichage des entités détectées dans la query
query_entities = match.get('query_entities', {})
for entity_type, entities in query_entities.items():
if entities:
print(f" - {entity_type}: {[f'{e[0]} ({e[1]:.2f})' for e in entities[:3]]}")
# 5. Export des templates traités
print("\n=== ÉTAPE 5: Export des résultats ===")
processor.export_processed_templates("processed_medical_templates.json")
# 6. Génération d'un rapport sur le dataset
print("\n=== ÉTAPE 6: Rapport du dataset ===")
report = analyzer.generate_entity_report("dataset.json")
print(f"Rapport généré pour {report['total_samples']} échantillons")
for entity_type, stats in report['entity_statistics'].items():
if stats['unique_count'] > 0:
print(f"\n{entity_type.upper()}:")
print(f" - Entités uniques: {stats['unique_count']}")
print(f" - Occurrences totales: {stats['total_occurrences']}")
print(f" - Top 3: {stats['top_10'][:3]}")
if __name__ == "__main__":
main()