Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import numpy as np | |
| from typing import List, Dict, Any, Optional, Tuple, Union | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| # Core libraries | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, AutoModel, AutoModelForTokenClassification, | |
| TrainingArguments, Trainer, pipeline, DataCollatorForTokenClassification | |
| ) | |
| from torch.utils.data import Dataset | |
| import torch.nn.functional as F | |
| # Vector database | |
| import chromadb | |
| from chromadb.config import Settings | |
| # Utilities | |
| import logging | |
| from tqdm import tqdm | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class MedicalEntity: | |
| """Structure pour les entités médicales extraites par NER""" | |
| exam_types: List[Tuple[str, float]] # (entity, confidence) | |
| specialties: List[Tuple[str, float]] | |
| anatomical_regions: List[Tuple[str, float]] | |
| pathologies: List[Tuple[str, float]] | |
| medical_procedures: List[Tuple[str, float]] | |
| measurements: List[Tuple[str, float]] | |
| medications: List[Tuple[str, float]] | |
| symptoms: List[Tuple[str, float]] | |
| class MedicalNERDataset(Dataset): | |
| """Dataset personnalisé pour l'entraînement NER médical""" | |
| def __init__(self, texts, labels, tokenizer, max_length=512): | |
| self.texts = texts | |
| self.labels = labels | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| text = self.texts[idx] | |
| labels = self.labels[idx] | |
| # Tokenisation | |
| encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=self.max_length, | |
| return_offsets_mapping=True, | |
| return_tensors='pt' | |
| ) | |
| # Alignement des labels avec les tokens | |
| aligned_labels = self._align_labels_with_tokens( | |
| labels, encoding.offset_mapping.squeeze().tolist() | |
| ) | |
| return { | |
| 'input_ids': encoding.input_ids.flatten(), | |
| 'attention_mask': encoding.attention_mask.flatten(), | |
| 'labels': torch.tensor(aligned_labels, dtype=torch.long) | |
| } | |
| def _align_labels_with_tokens(self, labels, offset_mapping): | |
| """Aligne les labels BIO avec les tokens du tokenizer""" | |
| aligned_labels = [] | |
| label_idx = 0 | |
| for start, end in offset_mapping: | |
| if start == 0 and end == 0: # Token spécial [CLS], [SEP], [PAD] | |
| aligned_labels.append(-100) # Ignore dans la loss | |
| else: | |
| if label_idx < len(labels): | |
| aligned_labels.append(labels[label_idx]) | |
| label_idx += 1 | |
| else: | |
| aligned_labels.append(0) # O (Outside) | |
| return aligned_labels | |
| class AdvancedMedicalNER: | |
| """NER médical avancé basé sur CamemBERT-Bio fine-tuné""" | |
| def __init__(self, model_name: str = "auto", cache_dir: str = "./models_cache"): | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(exist_ok=True) | |
| # Auto-détection du meilleur modèle NER médical disponible | |
| self.model_name = self._select_best_model(model_name) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Labels BIO pour entités médicales | |
| self.entity_labels = [ | |
| "O", # Outside | |
| "B-EXAM_TYPES", "I-EXAM_TYPES", # Types d'examens | |
| "B-SPECIALTIES", "I-SPECIALTIES", # Spécialités médicales | |
| "B-ANATOMICAL_REGIONS", "I-ANATOMICAL_REGIONS", # Régions anatomiques | |
| "B-PATHOLOGIES", "I-PATHOLOGIES", # Pathologies | |
| "B-PROCEDURES", "I-PROCEDURES", # Procédures médicales | |
| "B-MEASUREMENTS", "I-MEASUREMENTS", # Mesures/valeurs | |
| "B-MEDICATIONS", "I-MEDICATIONS", # Médicaments | |
| "B-SYMPTOMS", "I-SYMPTOMS" # Symptômes | |
| ] | |
| self.id2label = {i: label for i, label in enumerate(self.entity_labels)} | |
| self.label2id = {label: i for i, label in enumerate(self.entity_labels)} | |
| # Chargement du modèle NER | |
| self._load_ner_model() | |
| def _select_best_model(self, model_name: str) -> str: | |
| """Sélection automatique du meilleur modèle NER médical""" | |
| if model_name != "auto": | |
| return model_name | |
| # Liste des modèles par ordre de préférence | |
| preferred_models = [ | |
| "almanach/camembert-bio-base", # CamemBERT Bio français | |
| "Dr-BERT/DrBERT-7GB", # DrBERT spécialisé | |
| "emilyalsentzer/Bio_ClinicalBERT", # Bio Clinical BERT | |
| "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", | |
| "dmis-lab/biobert-base-cased-v1.2", # BioBERT | |
| "camembert-base" # Fallback CamemBERT standard | |
| ] | |
| for model in preferred_models: | |
| try: | |
| # Test de disponibilité | |
| AutoTokenizer.from_pretrained(model, cache_dir=self.cache_dir) | |
| logger.info(f"Modèle sélectionné: {model}") | |
| return model | |
| except: | |
| continue | |
| # Fallback ultime | |
| logger.warning("Utilisation du modèle de base camembert-base") | |
| return "camembert-base" | |
| def _load_ner_model(self): | |
| """Charge ou crée le modèle NER fine-tuné""" | |
| fine_tuned_path = self.cache_dir / "medical_ner_model" | |
| if fine_tuned_path.exists(): | |
| logger.info("Chargement du modèle NER fine-tuné existant") | |
| self.tokenizer = AutoTokenizer.from_pretrained(fine_tuned_path) | |
| self.ner_model = AutoModelForTokenClassification.from_pretrained(fine_tuned_path) | |
| else: | |
| logger.info("Création d'un nouveau modèle NER médical") | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, cache_dir=self.cache_dir) | |
| # Modèle pour classification de tokens (NER) | |
| self.ner_model = AutoModelForTokenClassification.from_pretrained( | |
| self.model_name, | |
| num_labels=len(self.entity_labels), | |
| id2label=self.id2label, | |
| label2id=self.label2id, | |
| cache_dir=self.cache_dir | |
| ) | |
| self.ner_model.to(self.device) | |
| # Pipeline NER | |
| self.ner_pipeline = pipeline( | |
| "token-classification", | |
| model=self.ner_model, | |
| tokenizer=self.tokenizer, | |
| device=0 if torch.cuda.is_available() else -1, | |
| aggregation_strategy="simple" | |
| ) | |
| def extract_entities(self, text: str) -> MedicalEntity: | |
| """Extraction d'entités avec le modèle NER fine-tuné""" | |
| # Prédiction NER | |
| try: | |
| ner_results = self.ner_pipeline(text) | |
| except Exception as e: | |
| logger.error(f"Erreur NER: {e}") | |
| return MedicalEntity([], [], [], [], [], [], [], []) | |
| # Groupement des entités par type | |
| entities = { | |
| "EXAM_TYPES": [], | |
| "SPECIALTIES": [], | |
| "ANATOMICAL_REGIONS": [], | |
| "PATHOLOGIES": [], | |
| "PROCEDURES": [], | |
| "MEASUREMENTS": [], | |
| "MEDICATIONS": [], | |
| "SYMPTOMS": [] | |
| } | |
| for result in ner_results: | |
| entity_type = result['entity_group'].replace('B-', '').replace('I-', '') | |
| entity_text = result['word'] | |
| confidence = result['score'] | |
| if entity_type in entities and confidence > 0.7: # Seuil de confiance | |
| entities[entity_type].append((entity_text, confidence)) | |
| return MedicalEntity( | |
| exam_types=entities["EXAM_TYPES"], | |
| specialties=entities["SPECIALTIES"], | |
| anatomical_regions=entities["ANATOMICAL_REGIONS"], | |
| pathologies=entities["PATHOLOGIES"], | |
| medical_procedures=entities["PROCEDURES"], | |
| measurements=entities["MEASUREMENTS"], | |
| medications=entities["MEDICATIONS"], | |
| symptoms=entities["SYMPTOMS"] | |
| ) | |
| def load_dataset(self, dataset_path: str) -> List[Dict]: | |
| """Charge le dataset depuis le fichier JSON""" | |
| try: | |
| with open(dataset_path, 'r', encoding='utf-8') as f: | |
| # Chaque ligne est un objet JSON séparé | |
| data = [] | |
| for line in f: | |
| if line.strip(): | |
| data.append(json.loads(line.strip())) | |
| return data | |
| except Exception as e: | |
| logger.error(f"Erreur lors du chargement du dataset: {e}") | |
| return [] | |
| """ | |
| def _text_to_bio_labels(self, text: str, entities_dict: Dict[str, List[str]]) -> List[int]: | |
| #Convertit le texte et les entités en labels BI en utilisant les offsets | |
| # Tokenisation du texte | |
| tokens = self.tokenizer.tokenize(text) | |
| labels = [0] * len(tokens) # Initialisation avec "O" (Outside) | |
| # Mapping des types d'entités vers les labels BIO | |
| entity_type_mapping = { | |
| "exam_types": ("B-EXAM_TYPES", "I-EXAM_TYPES"), | |
| "specialties": ("B-SPECIALTIES", "I-SPECIALTIES"), | |
| "anatomical_regions": ("B-ANATOMICAL_REGIONS", "I-ANATOMICAL_REGIONS"), | |
| "pathologies": ("B-PATHOLOGIES", "I-PATHOLOGIES"), | |
| "procedures": ("B-PROCEDURES", "I-PROCEDURES"), | |
| "measurements": ("B-MEASUREMENTS", "I-MEASUREMENTS"), | |
| "medications": ("B-MEDICATIONS", "I-MEDICATIONS"), | |
| "symptoms": ("B-SYMPTOMS", "I-SYMPTOMS") | |
| } | |
| # Attribution des labels pour chaque type d'entité | |
| for entity_type, entity_list in entities_dict.items(): | |
| if entity_type in entity_type_mapping and entity_list: | |
| b_label, i_label = entity_type_mapping[entity_type] | |
| b_label_id = self.label2id[b_label] | |
| i_label_id = self.label2id[i_label] | |
| for entity in entity_list: | |
| # Recherche de l'entité dans le texte tokenizé | |
| entity_tokens = self.tokenizer.tokenize(entity.lower()) | |
| text_lower = text.lower() | |
| # Recherche de la position de l'entité | |
| start_pos = text_lower.find(entity.lower()) | |
| if start_pos != -1: | |
| # Approximation de la position dans les tokens | |
| # (méthode simplifiée - pourrait être améliorée) | |
| char_to_token_ratio = len(tokens) / len(text) | |
| approx_token_start = int(start_pos * char_to_token_ratio) | |
| approx_token_end = min( | |
| len(tokens), | |
| approx_token_start + len(entity_tokens) | |
| ) | |
| # Attribution des labels BIO | |
| for i in range(approx_token_start, approx_token_end): | |
| if i < len(labels): | |
| if i == approx_token_start: | |
| labels[i] = b_label_id # B- pour le premier token | |
| else: | |
| labels[i] = i_label_id # I- pour les tokens suivants | |
| return labels | |
| """ | |
| def _text_to_bio_labels(self, text: str, entities_dict: Dict[str, List[str]]) -> List[int]: | |
| """Convertit le texte et les entités en labels BIO en utilisant les offsets (robuste)""" | |
| # Encodage avec offsets | |
| encoding = self.tokenizer( | |
| text, | |
| return_offsets_mapping=True, | |
| add_special_tokens=False | |
| ) | |
| tokens = encoding.tokens() | |
| offsets = encoding["offset_mapping"] | |
| labels = [self.label2id["O"]] * len(tokens) # Initialisation avec "O" | |
| # Mapping des types d'entités vers les labels BIO | |
| entity_type_mapping = { | |
| "exam_types": ("B-EXAM_TYPES", "I-EXAM_TYPES"), | |
| "specialties": ("B-SPECIALTIES", "I-SPECIALTIES"), | |
| "anatomical_regions": ("B-ANATOMICAL_REGIONS", "I-ANATOMICAL_REGIONS"), | |
| "pathologies": ("B-PATHOLOGIES", "I-PATHOLOGIES"), | |
| "procedures": ("B-PROCEDURES", "I-PROCEDURES"), | |
| "measurements": ("B-MEASUREMENTS", "I-MEASUREMENTS"), | |
| "medications": ("B-MEDICATIONS", "I-MEDICATIONS"), | |
| "symptoms": ("B-SYMPTOMS", "I-SYMPTOMS") | |
| } | |
| # Attribution des labels | |
| for entity_type, entity_list in entities_dict.items(): | |
| if entity_type in entity_type_mapping and entity_list: | |
| b_label, i_label = entity_type_mapping[entity_type] | |
| b_label_id = self.label2id[b_label] | |
| i_label_id = self.label2id[i_label] | |
| for entity in entity_list: | |
| start_char = text.lower().find(entity.lower()) | |
| if start_char == -1: | |
| continue | |
| end_char = start_char + len(entity) | |
| # Trouver tous les tokens qui chevauchent l’entité | |
| entity_token_idxs = [ | |
| i for i, (tok_start, tok_end) in enumerate(offsets) | |
| if tok_start < end_char and tok_end > start_char | |
| ] | |
| if not entity_token_idxs: | |
| continue | |
| # Attribution BIO | |
| for j, tok_idx in enumerate(entity_token_idxs): | |
| if j == 0: | |
| labels[tok_idx] = b_label_id | |
| else: | |
| labels[tok_idx] = i_label_id | |
| return labels | |
| def _prepare_training_data(self, templates_data: List[Dict]) -> Dict: | |
| """Prépare les données d'entraînement pour le NER à partir du dataset""" | |
| if not templates_data: | |
| logger.warning("Aucune donnée de template fournie") | |
| return {'train': MedicalNERDataset([], [], self.tokenizer)} | |
| texts = [] | |
| labels = [] | |
| logger.info(f"Préparation de {len(templates_data)} échantillons pour l'entraînement") | |
| for sample in tqdm(templates_data, desc="Conversion en format BIO"): | |
| try: | |
| text = sample['text'] | |
| entities_dict = sample['labels'] | |
| # Conversion en labels BIO | |
| bio_labels = self._text_to_bio_labels(text, entities_dict) | |
| texts.append(text) | |
| labels.append(bio_labels) | |
| except Exception as e: | |
| logger.error(f"Erreur lors du traitement d'un échantillon: {e}") | |
| continue | |
| if not texts: | |
| logger.error("Aucun échantillon valide trouvé pour l'entraînement") | |
| return {'train': MedicalNERDataset([], [], self.tokenizer)} | |
| # Division train/validation si suffisamment de données | |
| if len(texts) > 10: | |
| train_texts, val_texts, train_labels, val_labels = train_test_split( | |
| texts, labels, test_size=0.2, random_state=42 | |
| ) | |
| train_dataset = MedicalNERDataset(train_texts, train_labels, self.tokenizer) | |
| val_dataset = MedicalNERDataset(val_texts, val_labels, self.tokenizer) | |
| logger.info(f"Dataset divisé: {len(train_texts)} train, {len(val_texts)} validation") | |
| return {'train': train_dataset, 'eval': val_dataset} | |
| else: | |
| train_dataset = MedicalNERDataset(texts, labels, self.tokenizer) | |
| logger.info(f"Dataset d'entraînement: {len(texts)} échantillons") | |
| return {'train': train_dataset} | |
| def fine_tune_on_templates(self, templates_data: List[Dict] = None, | |
| dataset_path: str = "dataset.json", | |
| output_dir: str = None, | |
| epochs: int = 3): | |
| """Fine-tuning du modèle NER sur des templates médicaux""" | |
| if output_dir is None: | |
| output_dir = self.cache_dir / "medical_ner_model" | |
| # Chargement des données | |
| if templates_data is None: | |
| logger.info(f"Chargement du dataset depuis {dataset_path}") | |
| templates_data = self.load_dataset(dataset_path) | |
| if not templates_data: | |
| logger.error("Aucune donnée disponible pour l'entraînement") | |
| return | |
| logger.info("Début du fine-tuning NER sur templates médicaux") | |
| # Préparation des données d'entraînement | |
| datasets = self._prepare_training_data(templates_data) | |
| if len(datasets['train']) == 0: | |
| logger.error("Dataset d'entraînement vide") | |
| return | |
| # Data collator pour gérer le padding | |
| data_collator = DataCollatorForTokenClassification( | |
| tokenizer=self.tokenizer, | |
| padding=True | |
| ) | |
| # Configuration d'entraînement | |
| training_args = TrainingArguments( | |
| output_dir=str(output_dir), | |
| num_train_epochs=epochs, | |
| per_device_train_batch_size=8, | |
| per_device_eval_batch_size=8, | |
| warmup_steps=500, | |
| weight_decay=0.01, | |
| logging_dir=f"{output_dir}/logs", | |
| logging_steps=50, | |
| save_strategy="epoch", | |
| evaluation_strategy="epoch" if 'eval' in datasets else "no", | |
| load_best_model_at_end=True if 'eval' in datasets else False, | |
| metric_for_best_model="eval_loss" if 'eval' in datasets else None, | |
| greater_is_better=False, | |
| remove_unused_columns=False, | |
| ) | |
| # Fonction de calcul des métriques | |
| def compute_metrics(eval_pred): | |
| predictions, labels = eval_pred | |
| predictions = np.argmax(predictions, axis=2) | |
| # Calcul de l'accuracy en ignorant les labels -100 | |
| mask = labels != -100 | |
| accuracy = (predictions[mask] == labels[mask]).mean() | |
| return {"accuracy": accuracy} | |
| # Trainer | |
| trainer = Trainer( | |
| model=self.ner_model, | |
| args=training_args, | |
| train_dataset=datasets['train'], | |
| eval_dataset=datasets.get('eval'), | |
| tokenizer=self.tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics if 'eval' in datasets else None, | |
| ) | |
| # Entraînement | |
| logger.info("Début de l'entraînement...") | |
| trainer.train() | |
| # Sauvegarde | |
| trainer.save_model() | |
| self.tokenizer.save_pretrained(output_dir) | |
| # Recharger le modèle et le pipeline | |
| self._load_ner_model() | |
| logger.info(f"Fine-tuning terminé, modèle sauvé dans {output_dir}") | |
| # Affichage des métriques finales si évaluation disponible | |
| if 'eval' in datasets: | |
| eval_results = trainer.evaluate() | |
| logger.info(f"Métriques finales: {eval_results}") | |
| class AdvancedMedicalEmbedding: | |
| """Générateur d'embeddings médicaux avancés avec cross-encoder reranking""" | |
| def __init__(self, | |
| base_model: str = "almanach/camembert-bio-base", | |
| cross_encoder_model: str = "auto"): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.base_model_name = base_model | |
| # Modèle principal pour embeddings | |
| self._load_base_model() | |
| # Cross-encoder pour reranking | |
| self._load_cross_encoder(cross_encoder_model) | |
| def _load_base_model(self): | |
| """Charge le modèle de base pour les embeddings""" | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name) | |
| self.base_model = AutoModel.from_pretrained(self.base_model_name) | |
| self.base_model.to(self.device) | |
| logger.info(f"Modèle de base chargé: {self.base_model_name}") | |
| except Exception as e: | |
| logger.error(f"Erreur chargement modèle de base: {e}") | |
| raise | |
| def _load_cross_encoder(self, model_name: str): | |
| """Charge le cross-encoder pour reranking""" | |
| if model_name == "auto": | |
| # Sélection automatique du meilleur cross-encoder médical | |
| cross_encoders = [ | |
| "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", | |
| "emilyalsentzer/Bio_ClinicalBERT", | |
| self.base_model_name # Fallback | |
| ] | |
| for model in cross_encoders: | |
| try: | |
| self.cross_tokenizer = AutoTokenizer.from_pretrained(model) | |
| self.cross_model = AutoModel.from_pretrained(model) | |
| self.cross_model.to(self.device) | |
| logger.info(f"Cross-encoder chargé: {model}") | |
| break | |
| except: | |
| continue | |
| else: | |
| self.cross_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.cross_model = AutoModel.from_pretrained(model_name) | |
| self.cross_model.to(self.device) | |
| def generate_embedding(self, text: str, entities: MedicalEntity = None) -> np.ndarray: | |
| """Génère un embedding enrichi pour un texte médical""" | |
| # Tokenisation | |
| inputs = self.tokenizer( | |
| text, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Génération embedding | |
| with torch.no_grad(): | |
| outputs = self.base_model(**inputs) | |
| # Mean pooling | |
| attention_mask = inputs['attention_mask'] | |
| token_embeddings = outputs.last_hidden_state | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| # Enrichissement avec entités NER | |
| if entities: | |
| embedding = self._enrich_with_ner_entities(embedding, entities) | |
| return embedding.cpu().numpy().flatten().astype(np.float32) | |
| def _enrich_with_ner_entities(self, base_embedding: torch.Tensor, entities: MedicalEntity) -> torch.Tensor: | |
| """Enrichit l'embedding avec les entités extraites par NER""" | |
| # Concaténer les entités importantes avec leurs scores de confiance | |
| entity_texts = [] | |
| confidence_weights = [] | |
| for entity_list in [entities.exam_types, entities.specialties, | |
| entities.anatomical_regions, entities.pathologies]: | |
| for entity_text, confidence in entity_list: | |
| entity_texts.append(entity_text) | |
| confidence_weights.append(confidence) | |
| if not entity_texts: | |
| return base_embedding | |
| # Génération d'embeddings pour les entités | |
| entity_text_combined = " [SEP] ".join(entity_texts) | |
| entity_inputs = self.tokenizer( | |
| entity_text_combined, | |
| padding=True, | |
| truncation=True, | |
| max_length=256, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| entity_outputs = self.base_model(**entity_inputs) | |
| entity_embedding = torch.mean(entity_outputs.last_hidden_state, dim=1) | |
| # Fusion pondérée par les scores de confiance | |
| avg_confidence = np.mean(confidence_weights) if confidence_weights else 0.5 | |
| fusion_weight = min(0.4, avg_confidence) # Max 40% pour les entités | |
| enriched_embedding = (1 - fusion_weight) * base_embedding + fusion_weight * entity_embedding | |
| return enriched_embedding | |
| def cross_encoder_rerank(self, | |
| query: str, | |
| candidates: List[Dict], | |
| top_k: int = 3) -> List[Dict]: | |
| """Reranking avec cross-encoder pour affiner la sélection""" | |
| if len(candidates) <= top_k: | |
| return candidates | |
| reranked_candidates = [] | |
| for candidate in candidates: | |
| # Création de la paire query-candidate | |
| pair_text = f"{query} [SEP] {candidate['document']}" | |
| # Tokenisation | |
| inputs = self.cross_tokenizer( | |
| pair_text, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Score de similarité cross-encoder | |
| with torch.no_grad(): | |
| outputs = self.cross_model(**inputs) | |
| # Utilisation du [CLS] token pour le score de similarité | |
| cls_embedding = outputs.last_hidden_state[:, 0, :] | |
| similarity_score = torch.sigmoid(torch.mean(cls_embedding)).item() | |
| candidate_copy = candidate.copy() | |
| candidate_copy['cross_encoder_score'] = similarity_score | |
| candidate_copy['final_score'] = ( | |
| 0.6 * candidate['similarity_score'] + | |
| 0.4 * similarity_score | |
| ) | |
| reranked_candidates.append(candidate_copy) | |
| # Tri par score final | |
| reranked_candidates.sort(key=lambda x: x['final_score'], reverse=True) | |
| return reranked_candidates[:top_k] | |
| class MedicalTemplateVectorDB: | |
| """Base de données vectorielle optimisée pour templates médicaux""" | |
| def __init__(self, db_path: str = "./medical_vector_db", collection_name: str = "medical_templates"): | |
| self.db_path = db_path | |
| self.collection_name = collection_name | |
| # ChromaDB avec configuration optimisée | |
| self.client = chromadb.PersistentClient( | |
| path=db_path, | |
| settings=Settings( | |
| anonymized_telemetry=False, | |
| allow_reset=True | |
| ) | |
| ) | |
| # Collection avec métrique de distance optimisée | |
| try: | |
| self.collection = self.client.get_collection(collection_name) | |
| logger.info(f"Collection '{collection_name}' chargée") | |
| except: | |
| self.collection = self.client.create_collection( | |
| name=collection_name, | |
| metadata={ | |
| "hnsw:space": "cosine", | |
| "hnsw:M": 32, # Connectivité du graphe | |
| "hnsw:ef_construction": 200, # Qualité vs vitesse construction | |
| "hnsw:ef_search": 50 # Qualité vs vitesse recherche | |
| } | |
| ) | |
| logger.info(f"Collection '{collection_name}' créée avec optimisations HNSW") | |
| def add_template(self, | |
| template_id: str, | |
| template_text: str, | |
| embedding: np.ndarray, | |
| entities: MedicalEntity, | |
| metadata: Dict[str, Any] = None): | |
| """Ajoute un template avec métadonnées enrichies par NER""" | |
| # Métadonnées automatiques basées sur NER | |
| auto_metadata = { | |
| "exam_types": [entity[0] for entity in entities.exam_types], | |
| "specialties": [entity[0] for entity in entities.specialties], | |
| "anatomical_regions": [entity[0] for entity in entities.anatomical_regions], | |
| "pathologies": [entity[0] for entity in entities.pathologies], | |
| "procedures": [entity[0] for entity in entities.medical_procedures], | |
| "text_length": len(template_text), | |
| "entity_confidence_avg": np.mean([ | |
| entity[1] for entity_list in [ | |
| entities.exam_types, entities.specialties, | |
| entities.anatomical_regions, entities.pathologies | |
| ] for entity in entity_list | |
| ]) if any([entities.exam_types, entities.specialties, | |
| entities.anatomical_regions, entities.pathologies]) else 0.0 | |
| } | |
| if metadata: | |
| auto_metadata.update(metadata) | |
| self.collection.add( | |
| embeddings=[embedding.tolist()], | |
| documents=[template_text], | |
| metadatas=[auto_metadata], | |
| ids=[template_id] | |
| ) | |
| logger.info(f"Template {template_id} ajouté avec métadonnées NER automatiques") | |
| def advanced_search(self, | |
| query_embedding: np.ndarray, | |
| n_results: int = 10, | |
| entity_filters: Dict[str, List[str]] = None, | |
| confidence_threshold: float = 0.0) -> List[Dict]: | |
| """Recherche avancée avec filtres basés sur entités NER""" | |
| where_clause = {} | |
| # Filtres basés sur entités NER extraites | |
| if entity_filters: | |
| for entity_type, entity_values in entity_filters.items(): | |
| if entity_values: | |
| where_clause[entity_type] = {"$in": entity_values} | |
| # Filtre par confiance moyenne des entités | |
| if confidence_threshold > 0: | |
| where_clause["entity_confidence_avg"] = {"$gte": confidence_threshold} | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding.tolist()], | |
| n_results=n_results, | |
| where=where_clause if where_clause else None, | |
| include=["documents", "metadatas", "distances"] | |
| ) | |
| # Formatage des résultats | |
| formatted_results = [] | |
| for i in range(len(results['ids'][0])): | |
| formatted_results.append({ | |
| 'id': results['ids'][0][i], | |
| 'document': results['documents'][0][i], | |
| 'metadata': results['metadatas'][0][i], | |
| 'similarity_score': 1 - results['distances'][0][i], | |
| 'distance': results['distances'][0][i] | |
| }) | |
| return formatted_results | |
| class AdvancedMedicalTemplateProcessor: | |
| """Processeur avancé avec NER fine-tuné et reranking cross-encoder""" | |
| def __init__(self, | |
| base_model: str = "almanach/camembert-bio-base", | |
| db_path: str = "./advanced_medical_vector_db"): | |
| self.ner_extractor = AdvancedMedicalNER() | |
| self.embedding_generator = AdvancedMedicalEmbedding(base_model) | |
| self.vector_db = MedicalTemplateVectorDB(db_path) | |
| logger.info("Processeur médical avancé initialisé avec NER fine-tuné et cross-encoder reranking") | |
| def process_templates_batch(self, | |
| templates: List[Dict[str, str]] = None, | |
| dataset_path: str = "dataset.json", | |
| batch_size: int = 8, | |
| fine_tune_ner: bool = False) -> None: | |
| """Traitement avancé avec option de fine-tuning NER""" | |
| # Chargement des données si templates non fournis | |
| if templates is None: | |
| logger.info(f"Chargement des templates depuis {dataset_path}") | |
| templates = self.ner_extractor.load_dataset(dataset_path) | |
| # Conversion du format dataset vers le format attendu | |
| templates = [ | |
| { | |
| 'id': f"template_{i:04d}", | |
| 'text': template['text'], | |
| 'metadata': {'labels': template.get('labels', {})} | |
| } | |
| for i, template in enumerate(templates) | |
| ] | |
| if fine_tune_ner: | |
| logger.info("Fine-tuning du modèle NER sur les templates...") | |
| # Reconversion pour le fine-tuning | |
| training_data = [ | |
| { | |
| 'text': template['text'], | |
| 'labels': template['metadata'].get('labels', {}) | |
| } | |
| for template in templates | |
| ] | |
| self.ner_extractor.fine_tune_on_templates(training_data) | |
| logger.info(f"Traitement avancé de {len(templates)} templates") | |
| for i in tqdm(range(0, len(templates), batch_size), desc="Traitement avancé"): | |
| batch = templates[i:i+batch_size] | |
| for template in batch: | |
| try: | |
| template_id = template['id'] | |
| template_text = template['text'] | |
| metadata = template.get('metadata', {}) | |
| # NER avancé | |
| entities = self.ner_extractor.extract_entities(template_text) | |
| # Embedding enrichi | |
| embedding = self.embedding_generator.generate_embedding(template_text, entities) | |
| # Stockage avec métadonnées NER | |
| self.vector_db.add_template( | |
| template_id=template_id, | |
| template_text=template_text, | |
| embedding=embedding, | |
| entities=entities, | |
| metadata=metadata | |
| ) | |
| except Exception as e: | |
| logger.error(f"Erreur traitement template {template.get('id', 'unknown')}: {e}") | |
| continue | |
| def find_best_template_with_reranking(self, | |
| transcription: str, | |
| initial_candidates: int = 10, | |
| final_results: int = 3) -> List[Dict]: | |
| """Recherche optimale avec reranking cross-encoder""" | |
| # 1. Extraction NER de la transcription | |
| query_entities = self.ner_extractor.extract_entities(transcription) | |
| # 2. Génération embedding enrichi | |
| query_embedding = self.embedding_generator.generate_embedding(transcription, query_entities) | |
| # 3. Filtres automatiques basés sur entités extraites | |
| entity_filters = {} | |
| if query_entities.exam_types: | |
| entity_filters['exam_types'] = [entity[0] for entity in query_entities.exam_types] | |
| if query_entities.specialties: | |
| entity_filters['specialties'] = [entity[0] for entity in query_entities.specialties] | |
| if query_entities.anatomical_regions: | |
| entity_filters['anatomical_regions'] = [entity[0] for entity in query_entities.anatomical_regions] | |
| # 4. Recherche vectorielle initiale | |
| initial_candidates_results = self.vector_db.advanced_search( | |
| query_embedding=query_embedding, | |
| n_results=initial_candidates, | |
| entity_filters=entity_filters, | |
| confidence_threshold=0.6 | |
| ) | |
| # 5. Reranking avec cross-encoder | |
| if len(initial_candidates_results) > final_results: | |
| final_results_reranked = self.embedding_generator.cross_encoder_rerank( | |
| query=transcription, | |
| candidates=initial_candidates_results, | |
| top_k=final_results | |
| ) | |
| else: | |
| final_results_reranked = initial_candidates_results | |
| # 6. Enrichissement des résultats avec détails NER | |
| for result in final_results_reranked: | |
| result['query_entities'] = { | |
| 'exam_types': query_entities.exam_types, | |
| 'specialties': query_entities.specialties, | |
| 'anatomical_regions': query_entities.anatomical_regions, | |
| 'pathologies': query_entities.pathologies | |
| } | |
| return final_results_reranked | |
| def evaluate_ner_performance(self, test_dataset_path: str = None) -> Dict[str, float]: | |
| """Évalue les performances du modèle NER fine-tuné""" | |
| if test_dataset_path is None: | |
| logger.warning("Aucun dataset de test fourni pour l'évaluation") | |
| return {} | |
| test_data = self.ner_extractor.load_dataset(test_dataset_path) | |
| if not test_data: | |
| logger.error("Dataset de test vide") | |
| return {} | |
| correct_predictions = 0 | |
| total_entities = 0 | |
| entity_type_stats = {} | |
| for sample in tqdm(test_data, desc="Évaluation NER"): | |
| text = sample['text'] | |
| true_entities = sample['labels'] | |
| # Prédiction | |
| predicted_entities = self.ner_extractor.extract_entities(text) | |
| # Conversion en format comparable | |
| predicted_dict = { | |
| 'exam_types': [entity[0].lower() for entity in predicted_entities.exam_types], | |
| 'specialties': [entity[0].lower() for entity in predicted_entities.specialties], | |
| 'anatomical_regions': [entity[0].lower() for entity in predicted_entities.anatomical_regions], | |
| 'pathologies': [entity[0].lower() for entity in predicted_entities.pathologies], | |
| 'procedures': [entity[0].lower() for entity in predicted_entities.medical_procedures], | |
| 'measurements': [entity[0].lower() for entity in predicted_entities.measurements], | |
| 'medications': [entity[0].lower() for entity in predicted_entities.medications], | |
| 'symptoms': [entity[0].lower() for entity in predicted_entities.symptoms] | |
| } | |
| # Comparaison | |
| for entity_type, true_entities_list in true_entities.items(): | |
| if entity_type in predicted_dict: | |
| predicted_entities_list = predicted_dict[entity_type] | |
| # Statistiques par type d'entité | |
| if entity_type not in entity_type_stats: | |
| entity_type_stats[entity_type] = {'correct': 0, 'total': 0} | |
| true_entities_lower = [entity.lower() for entity in true_entities_list] | |
| for true_entity in true_entities_lower: | |
| total_entities += 1 | |
| entity_type_stats[entity_type]['total'] += 1 | |
| if true_entity in predicted_entities_list: | |
| correct_predictions += 1 | |
| entity_type_stats[entity_type]['correct'] += 1 | |
| # Calcul des métriques | |
| overall_accuracy = correct_predictions / total_entities if total_entities > 0 else 0 | |
| metrics = { | |
| 'overall_accuracy': overall_accuracy, | |
| 'total_entities': total_entities, | |
| 'correct_predictions': correct_predictions | |
| } | |
| # Métriques par type d'entité | |
| for entity_type, stats in entity_type_stats.items(): | |
| if stats['total'] > 0: | |
| accuracy = stats['correct'] / stats['total'] | |
| metrics[f'{entity_type}_accuracy'] = accuracy | |
| metrics[f'{entity_type}_total'] = stats['total'] | |
| logger.info(f"Évaluation NER terminée - Accuracy globale: {overall_accuracy:.4f}") | |
| return metrics | |
| def export_processed_templates(self, output_path: str = "processed_templates.json"): | |
| """Exporte les templates traités avec leurs embeddings et entités""" | |
| try: | |
| # Récupération de tous les templates de la base vectorielle | |
| all_results = self.vector_db.collection.get( | |
| include=["documents", "metadatas", "embeddings"] | |
| ) | |
| processed_templates = [] | |
| for i in range(len(all_results['ids'])): | |
| template_data = { | |
| 'id': all_results['ids'][i], | |
| 'text': all_results['documents'][i], | |
| 'metadata': all_results['metadatas'][i], | |
| 'embedding': all_results['embeddings'][i] if all_results.get('embeddings') else None | |
| } | |
| processed_templates.append(template_data) | |
| # Sauvegarde | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| json.dump(processed_templates, f, ensure_ascii=False, indent=2) | |
| logger.info(f"Templates traités exportés vers {output_path}") | |
| logger.info(f"Nombre de templates exportés: {len(processed_templates)}") | |
| except Exception as e: | |
| logger.error(f"Erreur lors de l'export: {e}") | |
| # Utilitaires pour l'analyse et le debugging | |
| class MedicalNERAnalyzer: | |
| """Outils d'analyse et de debugging pour le système NER médical""" | |
| def __init__(self, processor: AdvancedMedicalTemplateProcessor): | |
| self.processor = processor | |
| def analyze_text(self, text: str) -> Dict: | |
| """Analyse complète d'un texte médical""" | |
| # Extraction NER | |
| entities = self.processor.ner_extractor.extract_entities(text) | |
| # Génération d'embedding | |
| embedding = self.processor.embedding_generator.generate_embedding(text, entities) | |
| # Statistiques | |
| analysis = { | |
| 'text': text, | |
| 'text_length': len(text), | |
| 'entities': { | |
| 'exam_types': entities.exam_types, | |
| 'specialties': entities.specialties, | |
| 'anatomical_regions': entities.anatomical_regions, | |
| 'pathologies': entities.pathologies, | |
| 'procedures': entities.medical_procedures, | |
| 'measurements': entities.measurements, | |
| 'medications': entities.medications, | |
| 'symptoms': entities.symptoms | |
| }, | |
| 'embedding_shape': embedding.shape, | |
| 'entity_count_total': sum([ | |
| len(entities.exam_types), | |
| len(entities.specialties), | |
| len(entities.anatomical_regions), | |
| len(entities.pathologies), | |
| len(entities.medical_procedures), | |
| len(entities.measurements), | |
| len(entities.medications), | |
| len(entities.symptoms) | |
| ]), | |
| 'confidence_scores': { | |
| 'exam_types': [conf for _, conf in entities.exam_types], | |
| 'specialties': [conf for _, conf in entities.specialties], | |
| 'anatomical_regions': [conf for _, conf in entities.anatomical_regions], | |
| 'pathologies': [conf for _, conf in entities.pathologies] | |
| } | |
| } | |
| return analysis | |
| def compare_entities(self, text1: str, text2: str) -> Dict: | |
| """Compare les entités extraites de deux textes""" | |
| entities1 = self.processor.ner_extractor.extract_entities(text1) | |
| entities2 = self.processor.ner_extractor.extract_entities(text2) | |
| def entities_to_set(entities): | |
| all_entities = set() | |
| for entity_list in [entities.exam_types, entities.specialties, | |
| entities.anatomical_regions, entities.pathologies]: | |
| for entity, _ in entity_list: | |
| all_entities.add(entity.lower()) | |
| return all_entities | |
| set1 = entities_to_set(entities1) | |
| set2 = entities_to_set(entities2) | |
| return { | |
| 'text1_entities': list(set1), | |
| 'text2_entities': list(set2), | |
| 'common_entities': list(set1.intersection(set2)), | |
| 'unique_to_text1': list(set1.difference(set2)), | |
| 'unique_to_text2': list(set2.difference(set1)), | |
| 'similarity_ratio': len(set1.intersection(set2)) / len(set1.union(set2)) if set1.union(set2) else 0 | |
| } | |
| def generate_entity_report(self, dataset_path: str) -> Dict: | |
| """Génère un rapport statistique sur les entités du dataset""" | |
| dataset = self.processor.ner_extractor.load_dataset(dataset_path) | |
| entity_stats = { | |
| 'exam_types': {}, | |
| 'specialties': {}, | |
| 'anatomical_regions': {}, | |
| 'pathologies': {}, | |
| 'procedures': {}, | |
| 'measurements': {}, | |
| 'medications': {}, | |
| 'symptoms': {} | |
| } | |
| total_samples = len(dataset) | |
| for sample in tqdm(dataset, desc="Analyse du dataset"): | |
| labels = sample.get('labels', {}) | |
| for entity_type, entities in labels.items(): | |
| if entity_type in entity_stats: | |
| for entity in entities: | |
| entity_lower = entity.lower() | |
| if entity_lower not in entity_stats[entity_type]: | |
| entity_stats[entity_type][entity_lower] = 0 | |
| entity_stats[entity_type][entity_lower] += 1 | |
| # Génération du rapport | |
| report = { | |
| 'total_samples': total_samples, | |
| 'entity_statistics': {} | |
| } | |
| for entity_type, entity_counts in entity_stats.items(): | |
| if entity_counts: | |
| sorted_entities = sorted(entity_counts.items(), key=lambda x: x[1], reverse=True) | |
| report['entity_statistics'][entity_type] = { | |
| 'unique_count': len(entity_counts), | |
| 'total_occurrences': sum(entity_counts.values()), | |
| 'top_10': sorted_entities[:10], | |
| 'average_occurrences': sum(entity_counts.values()) / len(entity_counts) | |
| } | |
| return report | |
| # Exemple d'utilisation avancée | |
| def main(): | |
| """Exemple d'utilisation du système avancé avec fine-tuning""" | |
| # Initialisation du processeur avancé | |
| processor = AdvancedMedicalTemplateProcessor() | |
| # 1. Traitement des templates avec fine-tuning NER | |
| print("=== ÉTAPE 1: Traitement et Fine-tuning ===") | |
| processor.process_templates_batch( | |
| dataset_path="dataset.json", | |
| fine_tune_ner=True, # Active le fine-tuning | |
| batch_size=8 | |
| ) | |
| # 2. Évaluation des performances NER (optionnel, si dataset de test disponible) | |
| print("\n=== ÉTAPE 2: Évaluation des performances ===") | |
| # metrics = processor.evaluate_ner_performance("test_dataset.json") | |
| # print(f"Métriques d'évaluation: {metrics}") | |
| # 3. Analyse d'un texte médical | |
| print("\n=== ÉTAPE 3: Analyse de texte ===") | |
| analyzer = MedicalNERAnalyzer(processor) | |
| test_text = """madame bacon nicole bilan œdème droit gonalgies ostéophytes | |
| incontinence veineuse modérée portions surale droite crurale gauche saphéniennes""" | |
| analysis = analyzer.analyze_text(test_text) | |
| print(f"Analyse du texte:") | |
| print(f"- Nombre total d'entités: {analysis['entity_count_total']}") | |
| print(f"- Types d'examens détectés: {analysis['entities']['exam_types']}") | |
| print(f"- Régions anatomiques: {analysis['entities']['anatomical_regions']}") | |
| print(f"- Pathologies: {analysis['entities']['pathologies']}") | |
| # 4. Recherche avec reranking | |
| print("\n=== ÉTAPE 4: Recherche avec reranking ===") | |
| best_matches = processor.find_best_template_with_reranking( | |
| transcription=test_text, | |
| initial_candidates=15, | |
| final_results=3 | |
| ) | |
| # Affichage des résultats | |
| for i, match in enumerate(best_matches): | |
| print(f"\n--- Match {i+1} ---") | |
| print(f"Template ID: {match['id']}") | |
| print(f"Score final: {match.get('final_score', match['similarity_score']):.4f}") | |
| print(f"Score cross-encoder: {match.get('cross_encoder_score', 'N/A')}") | |
| print(f"Extrait du texte: {match['document'][:200]}...") | |
| # Affichage des entités détectées dans la query | |
| query_entities = match.get('query_entities', {}) | |
| for entity_type, entities in query_entities.items(): | |
| if entities: | |
| print(f" - {entity_type}: {[f'{e[0]} ({e[1]:.2f})' for e in entities[:3]]}") | |
| # 5. Export des templates traités | |
| print("\n=== ÉTAPE 5: Export des résultats ===") | |
| processor.export_processed_templates("processed_medical_templates.json") | |
| # 6. Génération d'un rapport sur le dataset | |
| print("\n=== ÉTAPE 6: Rapport du dataset ===") | |
| report = analyzer.generate_entity_report("dataset.json") | |
| print(f"Rapport généré pour {report['total_samples']} échantillons") | |
| for entity_type, stats in report['entity_statistics'].items(): | |
| if stats['unique_count'] > 0: | |
| print(f"\n{entity_type.upper()}:") | |
| print(f" - Entités uniques: {stats['unique_count']}") | |
| print(f" - Occurrences totales: {stats['total_occurrences']}") | |
| print(f" - Top 3: {stats['top_10'][:3]}") | |
| if __name__ == "__main__": | |
| main() |