Bio-gpt / app.py
Priyansu19's picture
Update Dockerfile to use fastbpe_exec
a89ac91
raw
history blame
46.3 kB
import os
import subprocess
import math
import difflib
import tempfile
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from sacremoses import MosesDetokenizer
from flask import Flask, request, jsonify, render_template
import traceback # Import traceback for better error logging
import re # Import the regular expression module
# --- Constants and Paths ---
# Ensure these files are in the same directory as app.py or provide correct paths
FINETUNED_MODEL_PATH = "hoc_best.pt"
BPE_CODES_PATH = "bpecodes"
DICT_TXT_PATH = "dict.txt"
FASTBPE_BIN_PATH = "./fastbpe_exec" # Assumes fast executable is alongside app.py
HALLMARKS = [ # Keep this consistent with training/evaluation
"activating invasion and metastasis", "avoiding immune destruction",
"cellular energetics", "enabling replicative immortality",
"evading growth suppressors", "genomic instability and mutation",
"inducing angiogenesis", "resisting cell death",
"sustaining proliferative signaling", "tumor promoting inflammation",
]
# --- Model Architecture Definitions (Copy from your notebook) ---
# NOTE: Make sure these classes are IDENTICAL to the ones used for training
# including GPTConfig, LayerNorm, CausalSelfAttention, MLP, Block, GPT, GPTWithSoftPrompt
@dataclass
class GPTConfig:
block_size: int
vocab_size: int
n_layer: int
n_head: int
n_embd: int
dropout: float = 0.0
bias: bool = True
class LayerNorm(nn.Module):
# (Copied from notebook)
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, x):
return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
class CausalSelfAttention(nn.Module):
# (Copied from notebook)
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.flash = hasattr(F, 'scaled_dot_product_attention') # Check for flash attention
if not self.flash:
# print("Warning: Flash Attention not available.") # Optional warning
# Make the buffer persistent otherwise device mismatches during forward pass
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size), persistent=True)
#else:
# print("Using Flash Attention.") # Optional info
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# Ensure bias buffer is used correctly
# Check if bias buffer exists before using it
if hasattr(self, 'bias'):
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
else:
# Fallback if somehow bias wasn't registered (shouldn't happen with persistent=True)
mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
att = att.masked_fill(mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
class MLP(nn.Module):
# (Copied from notebook)
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class Block(nn.Module):
# (Copied from notebook)
def __init__(self, config):
super().__init__()
self.ln1 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln2 = LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
# (Copied from notebook - simplified _init_weights and removed generate)
def __init__(self, config):
super().__init__()
#assert config.vocab_size is not None
#assert config.block_size is not None
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = LayerNorm(config.n_embd, bias=config.bias),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
# init all weights
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
if t > self.config.block_size:
# Crop sequence if longer than block size
print(f"Warning: Input sequence length ({t}) > block size ({self.config.block_size}). Cropping.")
idx = idx[:, -self.config.block_size:]
t = self.config.block_size
#assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
# Check for NaN/Inf in logits before returning
if torch.isnan(logits).any() or torch.isinf(logits).any():
print("WARNING: NaN or Inf detected in logits during inference.")
# Handle appropriately - maybe return an error indicator or zero logits?
# For now, just print warning.
return logits, None
class GPTWithSoftPrompt(nn.Module):
# (Copied from notebook - simplified)
def __init__(self, base_gpt: GPT, prompt_len=1):
super().__init__()
self.config = base_gpt.config
self.transformer = base_gpt.transformer
self.lm_head = base_gpt.lm_head
C = self.config.n_embd
self.soft_prompt = nn.Parameter(torch.zeros(1, prompt_len, C)) # Keep on CPU first
nn.init.normal_(self.soft_prompt, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
device = idx.device # Get device from input tensor
# Make sure soft_prompt is on the same device as input
soft_prompt_on_device = self.soft_prompt.to(device)
# token + pos
tok_emb = self.transformer.wte(idx) # (B,T,C)
pos = torch.arange(0, T, dtype=torch.long, device=device)
pos_emb = self.transformer.wpe(pos) # (T,C)
x_tokens = tok_emb + pos_emb
# prepend soft prompt
soft = soft_prompt_on_device.expand(B, -1, -1) # (B,P,C)
# --- FIX: Define P before the if/else block ---
P = soft.size(1) # Get soft prompt length
x = torch.cat([soft, x_tokens], dim=1) # (B,P+T,C)
# --- Standard Transformer forward pass ---
x = self.transformer.drop(x)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x) # (B,P+T,V)
# --- End Standard ---
if targets is None:
# Inference: return logits for the last token of the *original* sequence
# We need the prediction *after* the last input token, which is at index T (P+T-1 overall)
# Use P which is now defined
# Ensure index is within bounds
target_logit_index = P + T - 1
if target_logit_index >= logits.size(1):
print(f"Warning: Calculated logit index {target_logit_index} out of bounds for logits shape {logits.shape}. Returning last logit.")
target_logit_index = -1 # Fallback to last logit
final_logits = logits[:, target_logit_index, :]
# Check for NaN/Inf
if torch.isnan(final_logits).any() or torch.isinf(final_logits).any():
print(f"WARNING: NaN or Inf detected in final_logits at index {target_logit_index}.")
# Handle appropriately - maybe return zeros or raise an error?
# For now, just print warning. Let the calling function handle it.
return final_logits, None # Return (B, V)
else:
# Training loss calculation (copied from notebook)
# P is already defined above
pad_ignore = torch.full((B, P), -1, dtype=targets.dtype, device=device)
full_targets = torch.cat([pad_ignore, targets], dim=1)
logits_lm = logits[:, :-1, :].contiguous()
targets_lm = full_targets[:, 1:].contiguous()
loss = F.cross_entropy(
logits_lm.view(-1, logits_lm.size(-1)),
targets_lm.view(-1),
ignore_index=-1
)
# Check for NaN/Inf in loss
if torch.isnan(loss) or torch.isinf(loss):
print("WARNING: NaN or Inf detected in loss calculation.")
# Potentially add debugging info here (e.g., print shapes, inputs)
return logits, loss
# --- Constrained generation method (from Section 2.9) ---
@torch.no_grad()
def generate_labels(self, idx, allowed_mask, max_new_tokens=24, temperature=0.0):
self.eval() # Ensure model is in eval mode
B = idx.size(0)
# Add soft prompt length to effective block size consideration
P = self.soft_prompt.size(1)
# Correct effective block size based on GPT class logic
effective_block_size = self.config.block_size # GPT forward handles cropping
# Start with input index
out = idx.clone() # Clone to avoid modifying original input
# Ensure allowed_mask is on the correct device
allowed_mask = allowed_mask.to(idx.device)
finished = torch.zeros(B, dtype=torch.bool, device=idx.device)
# Get global eos_id safely
global eos_id
current_eos_id = eos_id # Use the globally loaded eos_id
for step in range(max_new_tokens):
# Crop context if it exceeds model's block size (GPT forward handles this internally now)
# ctx = out if out.size(1) <= effective_block_size else out[:, -effective_block_size:]
ctx = out # Pass the current sequence
# Forward pass - expects shape (B, T), model handles soft prompt internally
# It returns logits for the *next* token prediction after the last token in ctx
logits, _ = self(ctx) # Gets logits for last token prediction, shape (B, V)
# Check for NaN/Inf in logits
if torch.isnan(logits).any() or torch.isinf(logits).any():
print(f"WARNING: NaN or Inf detected in logits during generation step {step}. Stopping generation.")
# Return what we have so far, excluding potentially bad last token
return out[:, idx.size(1):] # Or handle error differently
# Apply constraint mask
# Ensure mask shape matches logits shape
if logits.shape != allowed_mask.shape:
print(f"Warning: Logits shape {logits.shape} doesn't match mask shape {allowed_mask.shape}. Reshaping mask.")
# This assumes mask needs batch dim added
current_mask = allowed_mask.unsqueeze(0).expand_as(logits)
else:
current_mask = allowed_mask
logits = logits + current_mask
# Sample next token
if temperature <= 0:
# Greedy decoding
next_id = torch.argmax(logits, dim=-1) # (B,)
else:
# Temperature sampling
probs = F.softmax(logits / temperature, dim=-1)
# Check for NaN/Inf in probs
if torch.isnan(probs).any() or torch.isinf(probs).any():
print(f"WARNING: NaN or Inf detected in probabilities during generation step {step}. Using argmax fallback.")
next_id = torch.argmax(logits, dim=-1) # Fallback to greedy
else:
try:
next_id = torch.multinomial(probs, num_samples=1).squeeze(1) # (B,)
except RuntimeError as e:
print(f"WARNING: torch.multinomial failed: {e}. Using argmax fallback.")
next_id = torch.argmax(logits, dim=-1) # Fallback to greedy
# Handle finished sequences (force EOS) and update output
# Check if current_eos_id is valid
if not isinstance(current_eos_id, int):
print(f"Warning: Global eos_id is not an integer ({current_eos_id}). Defaulting to 0.")
current_eos_id = 0
next_id = next_id.masked_fill(finished, current_eos_id) # Use the validated eos_id
# Check if next_id contains invalid values (e.g., negative)
if (next_id < 0).any():
print(f"WARNING: Negative token ID generated: {next_id}. Clipping to 0.")
next_id = torch.clamp(next_id, min=0)
# Append the next token ID
out = torch.cat([out, next_id.unsqueeze(1)], dim=1)
# Update finished status
finished |= (next_id == current_eos_id)
# Stop if all sequences in the batch are finished
if bool(finished.all()):
# print(f"Generation finished early at step {step+1}") # Optional debug info
break
# else:
# print(f"Generation reached max_new_tokens ({max_new_tokens})") # Optional debug info
# Return only the generated part (excluding the initial idx length)
return out[:, idx.size(1):]
# --- Tokenizer Helper Functions ---
# Added robustness and error checks
# Global tokenizer maps and special IDs, loaded once at startup
token2id, id2token = {}, {}
eos_id = 0 # Default, will be overwritten
pad_id = 0 # Default, will be overwritten
detokenizer = None
def load_tokenizer_data(dict_path):
global token2id, id2token, eos_id, pad_id, detokenizer
print(f"Loading vocabulary from {dict_path}...")
local_token2id, local_id2token = {}, {}
try:
with open(dict_path, encoding="utf-8") as f:
for i, line in enumerate(f):
parts = line.split() # Split by whitespace
if not parts: continue # Skip empty lines
tok = parts[0]
if tok in local_token2id:
print(f"Warning: Duplicate token '{tok}' found at line {i+1}. Keeping first occurrence.")
continue
local_token2id[tok] = i
local_id2token[i] = tok
# Assign to global variables only after successful loading
token2id = local_token2id
id2token = local_id2token
# Use a known special token ID if </s> is missing, otherwise default might be wrong
# Try multiple common EOS tokens
possible_eos = ["</s>", "<|endoftext|>", "[EOS]"]
found_eos = False
for eos_tok in possible_eos:
if eos_tok in token2id:
eos_id = token2id[eos_tok]
found_eos = True
print(f"Found EOS token '{eos_tok}' with ID: {eos_id}")
break
if not found_eos:
# If no common EOS found, fall back to the highest index or 0
eos_id = max(token2id.values()) if token2id else 0
print(f"Warning: Standard EOS tokens not found. Using highest index ({eos_id}) as EOS ID.")
# Assign pad_id, often same as eos_id or a specific <pad> token
pad_id = token2id.get("<pad>", eos_id) # Prefer <pad> if exists, else use eos_id
print(f"Using PAD ID: {pad_id}")
detokenizer = MosesDetokenizer(lang='en') # Initialize once
print(f"Vocabulary loaded. Size: {len(token2id)}")
if not detokenizer:
raise ValueError("MosesDetokenizer failed to initialize.")
except FileNotFoundError:
print(f"ERROR: Vocabulary file not found at {dict_path}")
raise
except Exception as e:
print(f"ERROR: Failed to load tokenizer data from {dict_path}: {e}")
raise
def bpe_encode_lines(lines, shard_size=500, desc="BPE Encode"):
""" Encodes lines using external fastBPE binary. Added error checking. """
global BPE_CODES_PATH, FASTBPE_BIN_PATH
# --- Input Validation ---
if not isinstance(lines, list):
print(f"Warning: bpe_encode_lines expected a list, got {type(lines)}. Attempting conversion.")
try:
lines = list(lines)
except TypeError:
raise ValueError("Input 'lines' must be a list or convertible to a list.")
if not lines: return []
# --- Path and Executable Checks ---
abs_fastbpe_path = os.path.abspath(FASTBPE_BIN_PATH)
abs_bpe_codes_path = os.path.abspath(BPE_CODES_PATH)
if not os.path.exists(abs_fastbpe_path):
raise FileNotFoundError(f"fastBPE executable not found at {abs_fastbpe_path}")
if not os.path.exists(abs_bpe_codes_path):
raise FileNotFoundError(f"BPE codes file not found at {abs_bpe_codes_path}")
if not os.access(abs_fastbpe_path, os.X_OK):
print(f"Warning: fastBPE binary at {abs_fastbpe_path} is not executable. Attempting chmod...")
try:
os.chmod(abs_fastbpe_path, 0o755)
except OSError as e:
raise PermissionError(f"Failed to make fastBPE executable: {e}. Please check permissions.")
out_tokens = []
# Process in chunks
with tempfile.TemporaryDirectory() as td:
for start in range(0, len(lines), shard_size):
chunk = lines[start:start+shard_size]
src_path = os.path.join(td, f"src_{start}.txt")
dst_path = os.path.join(td, f"dst_{start}.bpe")
try:
# Write chunk to temp file, ensuring strings
with open(src_path, "w", encoding="utf-8") as f:
for s in chunk:
f.write(str(s or "").strip() + "\n") # Ensure string conversion
# Run fastBPE
cmd = [abs_fastbpe_path, "applybpe", dst_path, src_path, abs_bpe_codes_path]
# print(f"Running command: {' '.join(cmd)}") # Debug command
process = subprocess.run(
cmd,
capture_output=True, text=True, check=False # Don't check=True here, handle error below
)
# Check for errors specifically
if process.returncode != 0:
# Log more details on failure
print(f"ERROR: fastBPE failed (exit code {process.returncode}) on chunk starting at index {start}.")
print(f"Command: {' '.join(cmd)}")
print(f"Stderr:\n{process.stderr}")
# Optionally print some input data
print(f"First line of input chunk: {chunk[0] if chunk else 'N/A'}")
raise subprocess.CalledProcessError(process.returncode, cmd, output=process.stdout, stderr=process.stderr)
# Read results if successful
with open(dst_path, "r", encoding="utf-8") as f:
for line in f:
out_tokens.append(line.strip().split())
except subprocess.CalledProcessError as e:
# Handle specific subprocess errors (already printed details)
raise # Re-raise to stop execution
except Exception as e:
print(f"ERROR: Unexpected error during BPE encoding chunk starting at index {start}: {e}")
traceback.print_exc() # Print full traceback for unexpected errors
raise # Re-raise
return out_tokens
def tokens_to_ids(bpe_tokens):
""" Converts BPE token strings to IDs using the global map. Added checks. """
global token2id, pad_id
if not isinstance(bpe_tokens, list):
raise ValueError(f"Input 'bpe_tokens' must be a list, got {type(bpe_tokens)}.")
ids = []
oov_count = 0
for t in bpe_tokens:
if not isinstance(t, str):
print(f"Warning: Non-string token found in bpe_tokens: {t}. Using pad_id.")
ids.append(pad_id)
oov_count += 1
continue
id_val = token2id.get(t, pad_id)
ids.append(id_val)
if id_val == pad_id and t not in token2id:
oov_count += 1
# print(f"Warning: OOV token '{t}' mapped to pad_id {pad_id}") # Reduce noise
if oov_count > 0:
print(f"Info: Found {oov_count} OOV tokens in sequence of length {len(bpe_tokens)}.")
return ids, oov_count
def ids_to_tokens(ids):
""" Converts IDs back to token strings. Added checks. """
global id2token
if not isinstance(ids, list):
raise ValueError(f"Input 'ids' must be a list, got {type(ids)}.")
tokens = []
for i in ids:
# Ensure ID is a valid integer before lookup
try:
# Handle potential floats or NaNs from generation
if isinstance(i, float) and math.isnan(i):
token = "<nan>"
else:
int_i = int(i)
token = id2token.get(int_i, "<unk>")
except (ValueError, TypeError):
print(f"Warning: Could not convert ID '{i}' to int. Using '<unk>'.")
token = "<unk>"
tokens.append(token)
return tokens
def bpe_decode_tokens(bpe_tokens):
""" Converts BPE token strings back to readable text. Added checks. """
global detokenizer
if detokenizer is None:
raise RuntimeError("Detokenizer not initialized. Call load_tokenizer_data first.")
if not isinstance(bpe_tokens, list):
raise ValueError(f"Input 'bpe_tokens' must be a list, got {type(bpe_tokens)}.")
# Ensure all items are strings before joining
try:
str_tokens = [str(t) for t in bpe_tokens]
except Exception as e:
print(f"Error converting tokens to strings: {e}. Tokens: {bpe_tokens}")
return "<decoding error>"
s = ' '.join(str_tokens).replace('@@ ', '')
try:
# Detokenizer might fail on empty or unusual input
return detokenizer.detokenize(s.split()) if s.strip() else ""
except Exception as e:
print(f"Error during detokenization: {e}. Input string: '{s}'")
return "<detokenization error>"
# --- Prediction Helper Functions ---
def to_canonical(pred_chunk: str):
""" Maps a predicted text chunk to a canonical hallmark name. Added checks. """
global HALLMARKS
# Ensure input is a string
if not isinstance(pred_chunk, str):
# print(f"Warning: to_canonical received non-string input: {pred_chunk}. Returning None.")
return None
s = pred_chunk.strip().lower()
low = [L.lower() for L in HALLMARKS]
if not s: return None
if s in low:
return HALLMARKS[low.index(s)]
# Use difflib for fuzzy matching
try:
best = difflib.get_close_matches(s, low, n=1, cutoff=0.7)
return HALLMARKS[low.index(best[0])] if best else None
except Exception as e:
print(f"Error during difflib matching for '{s}': {e}")
return None # Return None on error
def build_allowed_token_mask(vocab_size, device):
""" Builds the mask for constrained decoding. Added error checks. """
global HALLMARKS, token2id, eos_id, pad_id
allowed = set()
# --- Input Validation ---
if vocab_size <= 0:
raise ValueError("Vocabulary size must be positive.")
if not token2id:
raise RuntimeError("Tokenizer vocabulary (token2id) not loaded.")
print("Encoding hallmarks for mask...")
try:
# Ensure HALLMARKS is a list of strings
if not isinstance(HALLMARKS, list) or not all(isinstance(h, str) for h in HALLMARKS):
raise ValueError("HALLMARKS must be a list of strings.")
hallmark_bpes = bpe_encode_lines(HALLMARKS, desc="BPE Hallmarks (for mask)")
for bpe_list in hallmark_bpes:
ids, _ = tokens_to_ids(bpe_list)
allowed.update(ids)
print(f"Encoded {len(HALLMARKS)} hallmarks.")
except Exception as e:
print(f"ERROR: Failed to BPE encode or convert hallmarks for mask: {e}")
raise
print("Encoding separators for mask...")
SEPS = [", ", ",", "; ", ";", "|"]
try:
sep_bpes = bpe_encode_lines(SEPS, desc="BPE Separators (for mask)")
for bpe_list in sep_bpes:
ids, _ = tokens_to_ids(bpe_list)
allowed.update(ids)
print(f"Encoded {len(SEPS)} separators.")
except Exception as e:
print(f"ERROR: Failed to BPE encode or convert separators for mask: {e}")
raise
# Add EOS token - Check if eos_id is valid
if not isinstance(eos_id, int) or eos_id < 0 or eos_id >= vocab_size:
print(f"Warning: Invalid EOS ID ({eos_id}). Defaulting mask EOS to 0.")
effective_eos_id = 0
else:
effective_eos_id = eos_id
allowed.add(effective_eos_id)
print(f"Total allowed token IDs (including EOS {effective_eos_id}): {len(allowed)}")
# Create the mask tensor on CPU first
mask = torch.full((vocab_size,), float('-inf'), device=torch.device('cpu'))
try:
# Filter out potential invalid IDs before creating list for indexing
# Ensure pad_id is valid if used for filtering
effective_pad_id = pad_id if isinstance(pad_id, int) and 0 <= pad_id < vocab_size else -1 # Use -1 if pad_id is invalid
valid_allowed_ids = []
for id_ in allowed:
if isinstance(id_, int) and 0 <= id_ < vocab_size: # Check type and range
# Filter out pad_id unless it's the same as the effective_eos_id
if id_ != effective_pad_id or id_ == effective_eos_id:
valid_allowed_ids.append(id_)
# else: print(f"Warning: Invalid ID {id_} in allowed set skipped.") # Reduce noise
if not valid_allowed_ids:
raise ValueError("No valid token IDs found to allow in the mask.")
# Check ranges again after filtering (belt and braces)
max_valid_id = max(valid_allowed_ids)
min_valid_id = min(valid_allowed_ids)
if max_valid_id >= vocab_size or min_valid_id < 0:
# This should ideally not happen if filtering worked
raise IndexError(f"Filtered allowed IDs still out of range [{min_valid_id}, {max_valid_id}] for vocab size {vocab_size}.")
# Apply mask
mask[valid_allowed_ids] = 0.0 # Use list directly
print(f"Mask created with {len(valid_allowed_ids)} allowed indices.")
except IndexError as e:
print(f"ERROR: Index error while creating mask. Vocab size: {vocab_size}. Error: {e}")
# Find problematic IDs more carefully
problem_ids = [i for i in allowed if not isinstance(i, int) or i < 0 or i >= vocab_size]
print(f"Problematic IDs in allowed set: {problem_ids}")
raise
except Exception as e:
print(f"ERROR: Unexpected error creating mask: {e}")
traceback.print_exc()
raise
# Move final mask to target device
try:
target_device = torch.device(device) # Ensure device is a torch.device object
return mask.to(target_device)
except Exception as e:
print(f"Error moving mask to device '{device}': {e}")
raise
# --- Global Variables for Loaded Model and Assets ---
inference_model = None
ALLOWED_MASK = None
model_device = "cpu"
config = None # Added global config
# --- Initialization Function ---
def initialize_model_and_tokenizer():
global inference_model, ALLOWED_MASK, model_device, token2id, config # Add config
print("Initializing model...")
# Determine device
model_device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {model_device}")
# Load tokenizer data first (essential for vocab size)
try:
load_tokenizer_data(DICT_TXT_PATH)
if not token2id: # Check if loading actually populated the dict
raise ValueError("Tokenizer loading failed to populate token2id dictionary.")
except Exception as e:
print(f"FATAL: Could not load tokenizer data. Cannot proceed. Error: {e}")
return False # Indicate failure
# Define model config (MUST match finetuning config)
try:
# Ensure config is globally accessible after definition
config = GPTConfig(
vocab_size=len(token2id), # Get vocab size from loaded data
block_size=128, # Match training
n_layer=6, # Match training
n_head=6, # Match training
n_embd=384, # Match training
dropout=0.1, # Match training (dropout is off in eval mode)
bias=True # Match training
)
print(f"Model Config: {config}")
except Exception as e:
print(f"FATAL: Error creating GPTConfig: {e}")
return False
# Instantiate base and wrapped model (on CPU initially)
try:
base_gpt = GPT(config)
inference_model = GPTWithSoftPrompt(base_gpt, prompt_len=1)
except Exception as e:
print(f"FATAL: Error instantiating model: {e}")
traceback.print_exc()
return False
# Load finetuned weights
print(f"Loading finetuned weights from: {FINETUNED_MODEL_PATH}")
if not os.path.exists(FINETUNED_MODEL_PATH):
print(f"ERROR: Model weights file not found at {FINETUNED_MODEL_PATH}")
return False
try:
# Load state dict onto CPU first
state_dict = torch.load(FINETUNED_MODEL_PATH, map_location='cpu')
# Clean state dict keys (handle DDP 'module.' prefix)
cleaned_state_dict = {}
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
cleaned_state_dict[name] = v
# Load into model
missing_keys, unexpected_keys = inference_model.load_state_dict(cleaned_state_dict, strict=False)
if missing_keys:
# Filter out non-persistent buffer keys if necessary (though strict=False should handle this)
missing_persistent = [k for k in missing_keys if inference_model.get_parameter(k) is not None or inference_model.get_buffer(k) is not None]
if missing_persistent:
print("Warning: Missing persistent keys during state dict load:", missing_persistent)
if unexpected_keys:
print("Warning: Unexpected keys during state dict load:", unexpected_keys)
print("Weights loaded successfully.")
except Exception as e:
print(f"Error loading state dict from {FINETUNED_MODEL_PATH}: {e}")
print("Ensure the model architecture matches the saved checkpoint and the file is not corrupted.")
traceback.print_exc()
return False
# Move model to target device and set to eval mode
try:
inference_model.to(model_device)
inference_model.eval()
print(f"Model moved to device: {model_device} and set to eval mode.")
except Exception as e:
print(f"Error moving model to device '{model_device}': {e}")
traceback.print_exc()
return False
# Build the allowed token mask (after model is on device)
print("Building allowed token mask...")
try:
if config.vocab_size <= 0:
raise ValueError("Vocabulary size must be positive to build mask.")
# Ensure model_device is valid before passing
device_obj = torch.device(model_device)
ALLOWED_MASK = build_allowed_token_mask(config.vocab_size, device_obj)
print("Allowed token mask created.")
except Exception as e:
print(f"ERROR: Failed to build allowed token mask: {e}")
traceback.print_exc()
return False
return True # Indicate success
# --- Inference Function ---
def predict_hallmarks(abstract_text):
global inference_model, ALLOWED_MASK, model_device, token2id, eos_id
# --- Pre-computation Checks ---
if inference_model is None:
print("Error: Inference model is not loaded.")
return ["Error: Model not loaded"]
if ALLOWED_MASK is None:
print("Error: Allowed mask is not built.")
return ["Error: Mask not built"]
if not token2id:
print("Error: Tokenizer vocabulary not loaded.")
return ["Error: Tokenizer not loaded"]
# --- Input Validation ---
if not isinstance(abstract_text, str):
print(f"Warning: Received non-string abstract text type: {type(abstract_text)}. Attempting conversion.")
try:
abstract_text = str(abstract_text)
except Exception:
return ["Error: Invalid input type"]
if not abstract_text.strip():
print("Warning: Received empty or whitespace-only abstract text.")
return [] # Return empty list for empty input
try:
# --- 1. Preprocess and Tokenize Input ---
print("Tokenizing input abstract...")
cleaned_abstract = " ".join(abstract_text.split())
if not cleaned_abstract:
print("Warning: Input abstract contains only whitespace after cleaning.")
return []
bpe_tokens_list = bpe_encode_lines([cleaned_abstract])
if not bpe_tokens_list or not bpe_tokens_list[0]: # Check if list or first element is empty
print("Warning: BPE encoding resulted in empty tokens.")
return []
bpe_tokens = bpe_tokens_list[0]
input_ids_list, oov = tokens_to_ids(bpe_tokens)
if oov > 0:
print(f"Info: Input contained {oov} OOV tokens.")
# Add EOS token
input_ids = input_ids_list + [eos_id]
# Convert to tensor and move to device
input_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(model_device)
# --- 2. Generate Predictions ---
print("Generating predictions...")
with torch.no_grad():
generated_ids_tensor = inference_model.generate_labels(
input_tensor,
allowed_mask=ALLOWED_MASK,
max_new_tokens=30,
temperature=0.0
)
# --- 3. Decode and Post-process ---
print("Decoding and cleaning predictions...")
if generated_ids_tensor is None or generated_ids_tensor.numel() == 0:
print("Warning: Generation resulted in empty tensor.")
generated_ids = []
else:
# Ensure tensor is on CPU before converting to list
generated_ids = generated_ids_tensor[0].cpu().tolist()
if not generated_ids:
print("No tokens generated.")
return []
generated_tokens = ids_to_tokens(generated_ids)
# Remove tokens after EOS if present
try:
eos_token_str = id2token.get(eos_id, "</s>") # Get string representation
if eos_token_str in generated_tokens:
eos_index = generated_tokens.index(eos_token_str)
generated_tokens = generated_tokens[:eos_index]
except ValueError:
pass # EOS not found is okay
# Decode BPE tokens to string
generated_text = bpe_decode_tokens(generated_tokens).strip().lower()
print(f"Raw generated text: '{generated_text}'")
# Split potential multi-labels and map to canonical
parts = []
if generated_text:
potential_parts = re.split(r'[;,|]\s*', generated_text)
parts = [p.strip() for p in potential_parts if p.strip()]
if not parts: # Handle case with no delimiters
parts = [generated_text]
predicted_labels = []
seen_labels = set()
for p in parts:
canonical_label = to_canonical(p)
if canonical_label and canonical_label not in seen_labels:
predicted_labels.append(canonical_label)
seen_labels.add(canonical_label)
print(f"Final predicted labels: {predicted_labels}")
return predicted_labels
# --- Error Handling ---
except FileNotFoundError as fnf_err:
print(f"ERROR during prediction (File Not Found - likely BPE related): {fnf_err}")
traceback.print_exc()
return ["Error: BPE file processing error"]
except PermissionError as perm_err:
print(f"ERROR during prediction (Permission Error - likely fastBPE): {perm_err}")
traceback.print_exc()
return ["Error: BPE execution permission"]
except RuntimeError as run_err:
if "CUDA out of memory" in str(run_err):
print(f"ERROR: CUDA Out of Memory during prediction. Input length: {len(input_ids) if 'input_ids' in locals() else 'N/A'}")
traceback.print_exc()
return ["Error: Input too long (OOM)"]
else:
print(f"ERROR during prediction (PyTorch RuntimeError): {run_err}")
traceback.print_exc()
return ["Error: Model runtime error"]
except Exception as e:
print(f"ERROR during prediction (General Exception): {e}")
traceback.print_exc()
return [f"Error: An unexpected error occurred"]
# --- Flask App ---
app = Flask(__name__)
# --- Load Model on Startup ---
model_initialized = False
@app.before_request
def ensure_model_loaded():
""" Ensures model is loaded before handling the first request. """
global model_initialized
if not model_initialized:
print("First request received, attempting to initialize model...")
# Add basic locking if deploying with multiple workers (though not fully thread-safe here)
# For true multi-worker safety, model loading should happen before workers fork.
try:
if initialize_model_and_tokenizer():
model_initialized = True
print("Model initialization successful.")
else:
print("FATAL: Model initialization failed during first request.")
# We won't raise an error here, but subsequent requests will fail until fixed.
except Exception as init_err:
print(f"FATAL: Exception during model initialization: {init_err}")
traceback.print_exc()
# --- Routes ---
@app.route('/')
def home():
""" Renders the HTML frontend page. """
# Check if initialization failed and show an error page if so?
# For simplicity, we assume initialization works or subsequent predict calls fail.
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
""" Handles prediction requests from the frontend. """
global model_initialized
# Check if model is ready
if not model_initialized:
print("Error: Model not initialized when /predict called.")
# Return a specific status code like Service Unavailable
return jsonify({'error': 'Model is not ready. Please try again later.'}), 503
# Validate request format
if not request.is_json:
return jsonify({'error': 'Request must be JSON'}), 400
data = request.get_json()
abstract = data.get('abstract')
# Validate input abstract
if not abstract:
return jsonify({'error': 'Missing "abstract" field in JSON request'}), 400
if not isinstance(abstract, str):
return jsonify({'error': '"abstract" field must be a string'}), 400
if len(abstract.strip()) == 0:
print("Received empty abstract, returning empty prediction.")
return jsonify({'predictions': []})
MAX_ABSTRACT_LEN = 10000 # Define max length
if len(abstract) > MAX_ABSTRACT_LEN:
print(f"Received overly long abstract ({len(abstract)} chars), rejecting.")
return jsonify({'error': f'Input abstract is too long (max {MAX_ABSTRACT_LEN} chars)'}), 413 # Payload Too Large
print(f"\n--- Received Prediction Request ---")
print(f"Input Abstract (first 100 chars): {abstract[:100]}...")
try:
# Perform prediction
predictions = predict_hallmarks(abstract)
print(f"--- Prediction Complete ---")
# Check if the result indicates an internal error occurred
if isinstance(predictions, list) and len(predictions) > 0 and predictions[0].startswith("Error:"):
print(f"Internal error during prediction: {predictions[0]}")
# Return a generic server error to the client
return jsonify({'error': 'An internal error occurred during prediction.'}), 500
else:
# Return successful predictions
return jsonify({'predictions': predictions})
except Exception as e:
# Catch unexpected errors in the route handler itself
print(f"--- Prediction Failed Unexpectedly in Route ---")
print(f"Error: {e}")
traceback.print_exc()
return jsonify({'error': 'An internal server error occurred.'}), 500
# --- Run the App ---
if __name__ == '__main__':
# Initialize model eagerly when running script directly
if not model_initialized:
print("Running script directly, initializing model eagerly...")
if initialize_model_and_tokenizer():
model_initialized = True
print("Model initialization successful.")
else:
print("FATAL: Model initialization failed. Cannot start Flask server.")
exit(1) # Exit if model fails to load on startup
# Check fastBPE path validity before starting server
abs_fastbpe_path = os.path.abspath(FASTBPE_BIN_PATH)
if not os.path.exists(abs_fastbpe_path):
print(f"ERROR: fastBPE binary not found at '{abs_fastbpe_path}'.")
print("Please ensure fastBPE is compiled and the path is correct relative to app.py.")
exit(1)
if not os.access(abs_fastbpe_path, os.X_OK):
print(f"ERROR: fastBPE binary at '{abs_fastbpe_path}' is not executable.")
print("Attempting to make it executable with 'chmod +x'...")
try:
os.chmod(abs_fastbpe_path, 0o755)
print(f"Successfully made '{abs_fastbpe_path}' executable.")
except OSError as e:
print(f"ERROR: Failed to make fastBPE executable: {e}")
print("Please set execute permissions manually (e.g., 'chmod +x ./fast').")
exit(1)
print("Starting Flask server...")
# Use host='0.0.0.0' to make it accessible on your network
# Set debug=False for production environments
app.run(host='0.0.0.0', port=5000, debug=False) # Changed debug to False