import os import torch from flask import Blueprint, request, jsonify, render_template from transformers import pipeline from huggingface_hub import HfFolder # Define the Blueprint summarize_bp = Blueprint('summarize', __name__) # Global cache to store the loaded model in memory # This prevents reloading the model on every single request MODEL_CACHE = { "model_name": None, "pipeline": None } def get_pipeline(model_name, task_type, hf_token): """ Retrieves a pipeline from cache or loads it if it's new. """ global MODEL_CACHE # If we already have this model loaded, return it if MODEL_CACHE["model_name"] == model_name and MODEL_CACHE["pipeline"] is not None: return MODEL_CACHE["pipeline"] # Authentication if hf_token: HfFolder.save_token(hf_token) # Determine device device = 0 if torch.cuda.is_available() else -1 dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Load Pipeline print(f"Loading model: {model_name}...") if task_type == 'SEQ_2_SEQ_LM': # Summarization pipe = pipeline("summarization", model=model_name, device=device) elif task_type == 'TOKEN_CLS': pipe = pipeline("token-classification", model=model_name, aggregation_strategy="simple") else: pipe = pipeline("text-generation", model=model_name, torch_dtype=dtype, device=device) # Update Cache MODEL_CACHE["model_name"] = model_name MODEL_CACHE["pipeline"] = pipe return pipe def run_inference_logic(config): model_id = config['model_name'] text = config['text'] task_type = config['task_type'] hf_token = config['hf_token'] pipe = get_pipeline(model_id, task_type, hf_token) if task_type == 'TOKEN_CLS': results = pipe(text) results_sorted = sorted(results, key=lambda x: x['start'], reverse=True) masked_list = list(text) for ent in results_sorted: masked_list[ent['start']:ent['end']] = list(f"<{ent['entity_group']}>") return { "masked_text": "".join(masked_list), "labels": [r['entity_group'] for r in results] } elif task_type == 'SEQ_2_SEQ_LM': out = pipe( text, max_length=512, min_length=30, do_sample=True, temperature=float(config['temp']), top_k=int(config['topk']) ) return {"output": out[0]['summary_text']} else: out = pipe(text, max_new_tokens=1024) return {"output": out[0]['generated_text']} # --- Routes --- @summarize_bp.route('/', methods=['GET']) def index(): """Renders the UI.""" return render_template('inference.html') @summarize_bp.route('/api/summarize', methods=['POST']) def api_summarize(): """API Endpoint to handle the AJAX request from the UI.""" data = request.get_json() if not data or 'text' not in data: return jsonify({"error": "No text provided"}), 400 hf_token = os.environ.get("HF_TOKEN","") config = { "text": data['text'], "model_name": data.get('model_name', "facebook/bart-large-cnn"), "hf_token": data.get('hf_token', hf_token), "temp": data.get('temp','0.7'), "topk": data.get('topk','50'), # We force this for the specific summarization UI, # but the backend logic supports others. "task_type": "SEQ_2_SEQ_LM" } try: result = run_inference_logic(config) return jsonify(result) except Exception as e: print(f"Error: {e}") return jsonify({"error": str(e)}), 500