import os import subprocess import math import difflib import tempfile import numpy as np import pandas as pd from tqdm.auto import tqdm import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from sacremoses import MosesDetokenizer from flask import Flask, request, jsonify, render_template import traceback # Import traceback for better error logging import re # Import the regular expression module # --- Constants and Paths --- # Ensure these files are in the same directory as app.py or provide correct paths FINETUNED_MODEL_PATH = "hoc_best.pt" BPE_CODES_PATH = "bpecodes" DICT_TXT_PATH = "dict.txt" FASTBPE_BIN_PATH = "./fastbpe_exec" # Assumes fast executable is alongside app.py HALLMARKS = [ # Keep this consistent with training/evaluation "activating invasion and metastasis", "avoiding immune destruction", "cellular energetics", "enabling replicative immortality", "evading growth suppressors", "genomic instability and mutation", "inducing angiogenesis", "resisting cell death", "sustaining proliferative signaling", "tumor promoting inflammation", ] # --- Model Architecture Definitions (Copy from your notebook) --- # NOTE: Make sure these classes are IDENTICAL to the ones used for training # including GPTConfig, LayerNorm, CausalSelfAttention, MLP, Block, GPT, GPTWithSoftPrompt @dataclass class GPTConfig: block_size: int vocab_size: int n_layer: int n_head: int n_embd: int dropout: float = 0.0 bias: bool = True class LayerNorm(nn.Module): # (Copied from notebook) def __init__(self, ndim, bias): super().__init__() self.weight = nn.Parameter(torch.ones(ndim)) self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None def forward(self, x): return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) class CausalSelfAttention(nn.Module): # (Copied from notebook) def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) self.n_head = config.n_head self.n_embd = config.n_embd self.flash = hasattr(F, 'scaled_dot_product_attention') # Check for flash attention if not self.flash: # print("Warning: Flash Attention not available.") # Optional warning # Make the buffer persistent otherwise device mismatches during forward pass self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size), persistent=True) #else: # print("Using Flash Attention.") # Optional info def forward(self, x): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.n_embd, dim=2) k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True) else: # manual implementation of attention att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # Ensure bias buffer is used correctly # Check if bias buffer exists before using it if hasattr(self, 'bias'): att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) else: # Fallback if somehow bias wasn't registered (shouldn't happen with persistent=True) mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T) att = att.masked_fill(mask == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) return y class MLP(nn.Module): # (Copied from notebook) def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) self.gelu = nn.GELU() self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): x = self.c_fc(x) x = self.gelu(x) x = self.c_proj(x) x = self.dropout(x) return x class Block(nn.Module): # (Copied from notebook) def __init__(self, config): super().__init__() self.ln1 = LayerNorm(config.n_embd, bias=config.bias) self.attn = CausalSelfAttention(config) self.ln2 = LayerNorm(config.n_embd, bias=config.bias) self.mlp = MLP(config) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class GPT(nn.Module): # (Copied from notebook - simplified _init_weights and removed generate) def __init__(self, config): super().__init__() #assert config.vocab_size is not None #assert config.block_size is not None self.config = config self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), wpe = nn.Embedding(config.block_size, config.n_embd), drop = nn.Dropout(config.dropout), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), ln_f = LayerNorm(config.n_embd, bias=config.bias), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying # init all weights self.apply(self._init_weights) # apply special scaled init to the residual projections, per GPT-2 paper for pn, p in self.named_parameters(): if pn.endswith('c_proj.weight'): torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): device = idx.device b, t = idx.size() if t > self.config.block_size: # Crop sequence if longer than block size print(f"Warning: Input sequence length ({t}) > block size ({self.config.block_size}). Cropping.") idx = idx[:, -self.config.block_size:] t = self.config.block_size #assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) # forward the GPT model itself tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) if targets is not None: # if we are given some desired targets also calculate the loss logits = self.lm_head(x) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) return logits, loss else: # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim # Check for NaN/Inf in logits before returning if torch.isnan(logits).any() or torch.isinf(logits).any(): print("WARNING: NaN or Inf detected in logits during inference.") # Handle appropriately - maybe return an error indicator or zero logits? # For now, just print warning. return logits, None class GPTWithSoftPrompt(nn.Module): # (Copied from notebook - simplified) def __init__(self, base_gpt: GPT, prompt_len=1): super().__init__() self.config = base_gpt.config self.transformer = base_gpt.transformer self.lm_head = base_gpt.lm_head C = self.config.n_embd self.soft_prompt = nn.Parameter(torch.zeros(1, prompt_len, C)) # Keep on CPU first nn.init.normal_(self.soft_prompt, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, T = idx.shape device = idx.device # Get device from input tensor # Make sure soft_prompt is on the same device as input soft_prompt_on_device = self.soft_prompt.to(device) # token + pos tok_emb = self.transformer.wte(idx) # (B,T,C) pos = torch.arange(0, T, dtype=torch.long, device=device) pos_emb = self.transformer.wpe(pos) # (T,C) x_tokens = tok_emb + pos_emb # prepend soft prompt soft = soft_prompt_on_device.expand(B, -1, -1) # (B,P,C) # --- FIX: Define P before the if/else block --- P = soft.size(1) # Get soft prompt length x = torch.cat([soft, x_tokens], dim=1) # (B,P+T,C) # --- Standard Transformer forward pass --- x = self.transformer.drop(x) for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) logits = self.lm_head(x) # (B,P+T,V) # --- End Standard --- if targets is None: # Inference: return logits for the last token of the *original* sequence # We need the prediction *after* the last input token, which is at index T (P+T-1 overall) # Use P which is now defined # Ensure index is within bounds target_logit_index = P + T - 1 if target_logit_index >= logits.size(1): print(f"Warning: Calculated logit index {target_logit_index} out of bounds for logits shape {logits.shape}. Returning last logit.") target_logit_index = -1 # Fallback to last logit final_logits = logits[:, target_logit_index, :] # Check for NaN/Inf if torch.isnan(final_logits).any() or torch.isinf(final_logits).any(): print(f"WARNING: NaN or Inf detected in final_logits at index {target_logit_index}.") # Handle appropriately - maybe return zeros or raise an error? # For now, just print warning. Let the calling function handle it. return final_logits, None # Return (B, V) else: # Training loss calculation (copied from notebook) # P is already defined above pad_ignore = torch.full((B, P), -1, dtype=targets.dtype, device=device) full_targets = torch.cat([pad_ignore, targets], dim=1) logits_lm = logits[:, :-1, :].contiguous() targets_lm = full_targets[:, 1:].contiguous() loss = F.cross_entropy( logits_lm.view(-1, logits_lm.size(-1)), targets_lm.view(-1), ignore_index=-1 ) # Check for NaN/Inf in loss if torch.isnan(loss) or torch.isinf(loss): print("WARNING: NaN or Inf detected in loss calculation.") # Potentially add debugging info here (e.g., print shapes, inputs) return logits, loss # --- Constrained generation method (from Section 2.9) --- @torch.no_grad() def generate_labels(self, idx, allowed_mask, max_new_tokens=24, temperature=0.0): self.eval() # Ensure model is in eval mode B = idx.size(0) # Add soft prompt length to effective block size consideration P = self.soft_prompt.size(1) # Correct effective block size based on GPT class logic effective_block_size = self.config.block_size # GPT forward handles cropping # Start with input index out = idx.clone() # Clone to avoid modifying original input # Ensure allowed_mask is on the correct device allowed_mask = allowed_mask.to(idx.device) finished = torch.zeros(B, dtype=torch.bool, device=idx.device) # Get global eos_id safely global eos_id current_eos_id = eos_id # Use the globally loaded eos_id for step in range(max_new_tokens): # Crop context if it exceeds model's block size (GPT forward handles this internally now) # ctx = out if out.size(1) <= effective_block_size else out[:, -effective_block_size:] ctx = out # Pass the current sequence # Forward pass - expects shape (B, T), model handles soft prompt internally # It returns logits for the *next* token prediction after the last token in ctx logits, _ = self(ctx) # Gets logits for last token prediction, shape (B, V) # Check for NaN/Inf in logits if torch.isnan(logits).any() or torch.isinf(logits).any(): print(f"WARNING: NaN or Inf detected in logits during generation step {step}. Stopping generation.") # Return what we have so far, excluding potentially bad last token return out[:, idx.size(1):] # Or handle error differently # Apply constraint mask # Ensure mask shape matches logits shape if logits.shape != allowed_mask.shape: print(f"Warning: Logits shape {logits.shape} doesn't match mask shape {allowed_mask.shape}. Reshaping mask.") # This assumes mask needs batch dim added current_mask = allowed_mask.unsqueeze(0).expand_as(logits) else: current_mask = allowed_mask logits = logits + current_mask # Sample next token if temperature <= 0: # Greedy decoding next_id = torch.argmax(logits, dim=-1) # (B,) else: # Temperature sampling probs = F.softmax(logits / temperature, dim=-1) # Check for NaN/Inf in probs if torch.isnan(probs).any() or torch.isinf(probs).any(): print(f"WARNING: NaN or Inf detected in probabilities during generation step {step}. Using argmax fallback.") next_id = torch.argmax(logits, dim=-1) # Fallback to greedy else: try: next_id = torch.multinomial(probs, num_samples=1).squeeze(1) # (B,) except RuntimeError as e: print(f"WARNING: torch.multinomial failed: {e}. Using argmax fallback.") next_id = torch.argmax(logits, dim=-1) # Fallback to greedy # Handle finished sequences (force EOS) and update output # Check if current_eos_id is valid if not isinstance(current_eos_id, int): print(f"Warning: Global eos_id is not an integer ({current_eos_id}). Defaulting to 0.") current_eos_id = 0 next_id = next_id.masked_fill(finished, current_eos_id) # Use the validated eos_id # Check if next_id contains invalid values (e.g., negative) if (next_id < 0).any(): print(f"WARNING: Negative token ID generated: {next_id}. Clipping to 0.") next_id = torch.clamp(next_id, min=0) # Append the next token ID out = torch.cat([out, next_id.unsqueeze(1)], dim=1) # Update finished status finished |= (next_id == current_eos_id) # Stop if all sequences in the batch are finished if bool(finished.all()): # print(f"Generation finished early at step {step+1}") # Optional debug info break # else: # print(f"Generation reached max_new_tokens ({max_new_tokens})") # Optional debug info # Return only the generated part (excluding the initial idx length) return out[:, idx.size(1):] # --- Tokenizer Helper Functions --- # Added robustness and error checks # Global tokenizer maps and special IDs, loaded once at startup token2id, id2token = {}, {} eos_id = 0 # Default, will be overwritten pad_id = 0 # Default, will be overwritten detokenizer = None def load_tokenizer_data(dict_path): global token2id, id2token, eos_id, pad_id, detokenizer print(f"Loading vocabulary from {dict_path}...") local_token2id, local_id2token = {}, {} try: with open(dict_path, encoding="utf-8") as f: for i, line in enumerate(f): parts = line.split() # Split by whitespace if not parts: continue # Skip empty lines tok = parts[0] if tok in local_token2id: print(f"Warning: Duplicate token '{tok}' found at line {i+1}. Keeping first occurrence.") continue local_token2id[tok] = i local_id2token[i] = tok # Assign to global variables only after successful loading token2id = local_token2id id2token = local_id2token # Use a known special token ID if is missing, otherwise default might be wrong # Try multiple common EOS tokens possible_eos = ["", "<|endoftext|>", "[EOS]"] found_eos = False for eos_tok in possible_eos: if eos_tok in token2id: eos_id = token2id[eos_tok] found_eos = True print(f"Found EOS token '{eos_tok}' with ID: {eos_id}") break if not found_eos: # If no common EOS found, fall back to the highest index or 0 eos_id = max(token2id.values()) if token2id else 0 print(f"Warning: Standard EOS tokens not found. Using highest index ({eos_id}) as EOS ID.") # Assign pad_id, often same as eos_id or a specific token pad_id = token2id.get("", eos_id) # Prefer if exists, else use eos_id print(f"Using PAD ID: {pad_id}") detokenizer = MosesDetokenizer(lang='en') # Initialize once print(f"Vocabulary loaded. Size: {len(token2id)}") if not detokenizer: raise ValueError("MosesDetokenizer failed to initialize.") except FileNotFoundError: print(f"ERROR: Vocabulary file not found at {dict_path}") raise except Exception as e: print(f"ERROR: Failed to load tokenizer data from {dict_path}: {e}") raise def bpe_encode_lines(lines, shard_size=500, desc="BPE Encode"): """ Encodes lines using external fastBPE binary. Added error checking. """ global BPE_CODES_PATH, FASTBPE_BIN_PATH # --- Input Validation --- if not isinstance(lines, list): print(f"Warning: bpe_encode_lines expected a list, got {type(lines)}. Attempting conversion.") try: lines = list(lines) except TypeError: raise ValueError("Input 'lines' must be a list or convertible to a list.") if not lines: return [] # --- Path and Executable Checks --- abs_fastbpe_path = os.path.abspath(FASTBPE_BIN_PATH) abs_bpe_codes_path = os.path.abspath(BPE_CODES_PATH) if not os.path.exists(abs_fastbpe_path): raise FileNotFoundError(f"fastBPE executable not found at {abs_fastbpe_path}") if not os.path.exists(abs_bpe_codes_path): raise FileNotFoundError(f"BPE codes file not found at {abs_bpe_codes_path}") if not os.access(abs_fastbpe_path, os.X_OK): print(f"Warning: fastBPE binary at {abs_fastbpe_path} is not executable. Attempting chmod...") try: os.chmod(abs_fastbpe_path, 0o755) except OSError as e: raise PermissionError(f"Failed to make fastBPE executable: {e}. Please check permissions.") out_tokens = [] # Process in chunks with tempfile.TemporaryDirectory() as td: for start in range(0, len(lines), shard_size): chunk = lines[start:start+shard_size] src_path = os.path.join(td, f"src_{start}.txt") dst_path = os.path.join(td, f"dst_{start}.bpe") try: # Write chunk to temp file, ensuring strings with open(src_path, "w", encoding="utf-8") as f: for s in chunk: f.write(str(s or "").strip() + "\n") # Ensure string conversion # Run fastBPE cmd = [abs_fastbpe_path, "applybpe", dst_path, src_path, abs_bpe_codes_path] # print(f"Running command: {' '.join(cmd)}") # Debug command process = subprocess.run( cmd, capture_output=True, text=True, check=False # Don't check=True here, handle error below ) # Check for errors specifically if process.returncode != 0: # Log more details on failure print(f"ERROR: fastBPE failed (exit code {process.returncode}) on chunk starting at index {start}.") print(f"Command: {' '.join(cmd)}") print(f"Stderr:\n{process.stderr}") # Optionally print some input data print(f"First line of input chunk: {chunk[0] if chunk else 'N/A'}") raise subprocess.CalledProcessError(process.returncode, cmd, output=process.stdout, stderr=process.stderr) # Read results if successful with open(dst_path, "r", encoding="utf-8") as f: for line in f: out_tokens.append(line.strip().split()) except subprocess.CalledProcessError as e: # Handle specific subprocess errors (already printed details) raise # Re-raise to stop execution except Exception as e: print(f"ERROR: Unexpected error during BPE encoding chunk starting at index {start}: {e}") traceback.print_exc() # Print full traceback for unexpected errors raise # Re-raise return out_tokens def tokens_to_ids(bpe_tokens): """ Converts BPE token strings to IDs using the global map. Added checks. """ global token2id, pad_id if not isinstance(bpe_tokens, list): raise ValueError(f"Input 'bpe_tokens' must be a list, got {type(bpe_tokens)}.") ids = [] oov_count = 0 for t in bpe_tokens: if not isinstance(t, str): print(f"Warning: Non-string token found in bpe_tokens: {t}. Using pad_id.") ids.append(pad_id) oov_count += 1 continue id_val = token2id.get(t, pad_id) ids.append(id_val) if id_val == pad_id and t not in token2id: oov_count += 1 # print(f"Warning: OOV token '{t}' mapped to pad_id {pad_id}") # Reduce noise if oov_count > 0: print(f"Info: Found {oov_count} OOV tokens in sequence of length {len(bpe_tokens)}.") return ids, oov_count def ids_to_tokens(ids): """ Converts IDs back to token strings. Added checks. """ global id2token if not isinstance(ids, list): raise ValueError(f"Input 'ids' must be a list, got {type(ids)}.") tokens = [] for i in ids: # Ensure ID is a valid integer before lookup try: # Handle potential floats or NaNs from generation if isinstance(i, float) and math.isnan(i): token = "" else: int_i = int(i) token = id2token.get(int_i, "") except (ValueError, TypeError): print(f"Warning: Could not convert ID '{i}' to int. Using ''.") token = "" tokens.append(token) return tokens def bpe_decode_tokens(bpe_tokens): """ Converts BPE token strings back to readable text. Added checks. """ global detokenizer if detokenizer is None: raise RuntimeError("Detokenizer not initialized. Call load_tokenizer_data first.") if not isinstance(bpe_tokens, list): raise ValueError(f"Input 'bpe_tokens' must be a list, got {type(bpe_tokens)}.") # Ensure all items are strings before joining try: str_tokens = [str(t) for t in bpe_tokens] except Exception as e: print(f"Error converting tokens to strings: {e}. Tokens: {bpe_tokens}") return "" s = ' '.join(str_tokens).replace('@@ ', '') try: # Detokenizer might fail on empty or unusual input return detokenizer.detokenize(s.split()) if s.strip() else "" except Exception as e: print(f"Error during detokenization: {e}. Input string: '{s}'") return "" # --- Prediction Helper Functions --- def to_canonical(pred_chunk: str): """ Maps a predicted text chunk to a canonical hallmark name. Added checks. """ global HALLMARKS # Ensure input is a string if not isinstance(pred_chunk, str): # print(f"Warning: to_canonical received non-string input: {pred_chunk}. Returning None.") return None s = pred_chunk.strip().lower() low = [L.lower() for L in HALLMARKS] if not s: return None if s in low: return HALLMARKS[low.index(s)] # Use difflib for fuzzy matching try: best = difflib.get_close_matches(s, low, n=1, cutoff=0.7) return HALLMARKS[low.index(best[0])] if best else None except Exception as e: print(f"Error during difflib matching for '{s}': {e}") return None # Return None on error def build_allowed_token_mask(vocab_size, device): """ Builds the mask for constrained decoding. Added error checks. """ global HALLMARKS, token2id, eos_id, pad_id allowed = set() # --- Input Validation --- if vocab_size <= 0: raise ValueError("Vocabulary size must be positive.") if not token2id: raise RuntimeError("Tokenizer vocabulary (token2id) not loaded.") print("Encoding hallmarks for mask...") try: # Ensure HALLMARKS is a list of strings if not isinstance(HALLMARKS, list) or not all(isinstance(h, str) for h in HALLMARKS): raise ValueError("HALLMARKS must be a list of strings.") hallmark_bpes = bpe_encode_lines(HALLMARKS, desc="BPE Hallmarks (for mask)") for bpe_list in hallmark_bpes: ids, _ = tokens_to_ids(bpe_list) allowed.update(ids) print(f"Encoded {len(HALLMARKS)} hallmarks.") except Exception as e: print(f"ERROR: Failed to BPE encode or convert hallmarks for mask: {e}") raise print("Encoding separators for mask...") SEPS = [", ", ",", "; ", ";", "|"] try: sep_bpes = bpe_encode_lines(SEPS, desc="BPE Separators (for mask)") for bpe_list in sep_bpes: ids, _ = tokens_to_ids(bpe_list) allowed.update(ids) print(f"Encoded {len(SEPS)} separators.") except Exception as e: print(f"ERROR: Failed to BPE encode or convert separators for mask: {e}") raise # Add EOS token - Check if eos_id is valid if not isinstance(eos_id, int) or eos_id < 0 or eos_id >= vocab_size: print(f"Warning: Invalid EOS ID ({eos_id}). Defaulting mask EOS to 0.") effective_eos_id = 0 else: effective_eos_id = eos_id allowed.add(effective_eos_id) print(f"Total allowed token IDs (including EOS {effective_eos_id}): {len(allowed)}") # Create the mask tensor on CPU first mask = torch.full((vocab_size,), float('-inf'), device=torch.device('cpu')) try: # Filter out potential invalid IDs before creating list for indexing # Ensure pad_id is valid if used for filtering effective_pad_id = pad_id if isinstance(pad_id, int) and 0 <= pad_id < vocab_size else -1 # Use -1 if pad_id is invalid valid_allowed_ids = [] for id_ in allowed: if isinstance(id_, int) and 0 <= id_ < vocab_size: # Check type and range # Filter out pad_id unless it's the same as the effective_eos_id if id_ != effective_pad_id or id_ == effective_eos_id: valid_allowed_ids.append(id_) # else: print(f"Warning: Invalid ID {id_} in allowed set skipped.") # Reduce noise if not valid_allowed_ids: raise ValueError("No valid token IDs found to allow in the mask.") # Check ranges again after filtering (belt and braces) max_valid_id = max(valid_allowed_ids) min_valid_id = min(valid_allowed_ids) if max_valid_id >= vocab_size or min_valid_id < 0: # This should ideally not happen if filtering worked raise IndexError(f"Filtered allowed IDs still out of range [{min_valid_id}, {max_valid_id}] for vocab size {vocab_size}.") # Apply mask mask[valid_allowed_ids] = 0.0 # Use list directly print(f"Mask created with {len(valid_allowed_ids)} allowed indices.") except IndexError as e: print(f"ERROR: Index error while creating mask. Vocab size: {vocab_size}. Error: {e}") # Find problematic IDs more carefully problem_ids = [i for i in allowed if not isinstance(i, int) or i < 0 or i >= vocab_size] print(f"Problematic IDs in allowed set: {problem_ids}") raise except Exception as e: print(f"ERROR: Unexpected error creating mask: {e}") traceback.print_exc() raise # Move final mask to target device try: target_device = torch.device(device) # Ensure device is a torch.device object return mask.to(target_device) except Exception as e: print(f"Error moving mask to device '{device}': {e}") raise # --- Global Variables for Loaded Model and Assets --- inference_model = None ALLOWED_MASK = None model_device = "cpu" config = None # Added global config # --- Initialization Function --- def initialize_model_and_tokenizer(): global inference_model, ALLOWED_MASK, model_device, token2id, config # Add config print("Initializing model...") # Determine device model_device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {model_device}") # Load tokenizer data first (essential for vocab size) try: load_tokenizer_data(DICT_TXT_PATH) if not token2id: # Check if loading actually populated the dict raise ValueError("Tokenizer loading failed to populate token2id dictionary.") except Exception as e: print(f"FATAL: Could not load tokenizer data. Cannot proceed. Error: {e}") return False # Indicate failure # Define model config (MUST match finetuning config) try: # Ensure config is globally accessible after definition config = GPTConfig( vocab_size=len(token2id), # Get vocab size from loaded data block_size=128, # Match training n_layer=6, # Match training n_head=6, # Match training n_embd=384, # Match training dropout=0.1, # Match training (dropout is off in eval mode) bias=True # Match training ) print(f"Model Config: {config}") except Exception as e: print(f"FATAL: Error creating GPTConfig: {e}") return False # Instantiate base and wrapped model (on CPU initially) try: base_gpt = GPT(config) inference_model = GPTWithSoftPrompt(base_gpt, prompt_len=1) except Exception as e: print(f"FATAL: Error instantiating model: {e}") traceback.print_exc() return False # Load finetuned weights print(f"Loading finetuned weights from: {FINETUNED_MODEL_PATH}") if not os.path.exists(FINETUNED_MODEL_PATH): print(f"ERROR: Model weights file not found at {FINETUNED_MODEL_PATH}") return False try: # Load state dict onto CPU first state_dict = torch.load(FINETUNED_MODEL_PATH, map_location='cpu') # Clean state dict keys (handle DDP 'module.' prefix) cleaned_state_dict = {} for k, v in state_dict.items(): name = k[7:] if k.startswith('module.') else k cleaned_state_dict[name] = v # Load into model missing_keys, unexpected_keys = inference_model.load_state_dict(cleaned_state_dict, strict=False) if missing_keys: # Filter out non-persistent buffer keys if necessary (though strict=False should handle this) missing_persistent = [k for k in missing_keys if inference_model.get_parameter(k) is not None or inference_model.get_buffer(k) is not None] if missing_persistent: print("Warning: Missing persistent keys during state dict load:", missing_persistent) if unexpected_keys: print("Warning: Unexpected keys during state dict load:", unexpected_keys) print("Weights loaded successfully.") except Exception as e: print(f"Error loading state dict from {FINETUNED_MODEL_PATH}: {e}") print("Ensure the model architecture matches the saved checkpoint and the file is not corrupted.") traceback.print_exc() return False # Move model to target device and set to eval mode try: inference_model.to(model_device) inference_model.eval() print(f"Model moved to device: {model_device} and set to eval mode.") except Exception as e: print(f"Error moving model to device '{model_device}': {e}") traceback.print_exc() return False # Build the allowed token mask (after model is on device) print("Building allowed token mask...") try: if config.vocab_size <= 0: raise ValueError("Vocabulary size must be positive to build mask.") # Ensure model_device is valid before passing device_obj = torch.device(model_device) ALLOWED_MASK = build_allowed_token_mask(config.vocab_size, device_obj) print("Allowed token mask created.") except Exception as e: print(f"ERROR: Failed to build allowed token mask: {e}") traceback.print_exc() return False return True # Indicate success # --- Inference Function --- def predict_hallmarks(abstract_text): global inference_model, ALLOWED_MASK, model_device, token2id, eos_id # --- Pre-computation Checks --- if inference_model is None: print("Error: Inference model is not loaded.") return ["Error: Model not loaded"] if ALLOWED_MASK is None: print("Error: Allowed mask is not built.") return ["Error: Mask not built"] if not token2id: print("Error: Tokenizer vocabulary not loaded.") return ["Error: Tokenizer not loaded"] # --- Input Validation --- if not isinstance(abstract_text, str): print(f"Warning: Received non-string abstract text type: {type(abstract_text)}. Attempting conversion.") try: abstract_text = str(abstract_text) except Exception: return ["Error: Invalid input type"] if not abstract_text.strip(): print("Warning: Received empty or whitespace-only abstract text.") return [] # Return empty list for empty input try: # --- 1. Preprocess and Tokenize Input --- print("Tokenizing input abstract...") cleaned_abstract = " ".join(abstract_text.split()) if not cleaned_abstract: print("Warning: Input abstract contains only whitespace after cleaning.") return [] bpe_tokens_list = bpe_encode_lines([cleaned_abstract]) if not bpe_tokens_list or not bpe_tokens_list[0]: # Check if list or first element is empty print("Warning: BPE encoding resulted in empty tokens.") return [] bpe_tokens = bpe_tokens_list[0] input_ids_list, oov = tokens_to_ids(bpe_tokens) if oov > 0: print(f"Info: Input contained {oov} OOV tokens.") # Add EOS token input_ids = input_ids_list + [eos_id] # Convert to tensor and move to device input_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(model_device) # --- 2. Generate Predictions --- print("Generating predictions...") with torch.no_grad(): generated_ids_tensor = inference_model.generate_labels( input_tensor, allowed_mask=ALLOWED_MASK, max_new_tokens=30, temperature=0.0 ) # --- 3. Decode and Post-process --- print("Decoding and cleaning predictions...") if generated_ids_tensor is None or generated_ids_tensor.numel() == 0: print("Warning: Generation resulted in empty tensor.") generated_ids = [] else: # Ensure tensor is on CPU before converting to list generated_ids = generated_ids_tensor[0].cpu().tolist() if not generated_ids: print("No tokens generated.") return [] generated_tokens = ids_to_tokens(generated_ids) # Remove tokens after EOS if present try: eos_token_str = id2token.get(eos_id, "") # Get string representation if eos_token_str in generated_tokens: eos_index = generated_tokens.index(eos_token_str) generated_tokens = generated_tokens[:eos_index] except ValueError: pass # EOS not found is okay # Decode BPE tokens to string generated_text = bpe_decode_tokens(generated_tokens).strip().lower() print(f"Raw generated text: '{generated_text}'") # Split potential multi-labels and map to canonical parts = [] if generated_text: potential_parts = re.split(r'[;,|]\s*', generated_text) parts = [p.strip() for p in potential_parts if p.strip()] if not parts: # Handle case with no delimiters parts = [generated_text] predicted_labels = [] seen_labels = set() for p in parts: canonical_label = to_canonical(p) if canonical_label and canonical_label not in seen_labels: predicted_labels.append(canonical_label) seen_labels.add(canonical_label) print(f"Final predicted labels: {predicted_labels}") return predicted_labels # --- Error Handling --- except FileNotFoundError as fnf_err: print(f"ERROR during prediction (File Not Found - likely BPE related): {fnf_err}") traceback.print_exc() return ["Error: BPE file processing error"] except PermissionError as perm_err: print(f"ERROR during prediction (Permission Error - likely fastBPE): {perm_err}") traceback.print_exc() return ["Error: BPE execution permission"] except RuntimeError as run_err: if "CUDA out of memory" in str(run_err): print(f"ERROR: CUDA Out of Memory during prediction. Input length: {len(input_ids) if 'input_ids' in locals() else 'N/A'}") traceback.print_exc() return ["Error: Input too long (OOM)"] else: print(f"ERROR during prediction (PyTorch RuntimeError): {run_err}") traceback.print_exc() return ["Error: Model runtime error"] except Exception as e: print(f"ERROR during prediction (General Exception): {e}") traceback.print_exc() return [f"Error: An unexpected error occurred"] # --- Flask App --- app = Flask(__name__) # --- Load Model on Startup --- model_initialized = False @app.before_request def ensure_model_loaded(): """ Ensures model is loaded before handling the first request. """ global model_initialized if not model_initialized: print("First request received, attempting to initialize model...") # Add basic locking if deploying with multiple workers (though not fully thread-safe here) # For true multi-worker safety, model loading should happen before workers fork. try: if initialize_model_and_tokenizer(): model_initialized = True print("Model initialization successful.") else: print("FATAL: Model initialization failed during first request.") # We won't raise an error here, but subsequent requests will fail until fixed. except Exception as init_err: print(f"FATAL: Exception during model initialization: {init_err}") traceback.print_exc() # --- Routes --- @app.route('/') def home(): """ Renders the HTML frontend page. """ # Check if initialization failed and show an error page if so? # For simplicity, we assume initialization works or subsequent predict calls fail. return render_template('index.html') @app.route('/predict', methods=['POST']) def predict(): """ Handles prediction requests from the frontend. """ global model_initialized # Check if model is ready if not model_initialized: print("Error: Model not initialized when /predict called.") # Return a specific status code like Service Unavailable return jsonify({'error': 'Model is not ready. Please try again later.'}), 503 # Validate request format if not request.is_json: return jsonify({'error': 'Request must be JSON'}), 400 data = request.get_json() abstract = data.get('abstract') # Validate input abstract if not abstract: return jsonify({'error': 'Missing "abstract" field in JSON request'}), 400 if not isinstance(abstract, str): return jsonify({'error': '"abstract" field must be a string'}), 400 if len(abstract.strip()) == 0: print("Received empty abstract, returning empty prediction.") return jsonify({'predictions': []}) MAX_ABSTRACT_LEN = 10000 # Define max length if len(abstract) > MAX_ABSTRACT_LEN: print(f"Received overly long abstract ({len(abstract)} chars), rejecting.") return jsonify({'error': f'Input abstract is too long (max {MAX_ABSTRACT_LEN} chars)'}), 413 # Payload Too Large print(f"\n--- Received Prediction Request ---") print(f"Input Abstract (first 100 chars): {abstract[:100]}...") try: # Perform prediction predictions = predict_hallmarks(abstract) print(f"--- Prediction Complete ---") # Check if the result indicates an internal error occurred if isinstance(predictions, list) and len(predictions) > 0 and predictions[0].startswith("Error:"): print(f"Internal error during prediction: {predictions[0]}") # Return a generic server error to the client return jsonify({'error': 'An internal error occurred during prediction.'}), 500 else: # Return successful predictions return jsonify({'predictions': predictions}) except Exception as e: # Catch unexpected errors in the route handler itself print(f"--- Prediction Failed Unexpectedly in Route ---") print(f"Error: {e}") traceback.print_exc() return jsonify({'error': 'An internal server error occurred.'}), 500 # --- Run the App --- if __name__ == '__main__': # Initialize model eagerly when running script directly if not model_initialized: print("Running script directly, initializing model eagerly...") if initialize_model_and_tokenizer(): model_initialized = True print("Model initialization successful.") else: print("FATAL: Model initialization failed. Cannot start Flask server.") exit(1) # Exit if model fails to load on startup # Check fastBPE path validity before starting server abs_fastbpe_path = os.path.abspath(FASTBPE_BIN_PATH) if not os.path.exists(abs_fastbpe_path): print(f"ERROR: fastBPE binary not found at '{abs_fastbpe_path}'.") print("Please ensure fastBPE is compiled and the path is correct relative to app.py.") exit(1) if not os.access(abs_fastbpe_path, os.X_OK): print(f"ERROR: fastBPE binary at '{abs_fastbpe_path}' is not executable.") print("Attempting to make it executable with 'chmod +x'...") try: os.chmod(abs_fastbpe_path, 0o755) print(f"Successfully made '{abs_fastbpe_path}' executable.") except OSError as e: print(f"ERROR: Failed to make fastBPE executable: {e}") print("Please set execute permissions manually (e.g., 'chmod +x ./fast').") exit(1) print("Starting Flask server...") # Use host='0.0.0.0' to make it accessible on your network # Set debug=False for production environments app.run(host='0.0.0.0', port=5000, debug=False) # Changed debug to False