Spaces:
Running
Running
| import os | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| from datetime import datetime | |
| import gradio as gr | |
| from typing import Dict, List, Union, Optional | |
| import logging | |
| import traceback | |
| import asyncio | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ContentAnalyzer: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = None | |
| self.tokenizer = None | |
| self.batch_size = 2 # Reduced batch size for deeper thinking | |
| self.max_thinking_time = 30 # Maximum seconds per batch for reasoning | |
| self.trigger_categories = { | |
| "Violence": { | |
| "mapped_name": "Violence", | |
| "description": "Physical force, aggression, or actions causing harm to living beings or property." | |
| }, | |
| "Death": { | |
| "mapped_name": "Death References", | |
| "description": "Direct or implied loss of life, mortality discussions, or death-related events." | |
| }, | |
| "Substance_Use": { | |
| "mapped_name": "Substance Use", | |
| "description": "Usage or discussion of drugs, alcohol, or addictive substances." | |
| }, | |
| "Gore": { | |
| "mapped_name": "Gore", | |
| "description": "Graphic depictions of injuries, blood, or severe bodily harm." | |
| }, | |
| "Sexual_Content": { | |
| "mapped_name": "Sexual Content", | |
| "description": "Sexual activity, intimacy, or explicit sexual references." | |
| }, | |
| "Sexual_Abuse": { | |
| "mapped_name": "Sexual Abuse", | |
| "description": "Non-consensual sexual acts, exploitation, or sexual violence." | |
| }, | |
| "Self_Harm": { | |
| "mapped_name": "Self-Harm", | |
| "description": "Self-inflicted injury, suicidal thoughts, or destructive behaviors." | |
| }, | |
| "Mental_Health": { | |
| "mapped_name": "Mental Health Issues", | |
| "description": "Psychological distress, mental disorders, or emotional trauma." | |
| } | |
| } | |
| logger.info(f"Initialized analyzer with device: {self.device}") | |
| async def load_model(self, progress=None) -> None: | |
| """Load the model and tokenizer with progress updates.""" | |
| try: | |
| if progress: | |
| progress(0.1, "Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "LGAI-EXAONE/EXAONE-Deep-2.4B", | |
| use_fast=True, | |
| trust_remote_code=True | |
| ) | |
| if progress: | |
| progress(0.3, "Loading model...") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| "LGAI-EXAONE/EXAONE-Deep-2.4B", | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| if self.device == "cuda": | |
| self.model.eval() | |
| torch.cuda.empty_cache() | |
| if progress: | |
| progress(0.5, "Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| raise | |
| def _chunk_text(self, text: str, chunk_size: int = 20000, overlap: int = 100) -> List[str]: | |
| """Split text into overlapping chunks.""" | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk = ' '.join(words[i:i + chunk_size]) | |
| chunks.append(chunk) | |
| return chunks | |
| def _validate_response(self, response: str) -> str: | |
| """Validate and clean model response.""" | |
| valid_responses = {"YES", "NO", "MAYBE"} | |
| response = response.strip().upper() | |
| first_word = response.split()[0] if response else "NO" | |
| return first_word if first_word in valid_responses else "NO" | |
| async def _generate_outputs(self, inputs): | |
| """Helper method to generate outputs with torch.no_grad().""" | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=500, | |
| temperature=0.3, # Lower temperature for more focused responses | |
| top_p=0.95, # Slightly higher to ensure valid responses | |
| top_k=10, # Reduced to limit vocabulary to relevant tokens | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| do_sample=True # Keep sampling for slight variation | |
| ) | |
| return outputs | |
| async def analyze_chunks_batch( | |
| self, | |
| chunks: List[str], | |
| progress: Optional[gr.Progress] = None, | |
| current_progress: float = 0, | |
| progress_step: float = 0 | |
| ) -> Dict[str, float]: | |
| """Analyze multiple chunks in batches.""" | |
| all_triggers = {} | |
| for category, info in self.trigger_categories.items(): | |
| mapped_name = info["mapped_name"] | |
| description = info["description"] | |
| for i in range(0, len(chunks), self.batch_size): | |
| batch_chunks = chunks[i:i + self.batch_size] | |
| prompts = [] | |
| for chunk in batch_chunks: | |
| prompt = f"Analyze text for {mapped_name}. Definition: {description}. Content: \"{chunk}\". Answer YES/NO/MAYBE based on clear evidence." | |
| prompts.append(prompt) | |
| try: | |
| inputs = self.tokenizer( | |
| prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ).to(self.device) | |
| outputs = await asyncio.wait_for( | |
| self._generate_outputs(inputs), | |
| timeout=self.max_thinking_time | |
| ) | |
| responses = [ | |
| self.tokenizer.decode(output, skip_special_tokens=True) | |
| for output in outputs | |
| ] | |
| for response in responses: | |
| validated_response = self._validate_response(response) | |
| if validated_response == "YES": | |
| all_triggers[mapped_name] = all_triggers.get(mapped_name, 0) + 1 | |
| elif validated_response == "MAYBE": | |
| all_triggers[mapped_name] = all_triggers.get(mapped_name, 0) + 0.5 | |
| except asyncio.TimeoutError: | |
| logger.error(f"Timeout processing batch for {mapped_name}") | |
| continue | |
| except Exception as e: | |
| logger.error(f"Error processing batch for {mapped_name}: {str(e)}") | |
| continue | |
| if progress: | |
| current_progress += progress_step | |
| progress(min(current_progress, 0.9), f"Analyzing {mapped_name}...") | |
| return all_triggers | |
| async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]: | |
| """Analyze the entire script.""" | |
| if not self.model or not self.tokenizer: | |
| await self.load_model(progress) | |
| chunks = self._chunk_text(script) | |
| identified_triggers = await self.analyze_chunks_batch( | |
| chunks, | |
| progress, | |
| current_progress=0.5, | |
| progress_step=0.4 / (len(chunks) * len(self.trigger_categories)) | |
| ) | |
| if progress: | |
| progress(0.95, "Finalizing results...") | |
| final_triggers = [] | |
| chunk_threshold = max(1, len(chunks) * 0.1) | |
| for mapped_name, count in identified_triggers.items(): | |
| if count >= chunk_threshold: | |
| final_triggers.append(mapped_name) | |
| return final_triggers if final_triggers else ["None"] | |
| async def analyze_content( | |
| script: str, | |
| progress: Optional[gr.Progress] = None | |
| ) -> Dict[str, Union[List[str], str]]: | |
| """Main analysis function for the Gradio interface.""" | |
| logger.info("Starting content analysis") | |
| analyzer = ContentAnalyzer() | |
| try: | |
| # Fix: Use the analyzer instance's method instead of undefined function | |
| triggers = await analyzer.analyze_script(script, progress) | |
| if progress: | |
| progress(1.0, "Analysis complete!") | |
| result = { | |
| "detected_triggers": triggers, | |
| "confidence": "High - Content detected" if triggers != ["None"] else "High - No concerning content detected", | |
| "model": "LGAI-EXAONE/EXAONE-Deep-2.4B", | |
| "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| } | |
| logger.info(f"Analysis complete: {result}") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Analysis error: {str(e)}") | |
| return { | |
| "detected_triggers": ["Error occurred during analysis"], | |
| "confidence": "Error", | |
| "model": "LGAI-EXAONE/EXAONE-Deep-2.4B", | |
| "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| "error": str(e) | |
| } | |
| if __name__ == "__main__": | |
| iface = gr.Interface( | |
| fn=analyze_content, | |
| inputs=gr.Textbox(lines=8, label="Input Text"), | |
| outputs=gr.JSON(), | |
| title="Content Trigger Analysis", | |
| description="Analyze text content for sensitive topics and trigger warnings" | |
| ) | |
| iface.launch() |