Spaces:
Runtime error
Runtime error
| import os | |
| import subprocess | |
| import math | |
| import difflib | |
| import tempfile | |
| import numpy as np | |
| import pandas as pd | |
| from tqdm.auto import tqdm | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass | |
| from sacremoses import MosesDetokenizer | |
| from flask import Flask, request, jsonify, render_template | |
| import traceback # Import traceback for better error logging | |
| import re # Import the regular expression module | |
| # --- Constants and Paths --- | |
| # Ensure these files are in the same directory as app.py or provide correct paths | |
| FINETUNED_MODEL_PATH = "hoc_best.pt" | |
| BPE_CODES_PATH = "bpecodes" | |
| DICT_TXT_PATH = "dict.txt" | |
| FASTBPE_BIN_PATH = "./fastbpe_exec" # Assumes fast executable is alongside app.py | |
| HALLMARKS = [ # Keep this consistent with training/evaluation | |
| "activating invasion and metastasis", "avoiding immune destruction", | |
| "cellular energetics", "enabling replicative immortality", | |
| "evading growth suppressors", "genomic instability and mutation", | |
| "inducing angiogenesis", "resisting cell death", | |
| "sustaining proliferative signaling", "tumor promoting inflammation", | |
| ] | |
| # --- Model Architecture Definitions (Copy from your notebook) --- | |
| # NOTE: Make sure these classes are IDENTICAL to the ones used for training | |
| # including GPTConfig, LayerNorm, CausalSelfAttention, MLP, Block, GPT, GPTWithSoftPrompt | |
| class GPTConfig: | |
| block_size: int | |
| vocab_size: int | |
| n_layer: int | |
| n_head: int | |
| n_embd: int | |
| dropout: float = 0.0 | |
| bias: bool = True | |
| class LayerNorm(nn.Module): | |
| # (Copied from notebook) | |
| def __init__(self, ndim, bias): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(ndim)) | |
| self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None | |
| def forward(self, x): | |
| return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) | |
| class CausalSelfAttention(nn.Module): | |
| # (Copied from notebook) | |
| def __init__(self, config): | |
| super().__init__() | |
| assert config.n_embd % config.n_head == 0 | |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) | |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) | |
| self.attn_dropout = nn.Dropout(config.dropout) | |
| self.resid_dropout = nn.Dropout(config.dropout) | |
| self.n_head = config.n_head | |
| self.n_embd = config.n_embd | |
| self.flash = hasattr(F, 'scaled_dot_product_attention') # Check for flash attention | |
| if not self.flash: | |
| # print("Warning: Flash Attention not available.") # Optional warning | |
| # Make the buffer persistent otherwise device mismatches during forward pass | |
| self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) | |
| .view(1, 1, config.block_size, config.block_size), persistent=True) | |
| #else: | |
| # print("Using Flash Attention.") # Optional info | |
| def forward(self, x): | |
| B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) | |
| # calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) | |
| k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
| q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
| v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
| # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) | |
| if self.flash: | |
| # efficient attention using Flash Attention CUDA kernels | |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True) | |
| else: | |
| # manual implementation of attention | |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
| # Ensure bias buffer is used correctly | |
| # Check if bias buffer exists before using it | |
| if hasattr(self, 'bias'): | |
| att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) | |
| else: | |
| # Fallback if somehow bias wasn't registered (shouldn't happen with persistent=True) | |
| mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T) | |
| att = att.masked_fill(mask == 0, float('-inf')) | |
| att = F.softmax(att, dim=-1) | |
| att = self.attn_dropout(att) | |
| y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
| # output projection | |
| y = self.resid_dropout(self.c_proj(y)) | |
| return y | |
| class MLP(nn.Module): | |
| # (Copied from notebook) | |
| def __init__(self, config): | |
| super().__init__() | |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) | |
| self.gelu = nn.GELU() | |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, x): | |
| x = self.c_fc(x) | |
| x = self.gelu(x) | |
| x = self.c_proj(x) | |
| x = self.dropout(x) | |
| return x | |
| class Block(nn.Module): | |
| # (Copied from notebook) | |
| def __init__(self, config): | |
| super().__init__() | |
| self.ln1 = LayerNorm(config.n_embd, bias=config.bias) | |
| self.attn = CausalSelfAttention(config) | |
| self.ln2 = LayerNorm(config.n_embd, bias=config.bias) | |
| self.mlp = MLP(config) | |
| def forward(self, x): | |
| x = x + self.attn(self.ln1(x)) | |
| x = x + self.mlp(self.ln2(x)) | |
| return x | |
| class GPT(nn.Module): | |
| # (Copied from notebook - simplified _init_weights and removed generate) | |
| def __init__(self, config): | |
| super().__init__() | |
| #assert config.vocab_size is not None | |
| #assert config.block_size is not None | |
| self.config = config | |
| self.transformer = nn.ModuleDict(dict( | |
| wte = nn.Embedding(config.vocab_size, config.n_embd), | |
| wpe = nn.Embedding(config.block_size, config.n_embd), | |
| drop = nn.Dropout(config.dropout), | |
| h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), | |
| ln_f = LayerNorm(config.n_embd, bias=config.bias), | |
| )) | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying | |
| # init all weights | |
| self.apply(self._init_weights) | |
| # apply special scaled init to the residual projections, per GPT-2 paper | |
| for pn, p in self.named_parameters(): | |
| if pn.endswith('c_proj.weight'): | |
| torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, idx, targets=None): | |
| device = idx.device | |
| b, t = idx.size() | |
| if t > self.config.block_size: | |
| # Crop sequence if longer than block size | |
| print(f"Warning: Input sequence length ({t}) > block size ({self.config.block_size}). Cropping.") | |
| idx = idx[:, -self.config.block_size:] | |
| t = self.config.block_size | |
| #assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" | |
| pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) | |
| # forward the GPT model itself | |
| tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) | |
| pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) | |
| x = self.transformer.drop(tok_emb + pos_emb) | |
| for block in self.transformer.h: | |
| x = block(x) | |
| x = self.transformer.ln_f(x) | |
| if targets is not None: | |
| # if we are given some desired targets also calculate the loss | |
| logits = self.lm_head(x) | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) | |
| return logits, loss | |
| else: | |
| # inference-time mini-optimization: only forward the lm_head on the very last position | |
| logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim | |
| # Check for NaN/Inf in logits before returning | |
| if torch.isnan(logits).any() or torch.isinf(logits).any(): | |
| print("WARNING: NaN or Inf detected in logits during inference.") | |
| # Handle appropriately - maybe return an error indicator or zero logits? | |
| # For now, just print warning. | |
| return logits, None | |
| class GPTWithSoftPrompt(nn.Module): | |
| # (Copied from notebook - simplified) | |
| def __init__(self, base_gpt: GPT, prompt_len=1): | |
| super().__init__() | |
| self.config = base_gpt.config | |
| self.transformer = base_gpt.transformer | |
| self.lm_head = base_gpt.lm_head | |
| C = self.config.n_embd | |
| self.soft_prompt = nn.Parameter(torch.zeros(1, prompt_len, C)) # Keep on CPU first | |
| nn.init.normal_(self.soft_prompt, mean=0.0, std=0.02) | |
| def forward(self, idx, targets=None): | |
| B, T = idx.shape | |
| device = idx.device # Get device from input tensor | |
| # Make sure soft_prompt is on the same device as input | |
| soft_prompt_on_device = self.soft_prompt.to(device) | |
| # token + pos | |
| tok_emb = self.transformer.wte(idx) # (B,T,C) | |
| pos = torch.arange(0, T, dtype=torch.long, device=device) | |
| pos_emb = self.transformer.wpe(pos) # (T,C) | |
| x_tokens = tok_emb + pos_emb | |
| # prepend soft prompt | |
| soft = soft_prompt_on_device.expand(B, -1, -1) # (B,P,C) | |
| # --- FIX: Define P before the if/else block --- | |
| P = soft.size(1) # Get soft prompt length | |
| x = torch.cat([soft, x_tokens], dim=1) # (B,P+T,C) | |
| # --- Standard Transformer forward pass --- | |
| x = self.transformer.drop(x) | |
| for block in self.transformer.h: | |
| x = block(x) | |
| x = self.transformer.ln_f(x) | |
| logits = self.lm_head(x) # (B,P+T,V) | |
| # --- End Standard --- | |
| if targets is None: | |
| # Inference: return logits for the last token of the *original* sequence | |
| # We need the prediction *after* the last input token, which is at index T (P+T-1 overall) | |
| # Use P which is now defined | |
| # Ensure index is within bounds | |
| target_logit_index = P + T - 1 | |
| if target_logit_index >= logits.size(1): | |
| print(f"Warning: Calculated logit index {target_logit_index} out of bounds for logits shape {logits.shape}. Returning last logit.") | |
| target_logit_index = -1 # Fallback to last logit | |
| final_logits = logits[:, target_logit_index, :] | |
| # Check for NaN/Inf | |
| if torch.isnan(final_logits).any() or torch.isinf(final_logits).any(): | |
| print(f"WARNING: NaN or Inf detected in final_logits at index {target_logit_index}.") | |
| # Handle appropriately - maybe return zeros or raise an error? | |
| # For now, just print warning. Let the calling function handle it. | |
| return final_logits, None # Return (B, V) | |
| else: | |
| # Training loss calculation (copied from notebook) | |
| # P is already defined above | |
| pad_ignore = torch.full((B, P), -1, dtype=targets.dtype, device=device) | |
| full_targets = torch.cat([pad_ignore, targets], dim=1) | |
| logits_lm = logits[:, :-1, :].contiguous() | |
| targets_lm = full_targets[:, 1:].contiguous() | |
| loss = F.cross_entropy( | |
| logits_lm.view(-1, logits_lm.size(-1)), | |
| targets_lm.view(-1), | |
| ignore_index=-1 | |
| ) | |
| # Check for NaN/Inf in loss | |
| if torch.isnan(loss) or torch.isinf(loss): | |
| print("WARNING: NaN or Inf detected in loss calculation.") | |
| # Potentially add debugging info here (e.g., print shapes, inputs) | |
| return logits, loss | |
| # --- Constrained generation method (from Section 2.9) --- | |
| def generate_labels(self, idx, allowed_mask, max_new_tokens=24, temperature=0.0): | |
| self.eval() # Ensure model is in eval mode | |
| B = idx.size(0) | |
| # Add soft prompt length to effective block size consideration | |
| P = self.soft_prompt.size(1) | |
| # Correct effective block size based on GPT class logic | |
| effective_block_size = self.config.block_size # GPT forward handles cropping | |
| # Start with input index | |
| out = idx.clone() # Clone to avoid modifying original input | |
| # Ensure allowed_mask is on the correct device | |
| allowed_mask = allowed_mask.to(idx.device) | |
| finished = torch.zeros(B, dtype=torch.bool, device=idx.device) | |
| # Get global eos_id safely | |
| global eos_id | |
| current_eos_id = eos_id # Use the globally loaded eos_id | |
| for step in range(max_new_tokens): | |
| # Crop context if it exceeds model's block size (GPT forward handles this internally now) | |
| # ctx = out if out.size(1) <= effective_block_size else out[:, -effective_block_size:] | |
| ctx = out # Pass the current sequence | |
| # Forward pass - expects shape (B, T), model handles soft prompt internally | |
| # It returns logits for the *next* token prediction after the last token in ctx | |
| logits, _ = self(ctx) # Gets logits for last token prediction, shape (B, V) | |
| # Check for NaN/Inf in logits | |
| if torch.isnan(logits).any() or torch.isinf(logits).any(): | |
| print(f"WARNING: NaN or Inf detected in logits during generation step {step}. Stopping generation.") | |
| # Return what we have so far, excluding potentially bad last token | |
| return out[:, idx.size(1):] # Or handle error differently | |
| # Apply constraint mask | |
| # Ensure mask shape matches logits shape | |
| if logits.shape != allowed_mask.shape: | |
| print(f"Warning: Logits shape {logits.shape} doesn't match mask shape {allowed_mask.shape}. Reshaping mask.") | |
| # This assumes mask needs batch dim added | |
| current_mask = allowed_mask.unsqueeze(0).expand_as(logits) | |
| else: | |
| current_mask = allowed_mask | |
| logits = logits + current_mask | |
| # Sample next token | |
| if temperature <= 0: | |
| # Greedy decoding | |
| next_id = torch.argmax(logits, dim=-1) # (B,) | |
| else: | |
| # Temperature sampling | |
| probs = F.softmax(logits / temperature, dim=-1) | |
| # Check for NaN/Inf in probs | |
| if torch.isnan(probs).any() or torch.isinf(probs).any(): | |
| print(f"WARNING: NaN or Inf detected in probabilities during generation step {step}. Using argmax fallback.") | |
| next_id = torch.argmax(logits, dim=-1) # Fallback to greedy | |
| else: | |
| try: | |
| next_id = torch.multinomial(probs, num_samples=1).squeeze(1) # (B,) | |
| except RuntimeError as e: | |
| print(f"WARNING: torch.multinomial failed: {e}. Using argmax fallback.") | |
| next_id = torch.argmax(logits, dim=-1) # Fallback to greedy | |
| # Handle finished sequences (force EOS) and update output | |
| # Check if current_eos_id is valid | |
| if not isinstance(current_eos_id, int): | |
| print(f"Warning: Global eos_id is not an integer ({current_eos_id}). Defaulting to 0.") | |
| current_eos_id = 0 | |
| next_id = next_id.masked_fill(finished, current_eos_id) # Use the validated eos_id | |
| # Check if next_id contains invalid values (e.g., negative) | |
| if (next_id < 0).any(): | |
| print(f"WARNING: Negative token ID generated: {next_id}. Clipping to 0.") | |
| next_id = torch.clamp(next_id, min=0) | |
| # Append the next token ID | |
| out = torch.cat([out, next_id.unsqueeze(1)], dim=1) | |
| # Update finished status | |
| finished |= (next_id == current_eos_id) | |
| # Stop if all sequences in the batch are finished | |
| if bool(finished.all()): | |
| # print(f"Generation finished early at step {step+1}") # Optional debug info | |
| break | |
| # else: | |
| # print(f"Generation reached max_new_tokens ({max_new_tokens})") # Optional debug info | |
| # Return only the generated part (excluding the initial idx length) | |
| return out[:, idx.size(1):] | |
| # --- Tokenizer Helper Functions --- | |
| # Added robustness and error checks | |
| # Global tokenizer maps and special IDs, loaded once at startup | |
| token2id, id2token = {}, {} | |
| eos_id = 0 # Default, will be overwritten | |
| pad_id = 0 # Default, will be overwritten | |
| detokenizer = None | |
| def load_tokenizer_data(dict_path): | |
| global token2id, id2token, eos_id, pad_id, detokenizer | |
| print(f"Loading vocabulary from {dict_path}...") | |
| local_token2id, local_id2token = {}, {} | |
| try: | |
| with open(dict_path, encoding="utf-8") as f: | |
| for i, line in enumerate(f): | |
| parts = line.split() # Split by whitespace | |
| if not parts: continue # Skip empty lines | |
| tok = parts[0] | |
| if tok in local_token2id: | |
| print(f"Warning: Duplicate token '{tok}' found at line {i+1}. Keeping first occurrence.") | |
| continue | |
| local_token2id[tok] = i | |
| local_id2token[i] = tok | |
| # Assign to global variables only after successful loading | |
| token2id = local_token2id | |
| id2token = local_id2token | |
| # Use a known special token ID if </s> is missing, otherwise default might be wrong | |
| # Try multiple common EOS tokens | |
| possible_eos = ["</s>", "<|endoftext|>", "[EOS]"] | |
| found_eos = False | |
| for eos_tok in possible_eos: | |
| if eos_tok in token2id: | |
| eos_id = token2id[eos_tok] | |
| found_eos = True | |
| print(f"Found EOS token '{eos_tok}' with ID: {eos_id}") | |
| break | |
| if not found_eos: | |
| # If no common EOS found, fall back to the highest index or 0 | |
| eos_id = max(token2id.values()) if token2id else 0 | |
| print(f"Warning: Standard EOS tokens not found. Using highest index ({eos_id}) as EOS ID.") | |
| # Assign pad_id, often same as eos_id or a specific <pad> token | |
| pad_id = token2id.get("<pad>", eos_id) # Prefer <pad> if exists, else use eos_id | |
| print(f"Using PAD ID: {pad_id}") | |
| detokenizer = MosesDetokenizer(lang='en') # Initialize once | |
| print(f"Vocabulary loaded. Size: {len(token2id)}") | |
| if not detokenizer: | |
| raise ValueError("MosesDetokenizer failed to initialize.") | |
| except FileNotFoundError: | |
| print(f"ERROR: Vocabulary file not found at {dict_path}") | |
| raise | |
| except Exception as e: | |
| print(f"ERROR: Failed to load tokenizer data from {dict_path}: {e}") | |
| raise | |
| def bpe_encode_lines(lines, shard_size=500, desc="BPE Encode"): | |
| """ Encodes lines using external fastBPE binary. Added error checking. """ | |
| global BPE_CODES_PATH, FASTBPE_BIN_PATH | |
| # --- Input Validation --- | |
| if not isinstance(lines, list): | |
| print(f"Warning: bpe_encode_lines expected a list, got {type(lines)}. Attempting conversion.") | |
| try: | |
| lines = list(lines) | |
| except TypeError: | |
| raise ValueError("Input 'lines' must be a list or convertible to a list.") | |
| if not lines: return [] | |
| # --- Path and Executable Checks --- | |
| abs_fastbpe_path = os.path.abspath(FASTBPE_BIN_PATH) | |
| abs_bpe_codes_path = os.path.abspath(BPE_CODES_PATH) | |
| if not os.path.exists(abs_fastbpe_path): | |
| raise FileNotFoundError(f"fastBPE executable not found at {abs_fastbpe_path}") | |
| if not os.path.exists(abs_bpe_codes_path): | |
| raise FileNotFoundError(f"BPE codes file not found at {abs_bpe_codes_path}") | |
| if not os.access(abs_fastbpe_path, os.X_OK): | |
| print(f"Warning: fastBPE binary at {abs_fastbpe_path} is not executable. Attempting chmod...") | |
| try: | |
| os.chmod(abs_fastbpe_path, 0o755) | |
| except OSError as e: | |
| raise PermissionError(f"Failed to make fastBPE executable: {e}. Please check permissions.") | |
| out_tokens = [] | |
| # Process in chunks | |
| with tempfile.TemporaryDirectory() as td: | |
| for start in range(0, len(lines), shard_size): | |
| chunk = lines[start:start+shard_size] | |
| src_path = os.path.join(td, f"src_{start}.txt") | |
| dst_path = os.path.join(td, f"dst_{start}.bpe") | |
| try: | |
| # Write chunk to temp file, ensuring strings | |
| with open(src_path, "w", encoding="utf-8") as f: | |
| for s in chunk: | |
| f.write(str(s or "").strip() + "\n") # Ensure string conversion | |
| # Run fastBPE | |
| cmd = [abs_fastbpe_path, "applybpe", dst_path, src_path, abs_bpe_codes_path] | |
| # print(f"Running command: {' '.join(cmd)}") # Debug command | |
| process = subprocess.run( | |
| cmd, | |
| capture_output=True, text=True, check=False # Don't check=True here, handle error below | |
| ) | |
| # Check for errors specifically | |
| if process.returncode != 0: | |
| # Log more details on failure | |
| print(f"ERROR: fastBPE failed (exit code {process.returncode}) on chunk starting at index {start}.") | |
| print(f"Command: {' '.join(cmd)}") | |
| print(f"Stderr:\n{process.stderr}") | |
| # Optionally print some input data | |
| print(f"First line of input chunk: {chunk[0] if chunk else 'N/A'}") | |
| raise subprocess.CalledProcessError(process.returncode, cmd, output=process.stdout, stderr=process.stderr) | |
| # Read results if successful | |
| with open(dst_path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| out_tokens.append(line.strip().split()) | |
| except subprocess.CalledProcessError as e: | |
| # Handle specific subprocess errors (already printed details) | |
| raise # Re-raise to stop execution | |
| except Exception as e: | |
| print(f"ERROR: Unexpected error during BPE encoding chunk starting at index {start}: {e}") | |
| traceback.print_exc() # Print full traceback for unexpected errors | |
| raise # Re-raise | |
| return out_tokens | |
| def tokens_to_ids(bpe_tokens): | |
| """ Converts BPE token strings to IDs using the global map. Added checks. """ | |
| global token2id, pad_id | |
| if not isinstance(bpe_tokens, list): | |
| raise ValueError(f"Input 'bpe_tokens' must be a list, got {type(bpe_tokens)}.") | |
| ids = [] | |
| oov_count = 0 | |
| for t in bpe_tokens: | |
| if not isinstance(t, str): | |
| print(f"Warning: Non-string token found in bpe_tokens: {t}. Using pad_id.") | |
| ids.append(pad_id) | |
| oov_count += 1 | |
| continue | |
| id_val = token2id.get(t, pad_id) | |
| ids.append(id_val) | |
| if id_val == pad_id and t not in token2id: | |
| oov_count += 1 | |
| # print(f"Warning: OOV token '{t}' mapped to pad_id {pad_id}") # Reduce noise | |
| if oov_count > 0: | |
| print(f"Info: Found {oov_count} OOV tokens in sequence of length {len(bpe_tokens)}.") | |
| return ids, oov_count | |
| def ids_to_tokens(ids): | |
| """ Converts IDs back to token strings. Added checks. """ | |
| global id2token | |
| if not isinstance(ids, list): | |
| raise ValueError(f"Input 'ids' must be a list, got {type(ids)}.") | |
| tokens = [] | |
| for i in ids: | |
| # Ensure ID is a valid integer before lookup | |
| try: | |
| # Handle potential floats or NaNs from generation | |
| if isinstance(i, float) and math.isnan(i): | |
| token = "<nan>" | |
| else: | |
| int_i = int(i) | |
| token = id2token.get(int_i, "<unk>") | |
| except (ValueError, TypeError): | |
| print(f"Warning: Could not convert ID '{i}' to int. Using '<unk>'.") | |
| token = "<unk>" | |
| tokens.append(token) | |
| return tokens | |
| def bpe_decode_tokens(bpe_tokens): | |
| """ Converts BPE token strings back to readable text. Added checks. """ | |
| global detokenizer | |
| if detokenizer is None: | |
| raise RuntimeError("Detokenizer not initialized. Call load_tokenizer_data first.") | |
| if not isinstance(bpe_tokens, list): | |
| raise ValueError(f"Input 'bpe_tokens' must be a list, got {type(bpe_tokens)}.") | |
| # Ensure all items are strings before joining | |
| try: | |
| str_tokens = [str(t) for t in bpe_tokens] | |
| except Exception as e: | |
| print(f"Error converting tokens to strings: {e}. Tokens: {bpe_tokens}") | |
| return "<decoding error>" | |
| s = ' '.join(str_tokens).replace('@@ ', '') | |
| try: | |
| # Detokenizer might fail on empty or unusual input | |
| return detokenizer.detokenize(s.split()) if s.strip() else "" | |
| except Exception as e: | |
| print(f"Error during detokenization: {e}. Input string: '{s}'") | |
| return "<detokenization error>" | |
| # --- Prediction Helper Functions --- | |
| def to_canonical(pred_chunk: str): | |
| """ Maps a predicted text chunk to a canonical hallmark name. Added checks. """ | |
| global HALLMARKS | |
| # Ensure input is a string | |
| if not isinstance(pred_chunk, str): | |
| # print(f"Warning: to_canonical received non-string input: {pred_chunk}. Returning None.") | |
| return None | |
| s = pred_chunk.strip().lower() | |
| low = [L.lower() for L in HALLMARKS] | |
| if not s: return None | |
| if s in low: | |
| return HALLMARKS[low.index(s)] | |
| # Use difflib for fuzzy matching | |
| try: | |
| best = difflib.get_close_matches(s, low, n=1, cutoff=0.7) | |
| return HALLMARKS[low.index(best[0])] if best else None | |
| except Exception as e: | |
| print(f"Error during difflib matching for '{s}': {e}") | |
| return None # Return None on error | |
| def build_allowed_token_mask(vocab_size, device): | |
| """ Builds the mask for constrained decoding. Added error checks. """ | |
| global HALLMARKS, token2id, eos_id, pad_id | |
| allowed = set() | |
| # --- Input Validation --- | |
| if vocab_size <= 0: | |
| raise ValueError("Vocabulary size must be positive.") | |
| if not token2id: | |
| raise RuntimeError("Tokenizer vocabulary (token2id) not loaded.") | |
| print("Encoding hallmarks for mask...") | |
| try: | |
| # Ensure HALLMARKS is a list of strings | |
| if not isinstance(HALLMARKS, list) or not all(isinstance(h, str) for h in HALLMARKS): | |
| raise ValueError("HALLMARKS must be a list of strings.") | |
| hallmark_bpes = bpe_encode_lines(HALLMARKS, desc="BPE Hallmarks (for mask)") | |
| for bpe_list in hallmark_bpes: | |
| ids, _ = tokens_to_ids(bpe_list) | |
| allowed.update(ids) | |
| print(f"Encoded {len(HALLMARKS)} hallmarks.") | |
| except Exception as e: | |
| print(f"ERROR: Failed to BPE encode or convert hallmarks for mask: {e}") | |
| raise | |
| print("Encoding separators for mask...") | |
| SEPS = [", ", ",", "; ", ";", "|"] | |
| try: | |
| sep_bpes = bpe_encode_lines(SEPS, desc="BPE Separators (for mask)") | |
| for bpe_list in sep_bpes: | |
| ids, _ = tokens_to_ids(bpe_list) | |
| allowed.update(ids) | |
| print(f"Encoded {len(SEPS)} separators.") | |
| except Exception as e: | |
| print(f"ERROR: Failed to BPE encode or convert separators for mask: {e}") | |
| raise | |
| # Add EOS token - Check if eos_id is valid | |
| if not isinstance(eos_id, int) or eos_id < 0 or eos_id >= vocab_size: | |
| print(f"Warning: Invalid EOS ID ({eos_id}). Defaulting mask EOS to 0.") | |
| effective_eos_id = 0 | |
| else: | |
| effective_eos_id = eos_id | |
| allowed.add(effective_eos_id) | |
| print(f"Total allowed token IDs (including EOS {effective_eos_id}): {len(allowed)}") | |
| # Create the mask tensor on CPU first | |
| mask = torch.full((vocab_size,), float('-inf'), device=torch.device('cpu')) | |
| try: | |
| # Filter out potential invalid IDs before creating list for indexing | |
| # Ensure pad_id is valid if used for filtering | |
| effective_pad_id = pad_id if isinstance(pad_id, int) and 0 <= pad_id < vocab_size else -1 # Use -1 if pad_id is invalid | |
| valid_allowed_ids = [] | |
| for id_ in allowed: | |
| if isinstance(id_, int) and 0 <= id_ < vocab_size: # Check type and range | |
| # Filter out pad_id unless it's the same as the effective_eos_id | |
| if id_ != effective_pad_id or id_ == effective_eos_id: | |
| valid_allowed_ids.append(id_) | |
| # else: print(f"Warning: Invalid ID {id_} in allowed set skipped.") # Reduce noise | |
| if not valid_allowed_ids: | |
| raise ValueError("No valid token IDs found to allow in the mask.") | |
| # Check ranges again after filtering (belt and braces) | |
| max_valid_id = max(valid_allowed_ids) | |
| min_valid_id = min(valid_allowed_ids) | |
| if max_valid_id >= vocab_size or min_valid_id < 0: | |
| # This should ideally not happen if filtering worked | |
| raise IndexError(f"Filtered allowed IDs still out of range [{min_valid_id}, {max_valid_id}] for vocab size {vocab_size}.") | |
| # Apply mask | |
| mask[valid_allowed_ids] = 0.0 # Use list directly | |
| print(f"Mask created with {len(valid_allowed_ids)} allowed indices.") | |
| except IndexError as e: | |
| print(f"ERROR: Index error while creating mask. Vocab size: {vocab_size}. Error: {e}") | |
| # Find problematic IDs more carefully | |
| problem_ids = [i for i in allowed if not isinstance(i, int) or i < 0 or i >= vocab_size] | |
| print(f"Problematic IDs in allowed set: {problem_ids}") | |
| raise | |
| except Exception as e: | |
| print(f"ERROR: Unexpected error creating mask: {e}") | |
| traceback.print_exc() | |
| raise | |
| # Move final mask to target device | |
| try: | |
| target_device = torch.device(device) # Ensure device is a torch.device object | |
| return mask.to(target_device) | |
| except Exception as e: | |
| print(f"Error moving mask to device '{device}': {e}") | |
| raise | |
| # --- Global Variables for Loaded Model and Assets --- | |
| inference_model = None | |
| ALLOWED_MASK = None | |
| model_device = "cpu" | |
| config = None # Added global config | |
| # --- Initialization Function --- | |
| def initialize_model_and_tokenizer(): | |
| global inference_model, ALLOWED_MASK, model_device, token2id, config # Add config | |
| print("Initializing model...") | |
| # Determine device | |
| model_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {model_device}") | |
| # Load tokenizer data first (essential for vocab size) | |
| try: | |
| load_tokenizer_data(DICT_TXT_PATH) | |
| if not token2id: # Check if loading actually populated the dict | |
| raise ValueError("Tokenizer loading failed to populate token2id dictionary.") | |
| except Exception as e: | |
| print(f"FATAL: Could not load tokenizer data. Cannot proceed. Error: {e}") | |
| return False # Indicate failure | |
| # Define model config (MUST match finetuning config) | |
| try: | |
| # Ensure config is globally accessible after definition | |
| config = GPTConfig( | |
| vocab_size=len(token2id), # Get vocab size from loaded data | |
| block_size=128, # Match training | |
| n_layer=6, # Match training | |
| n_head=6, # Match training | |
| n_embd=384, # Match training | |
| dropout=0.1, # Match training (dropout is off in eval mode) | |
| bias=True # Match training | |
| ) | |
| print(f"Model Config: {config}") | |
| except Exception as e: | |
| print(f"FATAL: Error creating GPTConfig: {e}") | |
| return False | |
| # Instantiate base and wrapped model (on CPU initially) | |
| try: | |
| base_gpt = GPT(config) | |
| inference_model = GPTWithSoftPrompt(base_gpt, prompt_len=1) | |
| except Exception as e: | |
| print(f"FATAL: Error instantiating model: {e}") | |
| traceback.print_exc() | |
| return False | |
| # Load finetuned weights | |
| print(f"Loading finetuned weights from: {FINETUNED_MODEL_PATH}") | |
| if not os.path.exists(FINETUNED_MODEL_PATH): | |
| print(f"ERROR: Model weights file not found at {FINETUNED_MODEL_PATH}") | |
| return False | |
| try: | |
| # Load state dict onto CPU first | |
| state_dict = torch.load(FINETUNED_MODEL_PATH, map_location='cpu') | |
| # Clean state dict keys (handle DDP 'module.' prefix) | |
| cleaned_state_dict = {} | |
| for k, v in state_dict.items(): | |
| name = k[7:] if k.startswith('module.') else k | |
| cleaned_state_dict[name] = v | |
| # Load into model | |
| missing_keys, unexpected_keys = inference_model.load_state_dict(cleaned_state_dict, strict=False) | |
| if missing_keys: | |
| # Filter out non-persistent buffer keys if necessary (though strict=False should handle this) | |
| missing_persistent = [k for k in missing_keys if inference_model.get_parameter(k) is not None or inference_model.get_buffer(k) is not None] | |
| if missing_persistent: | |
| print("Warning: Missing persistent keys during state dict load:", missing_persistent) | |
| if unexpected_keys: | |
| print("Warning: Unexpected keys during state dict load:", unexpected_keys) | |
| print("Weights loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading state dict from {FINETUNED_MODEL_PATH}: {e}") | |
| print("Ensure the model architecture matches the saved checkpoint and the file is not corrupted.") | |
| traceback.print_exc() | |
| return False | |
| # Move model to target device and set to eval mode | |
| try: | |
| inference_model.to(model_device) | |
| inference_model.eval() | |
| print(f"Model moved to device: {model_device} and set to eval mode.") | |
| except Exception as e: | |
| print(f"Error moving model to device '{model_device}': {e}") | |
| traceback.print_exc() | |
| return False | |
| # Build the allowed token mask (after model is on device) | |
| print("Building allowed token mask...") | |
| try: | |
| if config.vocab_size <= 0: | |
| raise ValueError("Vocabulary size must be positive to build mask.") | |
| # Ensure model_device is valid before passing | |
| device_obj = torch.device(model_device) | |
| ALLOWED_MASK = build_allowed_token_mask(config.vocab_size, device_obj) | |
| print("Allowed token mask created.") | |
| except Exception as e: | |
| print(f"ERROR: Failed to build allowed token mask: {e}") | |
| traceback.print_exc() | |
| return False | |
| return True # Indicate success | |
| # --- Inference Function --- | |
| def predict_hallmarks(abstract_text): | |
| global inference_model, ALLOWED_MASK, model_device, token2id, eos_id | |
| # --- Pre-computation Checks --- | |
| if inference_model is None: | |
| print("Error: Inference model is not loaded.") | |
| return ["Error: Model not loaded"] | |
| if ALLOWED_MASK is None: | |
| print("Error: Allowed mask is not built.") | |
| return ["Error: Mask not built"] | |
| if not token2id: | |
| print("Error: Tokenizer vocabulary not loaded.") | |
| return ["Error: Tokenizer not loaded"] | |
| # --- Input Validation --- | |
| if not isinstance(abstract_text, str): | |
| print(f"Warning: Received non-string abstract text type: {type(abstract_text)}. Attempting conversion.") | |
| try: | |
| abstract_text = str(abstract_text) | |
| except Exception: | |
| return ["Error: Invalid input type"] | |
| if not abstract_text.strip(): | |
| print("Warning: Received empty or whitespace-only abstract text.") | |
| return [] # Return empty list for empty input | |
| try: | |
| # --- 1. Preprocess and Tokenize Input --- | |
| print("Tokenizing input abstract...") | |
| cleaned_abstract = " ".join(abstract_text.split()) | |
| if not cleaned_abstract: | |
| print("Warning: Input abstract contains only whitespace after cleaning.") | |
| return [] | |
| bpe_tokens_list = bpe_encode_lines([cleaned_abstract]) | |
| if not bpe_tokens_list or not bpe_tokens_list[0]: # Check if list or first element is empty | |
| print("Warning: BPE encoding resulted in empty tokens.") | |
| return [] | |
| bpe_tokens = bpe_tokens_list[0] | |
| input_ids_list, oov = tokens_to_ids(bpe_tokens) | |
| if oov > 0: | |
| print(f"Info: Input contained {oov} OOV tokens.") | |
| # Add EOS token | |
| input_ids = input_ids_list + [eos_id] | |
| # Convert to tensor and move to device | |
| input_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(model_device) | |
| # --- 2. Generate Predictions --- | |
| print("Generating predictions...") | |
| with torch.no_grad(): | |
| generated_ids_tensor = inference_model.generate_labels( | |
| input_tensor, | |
| allowed_mask=ALLOWED_MASK, | |
| max_new_tokens=30, | |
| temperature=0.0 | |
| ) | |
| # --- 3. Decode and Post-process --- | |
| print("Decoding and cleaning predictions...") | |
| if generated_ids_tensor is None or generated_ids_tensor.numel() == 0: | |
| print("Warning: Generation resulted in empty tensor.") | |
| generated_ids = [] | |
| else: | |
| # Ensure tensor is on CPU before converting to list | |
| generated_ids = generated_ids_tensor[0].cpu().tolist() | |
| if not generated_ids: | |
| print("No tokens generated.") | |
| return [] | |
| generated_tokens = ids_to_tokens(generated_ids) | |
| # Remove tokens after EOS if present | |
| try: | |
| eos_token_str = id2token.get(eos_id, "</s>") # Get string representation | |
| if eos_token_str in generated_tokens: | |
| eos_index = generated_tokens.index(eos_token_str) | |
| generated_tokens = generated_tokens[:eos_index] | |
| except ValueError: | |
| pass # EOS not found is okay | |
| # Decode BPE tokens to string | |
| generated_text = bpe_decode_tokens(generated_tokens).strip().lower() | |
| print(f"Raw generated text: '{generated_text}'") | |
| # Split potential multi-labels and map to canonical | |
| parts = [] | |
| if generated_text: | |
| potential_parts = re.split(r'[;,|]\s*', generated_text) | |
| parts = [p.strip() for p in potential_parts if p.strip()] | |
| if not parts: # Handle case with no delimiters | |
| parts = [generated_text] | |
| predicted_labels = [] | |
| seen_labels = set() | |
| for p in parts: | |
| canonical_label = to_canonical(p) | |
| if canonical_label and canonical_label not in seen_labels: | |
| predicted_labels.append(canonical_label) | |
| seen_labels.add(canonical_label) | |
| print(f"Final predicted labels: {predicted_labels}") | |
| return predicted_labels | |
| # --- Error Handling --- | |
| except FileNotFoundError as fnf_err: | |
| print(f"ERROR during prediction (File Not Found - likely BPE related): {fnf_err}") | |
| traceback.print_exc() | |
| return ["Error: BPE file processing error"] | |
| except PermissionError as perm_err: | |
| print(f"ERROR during prediction (Permission Error - likely fastBPE): {perm_err}") | |
| traceback.print_exc() | |
| return ["Error: BPE execution permission"] | |
| except RuntimeError as run_err: | |
| if "CUDA out of memory" in str(run_err): | |
| print(f"ERROR: CUDA Out of Memory during prediction. Input length: {len(input_ids) if 'input_ids' in locals() else 'N/A'}") | |
| traceback.print_exc() | |
| return ["Error: Input too long (OOM)"] | |
| else: | |
| print(f"ERROR during prediction (PyTorch RuntimeError): {run_err}") | |
| traceback.print_exc() | |
| return ["Error: Model runtime error"] | |
| except Exception as e: | |
| print(f"ERROR during prediction (General Exception): {e}") | |
| traceback.print_exc() | |
| return [f"Error: An unexpected error occurred"] | |
| # --- Flask App --- | |
| app = Flask(__name__) | |
| # --- Load Model on Startup --- | |
| model_initialized = False | |
| def ensure_model_loaded(): | |
| """ Ensures model is loaded before handling the first request. """ | |
| global model_initialized | |
| if not model_initialized: | |
| print("First request received, attempting to initialize model...") | |
| # Add basic locking if deploying with multiple workers (though not fully thread-safe here) | |
| # For true multi-worker safety, model loading should happen before workers fork. | |
| try: | |
| if initialize_model_and_tokenizer(): | |
| model_initialized = True | |
| print("Model initialization successful.") | |
| else: | |
| print("FATAL: Model initialization failed during first request.") | |
| # We won't raise an error here, but subsequent requests will fail until fixed. | |
| except Exception as init_err: | |
| print(f"FATAL: Exception during model initialization: {init_err}") | |
| traceback.print_exc() | |
| # --- Routes --- | |
| def home(): | |
| """ Renders the HTML frontend page. """ | |
| # Check if initialization failed and show an error page if so? | |
| # For simplicity, we assume initialization works or subsequent predict calls fail. | |
| return render_template('index.html') | |
| def predict(): | |
| """ Handles prediction requests from the frontend. """ | |
| global model_initialized | |
| # Check if model is ready | |
| if not model_initialized: | |
| print("Error: Model not initialized when /predict called.") | |
| # Return a specific status code like Service Unavailable | |
| return jsonify({'error': 'Model is not ready. Please try again later.'}), 503 | |
| # Validate request format | |
| if not request.is_json: | |
| return jsonify({'error': 'Request must be JSON'}), 400 | |
| data = request.get_json() | |
| abstract = data.get('abstract') | |
| # Validate input abstract | |
| if not abstract: | |
| return jsonify({'error': 'Missing "abstract" field in JSON request'}), 400 | |
| if not isinstance(abstract, str): | |
| return jsonify({'error': '"abstract" field must be a string'}), 400 | |
| if len(abstract.strip()) == 0: | |
| print("Received empty abstract, returning empty prediction.") | |
| return jsonify({'predictions': []}) | |
| MAX_ABSTRACT_LEN = 10000 # Define max length | |
| if len(abstract) > MAX_ABSTRACT_LEN: | |
| print(f"Received overly long abstract ({len(abstract)} chars), rejecting.") | |
| return jsonify({'error': f'Input abstract is too long (max {MAX_ABSTRACT_LEN} chars)'}), 413 # Payload Too Large | |
| print(f"\n--- Received Prediction Request ---") | |
| print(f"Input Abstract (first 100 chars): {abstract[:100]}...") | |
| try: | |
| # Perform prediction | |
| predictions = predict_hallmarks(abstract) | |
| print(f"--- Prediction Complete ---") | |
| # Check if the result indicates an internal error occurred | |
| if isinstance(predictions, list) and len(predictions) > 0 and predictions[0].startswith("Error:"): | |
| print(f"Internal error during prediction: {predictions[0]}") | |
| # Return a generic server error to the client | |
| return jsonify({'error': 'An internal error occurred during prediction.'}), 500 | |
| else: | |
| # Return successful predictions | |
| return jsonify({'predictions': predictions}) | |
| except Exception as e: | |
| # Catch unexpected errors in the route handler itself | |
| print(f"--- Prediction Failed Unexpectedly in Route ---") | |
| print(f"Error: {e}") | |
| traceback.print_exc() | |
| return jsonify({'error': 'An internal server error occurred.'}), 500 | |
| # --- Run the App --- | |
| if __name__ == '__main__': | |
| # Initialize model eagerly when running script directly | |
| if not model_initialized: | |
| print("Running script directly, initializing model eagerly...") | |
| if initialize_model_and_tokenizer(): | |
| model_initialized = True | |
| print("Model initialization successful.") | |
| else: | |
| print("FATAL: Model initialization failed. Cannot start Flask server.") | |
| exit(1) # Exit if model fails to load on startup | |
| # Check fastBPE path validity before starting server | |
| abs_fastbpe_path = os.path.abspath(FASTBPE_BIN_PATH) | |
| if not os.path.exists(abs_fastbpe_path): | |
| print(f"ERROR: fastBPE binary not found at '{abs_fastbpe_path}'.") | |
| print("Please ensure fastBPE is compiled and the path is correct relative to app.py.") | |
| exit(1) | |
| if not os.access(abs_fastbpe_path, os.X_OK): | |
| print(f"ERROR: fastBPE binary at '{abs_fastbpe_path}' is not executable.") | |
| print("Attempting to make it executable with 'chmod +x'...") | |
| try: | |
| os.chmod(abs_fastbpe_path, 0o755) | |
| print(f"Successfully made '{abs_fastbpe_path}' executable.") | |
| except OSError as e: | |
| print(f"ERROR: Failed to make fastBPE executable: {e}") | |
| print("Please set execute permissions manually (e.g., 'chmod +x ./fast').") | |
| exit(1) | |
| print("Starting Flask server...") | |
| # Use host='0.0.0.0' to make it accessible on your network | |
| # Set debug=False for production environments | |
| app.run(host='0.0.0.0', port=5000, debug=False) # Changed debug to False | |