import torch, transformers, tensorflow as tf from pathlib import Path from transformers import PreTrainedModel, AutoModelForCausalLM from huggingface_hub import hf_hub_download class SafeGenerationModel(PreTrainedModel): """ Model-agnostic toxicity-filter wrapper. Instantiates the correct backbone for ANY causal-LM config, then intercepts generate() to filter toxic prompts & completions. """ # ------------------------------------------------------------------ # A. Standard constructor # ------------------------------------------------------------------ def __init__(self, config, *model_args, **model_kwargs): super().__init__(config) # 1) Dynamically build the *real* model class that matches this config self.base_model = AutoModelForCausalLM.from_config( config, trust_remote_code=True ) # 2) Lazy-load toxicity classifier (loaded on first use) self._toxicity_model = None self.toxicity_threshold = 0.6 # 3) Tokenizer (needed for prompt/output decoding) try: self.tokenizer = transformers.AutoTokenizer.from_pretrained( config.name_or_path, trust_remote_code=True ) except Exception: self.tokenizer = None # ------------------------------------------------------------------ # B. Internal helpers # ------------------------------------------------------------------ @property def toxicity_model(self): if self._toxicity_model is None: path = hf_hub_download( repo_id=self.config.name_or_path, filename="toxic.keras" ) self._toxicity_model = tf.keras.models.load_model(path, compile=False) return self._toxicity_model def _is_toxic(self, text: str) -> bool: if not text.strip(): return False score = float(self.toxicity_model.predict([text])[0, 0]) return score >= self.toxicity_threshold def _safe_ids(self, msg: str, length=None): if self.tokenizer is None: raise RuntimeError("Tokenizer missing; cannot build safe reply.") ids = self.tokenizer(msg, return_tensors="pt")["input_ids"][0] if length is not None: pad = ( self.config.eos_token_id if self.config.eos_token_id is not None else (self.config.pad_token_id or 0) ) if ids.size(0) < length: ids = torch.cat( [ids, torch.full((length - ids.size(0),), pad, dtype=torch.long)], dim=0, ) else: ids = ids[:length] return ids.to(self.device) # ------------------------------------------------------------------ # C. Forward simply proxies to backbone # ------------------------------------------------------------------ def forward(self, *args, **kwargs): return self.base_model(*args, **kwargs) # ------------------------------------------------------------------ # D. generate() override with toxicity checks # ------------------------------------------------------------------ def generate(self, *args, **kwargs): SAFE_MSG = "Response is toxic, please be kind to yourself and others." # ---------- 1. Check prompt ---------- prompt_text = None if "input_ids" in kwargs and self.tokenizer is not None: prompt_text = self.tokenizer.decode( kwargs["input_ids"][0], skip_special_tokens=True ) elif args and self.tokenizer is not None: prompt_text = self.tokenizer.decode( args[0][0], skip_special_tokens=True ) if prompt_text and self._is_toxic(prompt_text): return self._safe_ids(SAFE_MSG).unsqueeze(0) # ---------- 2. Normal generation ---------- outputs = self.base_model.generate(*args, **kwargs) if self.tokenizer is None: return outputs # cannot decode → skip toxicity check outputs_cpu = outputs.detach().cpu() safe = [] for seq in outputs_cpu: txt = self.tokenizer.decode(seq, skip_special_tokens=True) if self._is_toxic(txt): safe.append(self._safe_ids(SAFE_MSG, length=seq.size(0))) else: safe.append(seq) return torch.stack(safe, dim=0).to(self.device)