""" Foundation 1.2 Clinical trial query system with 355M foundation model """ import gradio as gr import os from pathlib import Path import pickle import numpy as np from sentence_transformers import SentenceTransformer import logging from rank_bm25 import BM25Okapi import re from two_llm_system_FIXED import expand_query_with_355m, generate_clinical_response_with_xupract, rank_trials_with_355m logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize hf_token = os.getenv("HF_TOKEN") # Paths for data storage # Files will be downloaded from HF Dataset on first run DATASET_FILE = Path(__file__).parent / "complete_dataset_WITH_RESULTS_FULL.txt" CHUNKS_FILE = Path(__file__).parent / "dataset_chunks_TRIAL_AWARE.pkl" EMBEDDINGS_FILE = Path(__file__).parent / "dataset_embeddings_TRIAL_AWARE_FIXED.npy" # FIXED version to avoid cache INVERTED_INDEX_FILE = Path(__file__).parent / "inverted_index_TRIAL_AWARE.pkl" # Pre-built inverted index (638MB) # HF Dataset containing the large files DATASET_REPO = "gmkdigitalmedia/foundation1.2-data" # Global storage embedder = None doc_chunks = [] doc_embeddings = None bm25_index = None # BM25 index for fast keyword search inverted_index = None # Inverted index for instant drug lookup # ============================================================================ # RAG FUNCTIONS # ============================================================================ def load_embedder(): """Load L6 embedding model (matches generated embeddings)""" global embedder if embedder is None: logger.info("Loading MiniLM-L6 embedding model...") # Force CPU to avoid CUDA init in main process embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu') logger.info("L6 model loaded on CPU") def build_inverted_index(chunks): """ Build targeted inverted index for clinical search Maps drugs, diseases, companies, and endpoints to trial indices for O(1) lookup Indexes ONLY what matters: 1. INTERVENTION - drug/device names 2. CONDITIONS - diseases being treated 3. SPONSOR/COLLABORATOR/MANUFACTURER - company names 4. OUTCOME - trial endpoints (what's being measured) Does NOT index trial names (unnecessary noise) """ import time t_start = time.time() inv_index = {} logger.info("Building targeted index: drugs, diseases, companies, endpoints...") # Generic words to skip skip_words = { 'with', 'versus', 'combination', 'treatment', 'therapy', 'study', 'trial', 'phase', 'double', 'blind', 'placebo', 'group', 'control', 'active', 'randomized', 'multicenter', 'open', 'label', 'crossover' } for idx, chunk_data in enumerate(chunks): if idx % 100000 == 0 and idx > 0: logger.info(f" Indexed {idx:,}/{len(chunks):,} trials...") text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data text_lower = text.lower() # 1. DRUGS from INTERVENTION field intervention_match = re.search(r'intervention[:\s]+([^\n]+)', text_lower) if intervention_match: intervention_text = intervention_match.group(1) drugs = re.split(r'[,;\-\s]+', intervention_text) for drug in drugs: drug = drug.strip('.,;:() ') if len(drug) > 3 and drug not in skip_words: if drug not in inv_index: inv_index[drug] = [] if idx not in inv_index[drug]: inv_index[drug].append(idx) # 2. DISEASES from CONDITIONS field conditions_match = re.search(r'conditions?[:\s]+([^\n]+)', text_lower) if conditions_match: conditions_text = conditions_match.group(1) diseases = re.split(r'[,;\|]+', conditions_text) for disease in diseases: disease = disease.strip('.,;:() ') # Split multi-word conditions and index each significant word disease_words = re.findall(r'\b\w{4,}\b', disease) for word in disease_words: if word not in skip_words: if word not in inv_index: inv_index[word] = [] if idx not in inv_index[word]: inv_index[word].append(idx) # 3. COMPANIES from SPONSOR field sponsor_match = re.search(r'sponsor[:\s]+([^\n]+)', text_lower) if sponsor_match: sponsor_text = sponsor_match.group(1) sponsors = re.split(r'[,;\|]+', sponsor_text) for sponsor in sponsors: sponsor = sponsor.strip('.,;:() ') if len(sponsor) > 3: if sponsor not in inv_index: inv_index[sponsor] = [] if idx not in inv_index[sponsor]: inv_index[sponsor].append(idx) # 4. COMPANIES from COLLABORATOR field collab_match = re.search(r'collaborator[:\s]+([^\n]+)', text_lower) if collab_match: collab_text = collab_match.group(1) collaborators = re.split(r'[,;\|]+', collab_text) for collab in collaborators: collab = collab.strip('.,;:() ') if len(collab) > 3: if collab not in inv_index: inv_index[collab] = [] if idx not in inv_index[collab]: inv_index[collab].append(idx) # 5. COMPANIES from MANUFACTURER field manuf_match = re.search(r'manufacturer[:\s]+([^\n]+)', text_lower) if manuf_match: manuf_text = manuf_match.group(1) manufacturers = re.split(r'[,;\|]+', manuf_text) for manuf in manufacturers: manuf = manuf.strip('.,;:() ') if len(manuf) > 3: if manuf not in inv_index: inv_index[manuf] = [] if idx not in inv_index[manuf]: inv_index[manuf].append(idx) # 6. ENDPOINTS from OUTCOME fields # Look for outcome measures (what's being measured) outcome_matches = re.findall(r'outcome[:\s]+([^\n]+)', text_lower) for outcome_match in outcome_matches[:5]: # First 5 outcomes only # Extract meaningful endpoint terms endpoint_words = re.findall(r'\b\w{5,}\b', outcome_match) # 5+ char words for word in endpoint_words[:3]: # First 3 words per outcome if word not in skip_words and word not in {'outcome', 'measure', 'primary', 'secondary'}: if word not in inv_index: inv_index[word] = [] if idx not in inv_index[word]: inv_index[word].append(idx) t_elapsed = time.time() - t_start logger.info(f"✓ Targeted index built in {t_elapsed:.1f}s with {len(inv_index):,} terms") # Log sample entries for debugging (drugs, diseases, companies, endpoints) sample_terms = { 'drugs': ['keytruda', 'opdivo', 'humira'], 'diseases': ['cancer', 'diabetes', 'melanoma'], 'companies': ['novartis', 'pfizer', 'merck'], 'endpoints': ['survival', 'response', 'remission'] } for category, terms in sample_terms.items(): logger.info(f" {category.upper()} samples:") for term in terms: if term in inv_index: logger.info(f" '{term}' -> {len(inv_index[term])} trials") return inv_index def download_from_dataset(filename): """Download file from HF Dataset if not present locally""" from huggingface_hub import hf_hub_download import tempfile # Use /tmp for downloads (has write permissions in Docker) download_dir = Path("/tmp/foundation_data") download_dir.mkdir(exist_ok=True) local_file = download_dir / filename if local_file.exists(): logger.info(f"Found cached {filename}") return local_file try: logger.info(f"Downloading {filename} from {DATASET_REPO}...") downloaded_file = hf_hub_download( repo_id=DATASET_REPO, filename=filename, repo_type="dataset", local_dir=download_dir, local_dir_use_symlinks=False ) logger.info(f"Downloaded {filename}") return Path(downloaded_file) except Exception as e: logger.error(f"Failed to download {filename}: {e}") return None def load_embeddings(): """Load pre-generated embeddings (download from dataset if needed)""" global doc_chunks, doc_embeddings, bm25_index # Try to download if not present - store paths returned by download chunks_path = CHUNKS_FILE embeddings_path = EMBEDDINGS_FILE dataset_path = DATASET_FILE if not CHUNKS_FILE.exists(): downloaded = download_from_dataset("dataset_chunks_TRIAL_AWARE.pkl") if downloaded: chunks_path = downloaded if not EMBEDDINGS_FILE.exists(): downloaded = download_from_dataset("dataset_embeddings_TRIAL_AWARE_FIXED.npy") # FIXED version if downloaded: embeddings_path = downloaded if not DATASET_FILE.exists(): downloaded = download_from_dataset("complete_dataset_WITH_RESULTS_FULL.txt") if downloaded: dataset_path = downloaded if chunks_path.exists() and embeddings_path.exists(): try: logger.info("Loading embeddings from disk...") with open(chunks_path, 'rb') as f: doc_chunks = pickle.load(f) # Load embeddings loaded_embeddings = np.load(embeddings_path, allow_pickle=True) logger.info(f"Loaded embeddings type: {type(loaded_embeddings)}") # Check if it's already a proper numpy array if isinstance(loaded_embeddings, np.ndarray) and loaded_embeddings.ndim == 2: doc_embeddings = loaded_embeddings logger.info(f"✓ Embeddings are proper numpy array with shape: {doc_embeddings.shape}") elif isinstance(loaded_embeddings, list): logger.info(f"Converting embeddings from list to numpy array (memory efficient)...") # Convert in chunks to avoid memory spike chunk_size = 10000 total = len(loaded_embeddings) # DEBUG: Print first 3 items to see format logger.info(f"DEBUG: Total embeddings: {total}") logger.info(f"DEBUG: Type of first item: {type(loaded_embeddings[0])}") # Check if this is actually the chunks file (wrong file uploaded) if isinstance(loaded_embeddings[0], tuple) and len(loaded_embeddings[0]) == 2: if isinstance(loaded_embeddings[0][0], int) and isinstance(loaded_embeddings[0][1], str): raise ValueError( f"ERROR: The embeddings file contains (int, string) tuples!\n" f"This looks like the CHUNKS file was uploaded as the embeddings file.\n\n" f"First item: {loaded_embeddings[0][:2]}\n\n" f"Please re-upload the correct file:\n" f" CORRECT: dataset_embeddings_TRIAL_AWARE.npy (numpy array, 855 MB)\n" f" WRONG: dataset_chunks_TRIAL_AWARE.pkl (tuples, 2.8 GB)\n\n" f"The local file at /mnt/c/Users/ibm/Documents/HF/kg_to_model/dataset_embeddings_TRIAL_AWARE.npy is correct." ) if isinstance(loaded_embeddings[0], tuple): logger.info(f"DEBUG: Tuple length: {len(loaded_embeddings[0])}") for i, item in enumerate(loaded_embeddings[0][:5] if len(loaded_embeddings[0]) > 5 else loaded_embeddings[0]): logger.info(f"DEBUG: Tuple element {i}: type={type(item)}, preview={str(item)[:100]}") # Get embedding dimension from first item first_emb = loaded_embeddings[0] emb_idx = None # Initialize # Handle different formats if isinstance(first_emb, tuple): # Try both positions - could be (id, emb) or (emb, id) logger.info(f"DEBUG: Trying to find embedding vector in tuple...") emb_vector = None for idx, elem in enumerate(first_emb): if isinstance(elem, (list, np.ndarray)): emb_vector = elem emb_idx = idx logger.info(f"DEBUG: Found embedding at position {idx}") break if emb_vector is None: raise ValueError(f"No embedding vector found in tuple. Tuple contains: {[type(x) for x in first_emb]}") emb_dim = len(emb_vector) logger.info(f"DEBUG: Embedding dimension: {emb_dim}") elif isinstance(first_emb, list): emb_dim = len(first_emb) emb_idx = None elif isinstance(first_emb, np.ndarray): emb_dim = first_emb.shape[0] emb_idx = None else: raise ValueError(f"Unknown embedding format: {type(first_emb)}") logger.info(f"Creating array for {total} embeddings of dimension {emb_dim}") # Pre-allocate array doc_embeddings = np.zeros((total, emb_dim), dtype=np.float32) # Fill in chunks for i in range(0, total, chunk_size): end = min(i + chunk_size, total) # Extract embeddings from tuples if needed if isinstance(first_emb, tuple) and emb_idx is not None: # Extract just the embedding vector from each tuple at the correct position batch = [item[emb_idx] for item in loaded_embeddings[i:end]] doc_embeddings[i:end] = batch else: doc_embeddings[i:end] = loaded_embeddings[i:end] if i % 50000 == 0: logger.info(f"Converted {i}/{total} embeddings...") logger.info(f"✓ Converted to array with shape: {doc_embeddings.shape}") else: doc_embeddings = loaded_embeddings logger.info(f"Embeddings already numpy array with shape: {doc_embeddings.shape}") logger.info(f"Loaded {len(doc_chunks)} chunks with embeddings") # Skip BM25 (too memory-heavy for Docker), use inverted index only global inverted_index # Try to load pre-built inverted index (638MB) - MUCH faster than building (15 minutes) if INVERTED_INDEX_FILE.exists(): logger.info(f"Loading pre-built inverted index from {INVERTED_INDEX_FILE.name}...") try: with open(INVERTED_INDEX_FILE, 'rb') as f: inverted_index = pickle.load(f) logger.info(f"✓ Loaded pre-built inverted index with {len(inverted_index):,} terms (instant vs 15min build)") except Exception as e: logger.warning(f"Failed to load pre-built index: {e}, building from scratch...") inverted_index = build_inverted_index(doc_chunks) else: logger.info("Pre-built inverted index not found, building from scratch (this takes 15 minutes)...") inverted_index = build_inverted_index(doc_chunks) logger.info("Will use inverted index + semantic search (no BM25)") return True except Exception as e: logger.error(f"Failed to load embeddings: {e}") raise RuntimeError("Embeddings are required but failed to load") from e raise RuntimeError("Embeddings files not found - system cannot function without embeddings") def filter_trial_for_clinical_summary(trial_text): """ Filter trial data to keep essential clinical information including SOME results. COMPREHENSIVE FILTERING: - Keeps all core trial info (title, summary, conditions, interventions) - Keeps sponsor/collaborator/manufacturer (WHO is running the trial) - Keeps first 5 outcomes (to show key endpoints) - Keeps first 5 result values per trial (to show actual data) - Filters out overwhelming statistical noise (hundreds of baseline/adverse event lines) This ensures the LLM sees comprehensive context including company information. """ if not trial_text: return trial_text lines = trial_text.split('\n') filtered_lines = [] # Counters to limit repetitive data outcome_count = 0 outcome_desc_count = 0 result_value_count = 0 # Limits MAX_OUTCOMES = 5 MAX_OUTCOME_DESC = 5 MAX_RESULT_VALUES = 5 for line in lines: line_stripped = line.strip() # Skip empty lines if not line_stripped: continue # ALWAYS SKIP: Overwhelming noise always_skip = [ 'BASELINE:', 'SERIOUS_ADVERSE_EVENT:', 'OTHER_ADVERSE_EVENT:', 'OUTCOME_TYPE:', 'OUTCOME_TIME_FRAME:', 'OUTCOME_SAFETY:', 'OUTCOME_OTHER:', 'OUTCOME_NUMBER:' ] should_skip = False for marker in always_skip: if line_stripped.startswith(marker): should_skip = True break if should_skip: continue # LIMITED KEEP: Outcomes (first N only) if line_stripped.startswith('OUTCOME:'): outcome_count += 1 if outcome_count <= MAX_OUTCOMES: filtered_lines.append(line) continue # LIMITED KEEP: Outcome descriptions (first N only) if line_stripped.startswith('OUTCOME_DESCRIPTION:'): outcome_desc_count += 1 if outcome_desc_count <= MAX_OUTCOME_DESC: filtered_lines.append(line) continue # LIMITED KEEP: Result values (first N only) if line_stripped.startswith('RESULT_VALUE:'): result_value_count += 1 if result_value_count <= MAX_RESULT_VALUES: filtered_lines.append(line) continue # ALWAYS KEEP: Core trial information + context always_keep = [ 'NCT_ID:', 'TITLE:', 'OFFICIAL_TITLE:', 'SUMMARY:', 'DESCRIPTION:', 'CONDITIONS:', 'INTERVENTION:', # WHAT disease, WHAT drug 'SPONSOR:', 'COLLABORATOR:', 'MANUFACTURER:', # WHO is running/funding 'ELIGIBILITY:' # Note: OUTCOME/OUTCOME_DESCRIPTION handled in LIMITED KEEP section above ] for marker in always_keep: if line_stripped.startswith(marker): filtered_lines.append(line) break return '\n'.join(filtered_lines) def retrieve_context_with_embeddings(query, top_k=10): """ ENTERPRISE HYBRID SEARCH: Always combines keyword + semantic scoring - Extracts ALL meaningful terms from query (case-insensitive) - Scores each trial by keyword frequency (TF-IDF style) - Also gets semantic similarity scores - Merges both scores with weighted combination - Works regardless of capitalization, language, or spelling """ import time import re from collections import Counter global doc_chunks, doc_embeddings, embedder if doc_embeddings is None or len(doc_chunks) == 0: logger.error("Embeddings not loaded!") return "" t0 = time.time() # Extract ALL meaningful words from query (stop words removed) stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'what', 'how', 'do', 'you', 'know', 'about', 'that', 'this', 'there', 'it'} query_lower = query.lower() # Remove punctuation and split words = re.findall(r'\b\w+\b', query_lower) # Filter out stop words and short words query_terms = [w for w in words if len(w) > 2 and w not in stop_words] logger.info(f"[HYBRID] Query terms extracted: {query_terms}") # PARALLEL SEARCH: Run both keyword and semantic simultaneously # 1. KEYWORD SCORING WITH BM25 (Fast!) t_kw = time.time() # Use inverted index for drug lookup (lightweight, no BM25) global bm25_index, inverted_index keyword_scores = {} if inverted_index is not None: # Check if any query terms are in our drug/intervention inverted index inv_index_candidates = set() for term in query_terms: if term in inverted_index: inv_index_candidates.update(inverted_index[term]) logger.info(f"[INVERTED INDEX] Found {len(inverted_index[term])} trials for '{term}'") # FAST PATH: If we have inverted index hits (drug names), score those trials if inv_index_candidates: logger.info(f"[FAST PATH] Checking {len(inv_index_candidates)} inverted index candidates") # CRITICAL: Identify which terms are specific drugs (low frequency) drug_specific_terms = set() for term in query_terms: if term in inverted_index and len(inverted_index[term]) < 100: # This term appears in <100 trials - likely a specific drug name! drug_specific_terms.add(term) logger.info(f"[DRUG SPECIFIC] '{term}' found in {len(inverted_index[term])} trials - treating as drug name") for idx in inv_index_candidates: # No BM25, use simple match count as base score base_score = 1.0 # Check if this trial contains a drug-specific term chunk_data = doc_chunks[idx] chunk_text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data chunk_lower = chunk_text.lower() has_drug_match = False for drug_term in drug_specific_terms: if drug_term in chunk_lower: has_drug_match = True break # MASSIVE PRIORITY for drug-specific trials if has_drug_match: # Drug-specific trials get GUARANTEED top ranking score = 1000.0 + base_score logger.info(f"[DRUG PRIORITY] Trial {idx} contains specific drug - score={score:.1f}") else: # Regular inverted index hits (generic terms) if base_score <= 0: base_score = 0.1 score = base_score # Apply field-specific boosting for non-drug terms max_field_boost = 1.0 for term in query_terms: if term not in chunk_lower or term in drug_specific_terms: continue # INTERVENTION field - medium priority for non-drug terms if f'intervention: {term}' in chunk_lower or f'intervention:{term}' in chunk_lower: max_field_boost = max(max_field_boost, 3.0) # TITLE field - low priority elif 'title:' in chunk_lower: title_pos = chunk_lower.find('title:') term_pos = chunk_lower.find(term) if title_pos < term_pos < title_pos + 200: max_field_boost = max(max_field_boost, 2.0) score *= max_field_boost keyword_scores[idx] = score else: logger.info(f"[FALLBACK] No inverted index hits, using pure semantic search") logger.info(f"[HYBRID] Inverted index scoring: {len(keyword_scores)} trials matched ({time.time()-t_kw:.2f}s)") # 2. SEMANTIC SCORING load_embedder() t_sem = time.time() query_embedding = embedder.encode([query])[0] semantic_similarities = np.dot(doc_embeddings, query_embedding) logger.info(f"[HYBRID] Semantic scoring complete ({time.time()-t_sem:.2f}s)") # 3. MERGE SCORES # Normalize both scores to 0-1 range if keyword_scores: max_kw = max(keyword_scores.values()) keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()} else: keyword_scores_norm = {} max_sem = semantic_similarities.max() min_sem = semantic_similarities.min() semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10) # Combined score: 50% keyword (with IDF/field boost), 50% semantic (context) # Balanced approach: IDF-weighted keywords + semantic understanding combined_scores = np.zeros(len(doc_chunks)) for idx in range(len(doc_chunks)): kw_score = keyword_scores_norm.get(idx, 0.0) sem_score = semantic_scores_norm[idx] # If keyword match exists, balance keyword + semantic if kw_score > 0: combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score else: # Pure semantic if no keyword match combined_scores[idx] = sem_score # Get top K by combined score (get more candidates to sort by recency) # We'll get 10 candidates, then sort by NCT ID to find the 3 most recent candidate_k = max(top_k * 3, 10) # Get 3x requested, minimum 10 top_indices = np.argsort(combined_scores)[-candidate_k:][::-1] logger.info(f"[HYBRID] Top 3 combined scores: {combined_scores[top_indices[:3]]}") logger.info(f"[HYBRID] Top 3 keyword scores: {[keyword_scores_norm.get(i, 0.0) for i in top_indices[:3]]}") logger.info(f"[HYBRID] Top 3 semantic scores: {[semantic_scores_norm[i] for i in top_indices[:3]]}") # Extract text and scores for 355M ranking # Format as (score, text) tuples for rank_trials_with_355m candidate_trials_for_ranking = [(combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i]) for i in top_indices] # SORT BY NCT ID (higher = newer) before 355M ranking def extract_nct_number(trial_tuple): """Extract NCT number from trial text for sorting (higher = newer)""" _, text = trial_tuple match = re.search(r'NCT_ID:\s*NCT(\d+)', text) return int(match.group(1)) if match else 0 # Sort candidates by NCT ID (descending = newest first) candidate_trials_for_ranking.sort(key=extract_nct_number, reverse=True) # Log top 5 NCT IDs to show recency sorting top_ncts = [] for score, text in candidate_trials_for_ranking[:5]: match = re.search(r'NCT_ID:\s*(NCT\d+)', text) if match: top_ncts.append(match.group(1)) logger.info(f"[NCT SORT] Top 5 candidates by recency: {top_ncts}") # SKIP 355M RANKING - It's broken (gives 0.50 to everything) and wastes 10 seconds # Just use the hybrid-scored + recency-sorted candidates logger.info(f"[FAST MODE] Using hybrid search + recency sort (skipping broken 355M ranking)") ranked_trials = candidate_trials_for_ranking # Take top K from ranked results top_ranked = ranked_trials[:top_k] logger.info(f"[FAST MODE] Selected top {len(top_ranked)} trials (hybrid score + recency)") # Extract just the text raw_chunks = [trial_text for _, trial_text in top_ranked] # Apply clinical filter to each trial context_chunks = [filter_trial_for_clinical_summary(chunk) for chunk in raw_chunks] if context_chunks: first_trial_preview = context_chunks[0][:200] logger.info(f"[HYBRID] First result (filtered): {first_trial_preview}") # Add ranking information if available from 355M if hasattr(ranked_trials, 'ranking_info'): ranking_header = "[TRIAL RANKING BY CLINICAL RELEVANCE GPT]\n" for info in ranked_trials.ranking_info: ranking_header += f"Rank {info['rank']}: {info['nct_id']} - Relevance {info['relevance_rating']}\n" ranking_header += "---\n\n" # Prepend ranking info to first trial if context_chunks: context_chunks[0] = ranking_header + context_chunks[0] logger.info(f"[355M RANKING] Added ranking metadata to context for final LLM") context = "\n\n---\n\n".join(context_chunks) # Use --- as separator between trials logger.info(f"[HYBRID] TOTAL TIME: {time.time()-t0:.2f}s") logger.info(f"[HYBRID] Filtered context length: {len(context)} chars (was ~{sum(len(c) for c in raw_chunks)} chars)") return context def keyword_search_query_text(query, max_results=10, hf_token=None): """Search dataset using ALL meaningful words from the full query""" if not DATASET_FILE.exists(): logger.error("Dataset file not found") return "" # Extract all meaningful words from the full query # Remove common stopwords but keep medical/clinical terms stopwords = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'should', 'could', 'may', 'might', 'must', 'can', 'of', 'at', 'by', 'for', 'with', 'about', 'as', 'into', 'through', 'during', 'to', 'from', 'in', 'on', 'what', 'you', 'know', 'that', 'relevant'} # Extract words, filter stopwords and short words words = query.lower().split() search_terms = [w.strip('?.,!;:()[]{}') for w in words if w.lower() not in stopwords and len(w) >= 3] if not search_terms: logger.warning("No search terms extracted from query") return "" logger.info(f"Search terms from full query: {search_terms}") # Store trials with match scores trials_with_scores = [] current_trial = "" try: with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f: for line in f: # Check if new trial starts if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"): # Score previous trial if current_trial: trial_lower = current_trial.lower() # Count matches for all search terms score = sum(1 for term in search_terms if term in trial_lower) if score > 0: trials_with_scores.append((score, current_trial)) current_trial = line else: current_trial += line # Check last trial if current_trial: trial_lower = current_trial.lower() score = sum(1 for term in search_terms if term in trial_lower) if score > 0: trials_with_scores.append((score, current_trial)) # Sort by score (highest first) and take top results trials_with_scores.sort(reverse=True, key=lambda x: x[0]) matching_trials = [(score, trial) for score, trial in trials_with_scores[:max_results]] if matching_trials: logger.info(f"Keyword search found {len(matching_trials)} trials") return matching_trials # Return list of (score, trial) tuples else: logger.warning("Keyword search found no matching trials") return [] except Exception as e: logger.error(f"Keyword search failed: {e}") return [] def keyword_search_in_dataset(entities, max_results=10): """Legacy: Search dataset file for keyword matches using extracted entities""" if not DATASET_FILE.exists(): logger.error("Dataset file not found") return "" drugs = [d.lower() for d in entities.get('drugs', [])] conditions = [c.lower() for c in entities.get('conditions', [])] if not drugs and not conditions: logger.warning("No search terms for keyword search") return "" logger.info(f"Keyword search - Drugs: {drugs}, Conditions: {conditions}") # Store trials with match scores trials_with_scores = [] current_trial = "" try: with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f: for line in f: # Check if new trial starts if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"): # Score previous trial if current_trial: trial_lower = current_trial.lower() # Count matches drug_matches = sum(1 for d in drugs if d in trial_lower) condition_matches = sum(1 for c in conditions if c in trial_lower) # Only include trials that match at least the drug (if drug was specified) if drugs: if drug_matches > 0: score = drug_matches * 10 + condition_matches trials_with_scores.append((score, current_trial)) elif condition_matches > 0: # No drug specified, just match conditions trials_with_scores.append((condition_matches, current_trial)) current_trial = line else: current_trial += line # Check last trial if current_trial: trial_lower = current_trial.lower() drug_matches = sum(1 for d in drugs if d in trial_lower) condition_matches = sum(1 for c in conditions if c in trial_lower) if drugs: if drug_matches > 0: score = drug_matches * 10 + condition_matches trials_with_scores.append((score, current_trial)) elif condition_matches > 0: trials_with_scores.append((condition_matches, current_trial)) # Sort by score (highest first) and take top results trials_with_scores.sort(reverse=True, key=lambda x: x[0]) matching_trials = [trial for score, trial in trials_with_scores[:max_results]] if matching_trials: context = "\n\n---\n\n".join(matching_trials) if len(context) > 6000: context = context[:6000] + "..." logger.info(f"Keyword search found {len(matching_trials)} trials (from {len(trials_with_scores)} candidates)") return context else: logger.warning("Keyword search found no trials matching drug") return "" except Exception as e: logger.error(f"Keyword search failed: {e}") return "" # ============================================================================ # ENTITY EXTRACTION # ============================================================================ def parse_entities_from_query(conversation, hf_token=None): """Parse entities from query using both 355M and 8B models + regex fallback""" entities = {'drugs': [], 'conditions': []} # Use 355M model for entity extraction extracted_355m = extract_entities_with_small_model(conversation) # Also use 8B model for more reliable extraction extracted_8b = extract_entities_with_8b(conversation, hf_token=hf_token) # Combine both extractions extracted = (extracted_355m or "") + "\n" + (extracted_8b or "") # Parse model output if extracted: lines = extracted.split('\n') for line in lines: lower_line = line.lower() if 'drug:' in lower_line or 'medication:' in lower_line: drug = re.sub(r'(drug:|medication:)', '', line, flags=re.IGNORECASE).strip() if drug: entities['drugs'].append(drug) elif 'condition:' in lower_line or 'disease:' in lower_line: condition = re.sub(r'(condition:|disease:)', '', line, flags=re.IGNORECASE).strip() if condition: entities['conditions'].append(condition) # Regex fallback for standard drug naming patterns drug_patterns = [ r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: -mab suffix r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: -nib suffix r'\b([A-Z]\d+[A-Z]+\d+)\b' # Alphanumeric codes like F8IL10 ] for pattern in drug_patterns: matches = re.findall(pattern, conversation) for match in matches: if match.lower() not in [d.lower() for d in entities['drugs']]: entities['drugs'].append(match) condition_patterns = [ r'\b(sjogren\'?s?|lupus|myelofibrosis|rheumatoid arthritis)\b' ] for pattern in condition_patterns: matches = re.findall(pattern, conversation, re.IGNORECASE) for match in matches: if match not in [c.lower() for c in entities['conditions']]: entities['conditions'].append(match) logger.info(f"Extracted entities: {entities}") return entities # ============================================================================ # MAIN QUERY PROCESSING # ============================================================================ def extract_entities_simple(query): """Simple entity extraction using regex patterns - no model needed""" entities = {'drugs': [], 'conditions': []} # Drug patterns drug_patterns = [ r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: ianalumab, rituximab, etc. r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: imatinib, etc. r'\b([A-Z]\d+[A-Z]+\d+)\b', # Alphanumeric codes r'\b(ianalumab|rituximab|tocilizumab|adalimumab|infliximab)\b', # Common drugs ] # Condition patterns condition_patterns = [ r'\b(sjogren\'?s?\s+syndrome)\b', r'\b(rheumatoid arthritis)\b', r'\b(lupus)\b', r'\b(myelofibrosis)\b', r'\b(diabetes)\b', r'\b(cancer|carcinoma|melanoma)\b', ] query_lower = query.lower() # Extract drugs for pattern in drug_patterns: matches = re.findall(pattern, query, re.IGNORECASE) for match in matches: if match.lower() not in [d.lower() for d in entities['drugs']]: entities['drugs'].append(match) # Extract conditions for pattern in condition_patterns: matches = re.findall(pattern, query, re.IGNORECASE) for match in matches: if match.lower() not in [c.lower() for c in entities['conditions']]: entities['conditions'].append(match) logger.info(f"Extracted entities: {entities}") return entities def parse_query_with_llm(query, hf_token=None): """ Use fast LLM to parse query and extract structured information Extracts: - Drug names - Diseases/conditions - Companies (sponsors/manufacturers) - Endpoints (what's being measured) - Search terms (optimized for RAG) Returns: Dict with extracted entities and optimized search query """ try: from huggingface_hub import InferenceClient logger.info("[QUERY PARSER] Analyzing user query with LLM...") client = InferenceClient(token=hf_token, timeout=30) parse_prompt = f"""Extract key information from this clinical trial query. Query: "{query}" Extract and return in this EXACT format: DRUGS: [list drug/treatment names, or "none"] DISEASES: [list diseases/conditions, or "none"] COMPANIES: [list company/sponsor names, or "none"] ENDPOINTS: [list trial endpoints/outcomes, or "none"] SEARCH_TERMS: [optimized search keywords] Examples: Query: "What Novartis drugs treat melanoma?" DRUGS: none DISEASES: melanoma COMPANIES: Novartis ENDPOINTS: none SEARCH_TERMS: Novartis melanoma treatment drugs Query: "Tell me about Keytruda for lung cancer" DRUGS: Keytruda DISEASES: lung cancer COMPANIES: none ENDPOINTS: none SEARCH_TERMS: Keytruda lung cancer Now parse the query above:""" response = client.chat_completion( model="meta-llama/Llama-3.1-70B-Instruct", messages=[{"role": "user", "content": parse_prompt}], max_tokens=256, temperature=0.1 # Low temp for consistent parsing ) parsed = response.choices[0].message.content.strip() logger.info(f"[QUERY PARSER] Extracted entities:\n{parsed}") # Parse the response into dict result = { 'raw_parsed': parsed, 'drugs': [], 'diseases': [], 'companies': [], 'endpoints': [], 'search_terms': query # fallback } lines = parsed.split('\n') for line in lines: line = line.strip() if line.startswith('DRUGS:'): drugs = line.replace('DRUGS:', '').strip() if drugs.lower() != 'none': result['drugs'] = [d.strip() for d in drugs.split(',')] elif line.startswith('DISEASES:'): diseases = line.replace('DISEASES:', '').strip() if diseases.lower() != 'none': result['diseases'] = [d.strip() for d in diseases.split(',')] elif line.startswith('COMPANIES:'): companies = line.replace('COMPANIES:', '').strip() if companies.lower() != 'none': result['companies'] = [c.strip() for c in companies.split(',')] elif line.startswith('ENDPOINTS:'): endpoints = line.replace('ENDPOINTS:', '').strip() if endpoints.lower() != 'none': result['endpoints'] = [e.strip() for e in endpoints.split(',')] elif line.startswith('SEARCH_TERMS:'): result['search_terms'] = line.replace('SEARCH_TERMS:', '').strip() logger.info(f"[QUERY PARSER] ✓ Drugs: {result['drugs']}, Diseases: {result['diseases']}, Companies: {result['companies']}") return result except Exception as e: logger.warning(f"[QUERY PARSER] Failed: {e}, using original query") return { 'drugs': [], 'diseases': [], 'companies': [], 'endpoints': [], 'search_terms': query, 'raw_parsed': '' } def generate_llama_response(query, rag_context, hf_token=None): """ Generate response using FAST Groq API (10x faster than HF) Speed comparison: - HuggingFace: ~40 tokens/sec = 15 seconds - Groq: ~300 tokens/sec = 2 seconds (FREE!) """ try: # Try Groq first (much faster), fallback to HuggingFace groq_api_key = os.getenv("GROQ_API_KEY") if groq_api_key: logger.info("Generating response with Llama-3.1-70B via GROQ (fast)...") from groq import Groq client = Groq(api_key=groq_api_key) # Simplified prompt for faster generation system_prompt = """You are a medical research assistant. Answer based ONLY on the provided clinical trial data. Be concise and cite NCT IDs.""" user_prompt = f"""Clinical trials: {rag_context[:6000]} Question: {query} Provide a concise answer citing specific NCT trial IDs.""" response = client.chat.completions.create( model="llama-3.1-70b-versatile", # Groq's optimized 70B messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], max_tokens=512, # Shorter for speed temperature=0.3, timeout=30 ) return response.choices[0].message.content.strip() else: # Fallback to HuggingFace (slower) logger.info("Generating response with Llama-3.1-70B via HuggingFace (slow)...") from huggingface_hub import InferenceClient client = InferenceClient(token=hf_token, timeout=120) system_prompt = """You are a medical research assistant. Answer based ONLY on the provided clinical trial data. Be concise and cite NCT IDs.""" user_prompt = f"""Clinical trials: {rag_context[:6000]} Question: {query} Provide a concise answer citing specific NCT trial IDs.""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] response = client.chat_completion( model="meta-llama/Meta-Llama-3.1-70B-Instruct", messages=messages, max_tokens=512, # Reduced from 2048 for speed temperature=0.3 ) return response.choices[0].message.content.strip() except Exception as e: logger.error(f"Llama error: {e}") return f"Llama API error: {str(e)}" def process_query_simple_test(conversation): """TEST JUST THE RAG - no models""" try: import time output = [] output.append(f"QUERY: {conversation}\n") # Check if embeddings loaded if doc_embeddings is None or len(doc_chunks) == 0: return "FAIL: Embeddings not loaded" output.append(f"✓ Embeddings loaded: {len(doc_chunks)} chunks\n") output.append(f"✓ Embeddings shape: {doc_embeddings.shape}\n") # Try to search start = time.time() context = retrieve_context_with_embeddings(conversation, top_k=3) search_time = time.time() - start if not context: return "".join(output) + "\nFAIL: RAG returned empty" output.append(f"✓ RAG search took: {search_time:.2f}s\n") output.append(f"✓ Retrieved {context.count('NCT')} trials\n\n") output.append("FIRST 1000 CHARS:\n") output.append(context[:1000]) return "".join(output) except Exception as e: import traceback return f"ERROR IN RAG TEST:\n{str(e)}\n\nTRACEBACK:\n{traceback.format_exc()}" def process_query(conversation): """ Complete pipeline with LLM query parsing and natural language generation Flow: 0. LLM Parser - Extract drugs, diseases, companies, endpoints (~2-3s) 1. RAG Search - Hybrid search using optimized query (~2s) 2. Skipped - 355M ranking removed (was broken) 3. LLM Response - Llama 70B generates natural language (~15s) Total: ~20 seconds """ import time import traceback import sys # MASTER try/except - catches EVERYTHING try: start_time = time.time() output_parts = [f"QUERY: {conversation}\n\n"] # Step 0: Parse query with LLM to extract structured info try: step0_start = time.time() logger.info("Step 0: Parsing query with LLM...") output_parts.append("✓ Step 0: LLM query parser started...\n") parsed_query = parse_query_with_llm(conversation, hf_token=hf_token) # Use optimized search terms from parser search_query = parsed_query['search_terms'] step0_time = time.time() - step0_start output_parts.append(f"✓ Step 0 Complete: Extracted entities ({step0_time:.1f}s)\n") output_parts.append(f" Drugs: {parsed_query['drugs']}\n") output_parts.append(f" Diseases: {parsed_query['diseases']}\n") output_parts.append(f" Companies: {parsed_query['companies']}\n") output_parts.append(f" Optimized search: {search_query}\n") logger.info(f"Query parsing successful in {step0_time:.1f}s") except Exception as e: error_msg = f"✗ Step 0 WARNING (LLM Parser): {str(e)}, using original query" logger.warning(error_msg) output_parts.append(f"{error_msg}\n") search_query = conversation # Fallback to original # Step 1: RAG search (using optimized search query) try: step1_start = time.time() logger.info("Step 1: RAG search...") output_parts.append("✓ Step 1: RAG search started...\n") context = retrieve_context_with_embeddings(search_query, top_k=3) if not context: return "No matching trials found in RAG search." # No limit - use complete trials step1_time = time.time() - step1_start output_parts.append(f"✓ Step 1 Complete: Found {context.count('NCT')} trials ({step1_time:.1f}s)\n") logger.info(f"RAG search successful - found trials in {step1_time:.1f}s") except Exception as e: error_msg = f"✗ Step 1 FAILED (RAG search): {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return error_msg # Step 2: Skipped (355M ranking removed - was broken) output_parts.append("✓ Step 2: Skipped (using hybrid search + recency)\n") # Step 3: Llama 70B try: step3_start = time.time() logger.info("Step 3: Generating response with Llama-3.1-70B...") output_parts.append("✓ Step 3: Llama 70B generation started...\n") llama_response = generate_llama_response(conversation, context, hf_token=hf_token) step3_time = time.time() - step3_start output_parts.append(f"✓ Step 3 Complete: Llama 70B response generated ({step3_time:.1f}s)\n") logger.info(f"Llama 70B generation successful in {step3_time:.1f}s") except Exception as e: error_msg = f"✗ Step 3 FAILED (Llama 70B): {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) llama_response = f"[Llama 70B error: {str(e)}]" output_parts.append(f"✗ Step 3 Failed: {str(e)}\n") total_time = time.time() - start_time # Format output - handle missing variables try: context_display = context if 'context' in locals() else "[No context retrieved]" clinical_display = clinical_context_355m if 'clinical_context_355m' in locals() else "[355M not run]" llama_display = llama_response if 'llama_response' in locals() else "[Llama 70B not run]" output = f"""{''.join(output_parts)} CLINICAL SUMMARY (Llama-3.1-70B-Instruct): {llama_display} --- RAG RETRIEVED TRIALS (Top 3 Most Relevant): {context_display} --- Total Time: {total_time:.1f}s """ return output except Exception as e: # Absolute fallback error_info = f""" CRITICAL ERROR IN OUTPUT FORMATTING: {str(e)} TRACEBACK: {traceback.format_exc()} OUTPUT PARTS: {''.join(output_parts)} Variables defined: {locals().keys()} """ logger.error(error_info) return error_info # MASTER EXCEPTION HANDLER - catches ANY unhandled error except Exception as master_error: master_error_msg = f""" ======================================== MASTER ERROR HANDLER CAUGHT EXCEPTION ======================================== Error Type: {type(master_error).__name__} Error Message: {str(master_error)} FULL TRACEBACK: {traceback.format_exc()} System Info: - Python version: {sys.version} - Error at line: {sys.exc_info()[2].tb_lineno if sys.exc_info()[2] else 'unknown'} ======================================== """ logger.error(master_error_msg) return master_error_msg # ============================================================================ # GRADIO INTERFACE # ============================================================================ with gr.Blocks(title="Foundation 1.2") as demo: gr.Markdown("# Foundation 1.2 - Clinical Trial AI") query_input = gr.Textbox( label="Ask about clinical trials", placeholder="Example: What are the results for ianalumab in Sjogren's syndrome?", lines=3 ) submit_btn = gr.Button("Generate Response", variant="primary") output = gr.Textbox( label="AI Response", lines=30 ) submit_btn.click( fn=process_query, # Full pipeline: RAG + 355M + Llama inputs=query_input, outputs=output ) gr.Markdown(""" **Production RAG Pipeline - Optimized for Clinical Accuracy** **Search (3-Stage Hybrid):** 1. Keyword matching (70%) + Semantic search (30%) → 10 candidates 2. 355M Clinical Trial GPT re-ranks by relevance 3. Returns top 3 trials with best clinical relevance scores **Generation (Qwen2.5-14B-Instruct):** - 14B parameter model via HuggingFace Inference API - Structured clinical summaries with clear headings - Cites specific NCT trial IDs - Includes actual trial results and efficacy data - High-quality medical reasoning and analysis *355M model used for ranking (not generation) + Qwen2.5-14B for responses* """) # ============================================================================ # STARTUP # ============================================================================ # Embeddings will be loaded by FastAPI startup event in app.py # Do NOT load here - causes Docker permission errors logger.info("=== Foundation 1.2 Module Loaded ===") logger.info("Call load_embeddings() to initialize the system") if DATASET_FILE.exists(): file_size_mb = DATASET_FILE.stat().st_size / (1024 * 1024) logger.info(f"✓ Dataset file found: {file_size_mb:.0f}MB") else: logger.error("✗ Dataset file not found!") logger.info("=== Startup Complete ===") if __name__ == "__main__": demo.launch()