Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from flask import Blueprint, request, jsonify, render_template | |
| from transformers import pipeline | |
| from huggingface_hub import HfFolder | |
| # Define the Blueprint | |
| summarize_bp = Blueprint('summarize', __name__) | |
| # Global cache to store the loaded model in memory | |
| # This prevents reloading the model on every single request | |
| MODEL_CACHE = { | |
| "model_name": None, | |
| "pipeline": None | |
| } | |
| def get_pipeline(model_name, task_type, hf_token): | |
| """ | |
| Retrieves a pipeline from cache or loads it if it's new. | |
| """ | |
| global MODEL_CACHE | |
| # If we already have this model loaded, return it | |
| if MODEL_CACHE["model_name"] == model_name and MODEL_CACHE["pipeline"] is not None: | |
| return MODEL_CACHE["pipeline"] | |
| # Authentication | |
| if hf_token: | |
| HfFolder.save_token(hf_token) | |
| # Determine device | |
| device = 0 if torch.cuda.is_available() else -1 | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Load Pipeline | |
| print(f"Loading model: {model_name}...") | |
| if task_type == 'SEQ_2_SEQ_LM': # Summarization | |
| pipe = pipeline("summarization", model=model_name, device=device) | |
| elif task_type == 'TOKEN_CLS': | |
| pipe = pipeline("token-classification", model=model_name, aggregation_strategy="simple") | |
| else: | |
| pipe = pipeline("text-generation", model=model_name, torch_dtype=dtype, device=device) | |
| # Update Cache | |
| MODEL_CACHE["model_name"] = model_name | |
| MODEL_CACHE["pipeline"] = pipe | |
| return pipe | |
| def run_inference_logic(config): | |
| model_id = config['model_name'] | |
| text = config['text'] | |
| task_type = config['task_type'] | |
| hf_token = config['hf_token'] | |
| pipe = get_pipeline(model_id, task_type, hf_token) | |
| if task_type == 'TOKEN_CLS': | |
| results = pipe(text) | |
| results_sorted = sorted(results, key=lambda x: x['start'], reverse=True) | |
| masked_list = list(text) | |
| for ent in results_sorted: | |
| masked_list[ent['start']:ent['end']] = list(f"<{ent['entity_group']}>") | |
| return { | |
| "masked_text": "".join(masked_list), | |
| "labels": [r['entity_group'] for r in results] | |
| } | |
| elif task_type == 'SEQ_2_SEQ_LM': | |
| out = pipe( | |
| text, | |
| max_length=512, | |
| min_length=30, | |
| do_sample=True, | |
| temperature=float(config['temp']), | |
| top_k=int(config['topk']) | |
| ) | |
| return {"output": out[0]['summary_text']} | |
| else: | |
| out = pipe(text, max_new_tokens=1024) | |
| return {"output": out[0]['generated_text']} | |
| # --- Routes --- | |
| def index(): | |
| """Renders the UI.""" | |
| return render_template('inference.html') | |
| def api_summarize(): | |
| """API Endpoint to handle the AJAX request from the UI.""" | |
| data = request.get_json() | |
| if not data or 'text' not in data: | |
| return jsonify({"error": "No text provided"}), 400 | |
| hf_token = os.environ.get("HF_TOKEN","") | |
| config = { | |
| "text": data['text'], | |
| "model_name": data.get('model_name', "facebook/bart-large-cnn"), | |
| "hf_token": data.get('hf_token', hf_token), | |
| "temp": data.get('temp','0.7'), | |
| "topk": data.get('topk','50'), | |
| # We force this for the specific summarization UI, | |
| # but the backend logic supports others. | |
| "task_type": "SEQ_2_SEQ_LM" | |
| } | |
| try: | |
| result = run_inference_logic(config) | |
| return jsonify(result) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| return jsonify({"error": str(e)}), 500 |