noding / blueprints /summarize.py
broadfield-dev's picture
Update blueprints/summarize.py
d135050 verified
import os
import torch
from flask import Blueprint, request, jsonify, render_template
from transformers import pipeline
from huggingface_hub import HfFolder
# Define the Blueprint
summarize_bp = Blueprint('summarize', __name__)
# Global cache to store the loaded model in memory
# This prevents reloading the model on every single request
MODEL_CACHE = {
"model_name": None,
"pipeline": None
}
def get_pipeline(model_name, task_type, hf_token):
"""
Retrieves a pipeline from cache or loads it if it's new.
"""
global MODEL_CACHE
# If we already have this model loaded, return it
if MODEL_CACHE["model_name"] == model_name and MODEL_CACHE["pipeline"] is not None:
return MODEL_CACHE["pipeline"]
# Authentication
if hf_token:
HfFolder.save_token(hf_token)
# Determine device
device = 0 if torch.cuda.is_available() else -1
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load Pipeline
print(f"Loading model: {model_name}...")
if task_type == 'SEQ_2_SEQ_LM': # Summarization
pipe = pipeline("summarization", model=model_name, device=device)
elif task_type == 'TOKEN_CLS':
pipe = pipeline("token-classification", model=model_name, aggregation_strategy="simple")
else:
pipe = pipeline("text-generation", model=model_name, torch_dtype=dtype, device=device)
# Update Cache
MODEL_CACHE["model_name"] = model_name
MODEL_CACHE["pipeline"] = pipe
return pipe
def run_inference_logic(config):
model_id = config['model_name']
text = config['text']
task_type = config['task_type']
hf_token = config['hf_token']
pipe = get_pipeline(model_id, task_type, hf_token)
if task_type == 'TOKEN_CLS':
results = pipe(text)
results_sorted = sorted(results, key=lambda x: x['start'], reverse=True)
masked_list = list(text)
for ent in results_sorted:
masked_list[ent['start']:ent['end']] = list(f"<{ent['entity_group']}>")
return {
"masked_text": "".join(masked_list),
"labels": [r['entity_group'] for r in results]
}
elif task_type == 'SEQ_2_SEQ_LM':
out = pipe(
text,
max_length=512,
min_length=30,
do_sample=True,
temperature=float(config['temp']),
top_k=int(config['topk'])
)
return {"output": out[0]['summary_text']}
else:
out = pipe(text, max_new_tokens=1024)
return {"output": out[0]['generated_text']}
# --- Routes ---
@summarize_bp.route('/', methods=['GET'])
def index():
"""Renders the UI."""
return render_template('inference.html')
@summarize_bp.route('/api/summarize', methods=['POST'])
def api_summarize():
"""API Endpoint to handle the AJAX request from the UI."""
data = request.get_json()
if not data or 'text' not in data:
return jsonify({"error": "No text provided"}), 400
hf_token = os.environ.get("HF_TOKEN","")
config = {
"text": data['text'],
"model_name": data.get('model_name', "facebook/bart-large-cnn"),
"hf_token": data.get('hf_token', hf_token),
"temp": data.get('temp','0.7'),
"topk": data.get('topk','50'),
# We force this for the specific summarization UI,
# but the backend logic supports others.
"task_type": "SEQ_2_SEQ_LM"
}
try:
result = run_inference_logic(config)
return jsonify(result)
except Exception as e:
print(f"Error: {e}")
return jsonify({"error": str(e)}), 500