File size: 3,641 Bytes
890fbe9
 
 
 
 
 
 
abb6d0f
890fbe9
 
 
 
 
 
 
 
c2e3eb9
890fbe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19e8e61
890fbe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2e3eb9
fe43ce2
890fbe9
 
 
 
 
 
 
 
 
 
 
 
 
da702bd
 
 
 
 
d135050
 
da702bd
890fbe9
 
 
 
 
 
 
 
799110a
890fbe9
 
bc221f8
890fbe9
799110a
890fbe9
 
 
 
 
 
952b0ec
890fbe9
 
 
2cd2e26
d135050
 
890fbe9
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import torch
from flask import Blueprint, request, jsonify, render_template
from transformers import pipeline
from huggingface_hub import HfFolder

# Define the Blueprint
summarize_bp = Blueprint('summarize', __name__)

# Global cache to store the loaded model in memory
# This prevents reloading the model on every single request
MODEL_CACHE = {
    "model_name": None,
    "pipeline": None
}

def get_pipeline(model_name, task_type, hf_token):
    """
    Retrieves a pipeline from cache or loads it if it's new.
    """
    global MODEL_CACHE
    
    # If we already have this model loaded, return it
    if MODEL_CACHE["model_name"] == model_name and MODEL_CACHE["pipeline"] is not None:
        return MODEL_CACHE["pipeline"]

    # Authentication
    if hf_token: 
        HfFolder.save_token(hf_token)

    # Determine device
    device = 0 if torch.cuda.is_available() else -1
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    # Load Pipeline
    print(f"Loading model: {model_name}...")
    if task_type == 'SEQ_2_SEQ_LM': # Summarization
        pipe = pipeline("summarization", model=model_name, device=device)
    elif task_type == 'TOKEN_CLS':
        pipe = pipeline("token-classification", model=model_name, aggregation_strategy="simple")
    else:
        pipe = pipeline("text-generation", model=model_name, torch_dtype=dtype, device=device)
    
    # Update Cache
    MODEL_CACHE["model_name"] = model_name
    MODEL_CACHE["pipeline"] = pipe
    
    return pipe

def run_inference_logic(config):
    model_id = config['model_name']
    text = config['text']
    task_type = config['task_type']
    hf_token = config['hf_token']
    pipe = get_pipeline(model_id, task_type, hf_token)

    if task_type == 'TOKEN_CLS':
        results = pipe(text)
        results_sorted = sorted(results, key=lambda x: x['start'], reverse=True)
        masked_list = list(text)
        for ent in results_sorted:
            masked_list[ent['start']:ent['end']] = list(f"<{ent['entity_group']}>")
        return {
            "masked_text": "".join(masked_list),
            "labels": [r['entity_group'] for r in results]
        }

    elif task_type == 'SEQ_2_SEQ_LM':
        out = pipe(
            text,
            max_length=512,
            min_length=30,
            do_sample=True,
            temperature=float(config['temp']),
            top_k=int(config['topk'])
        )
        return {"output": out[0]['summary_text']}

    else:
        out = pipe(text, max_new_tokens=1024)
        return {"output": out[0]['generated_text']}

# --- Routes ---

@summarize_bp.route('/', methods=['GET'])
def index():
    """Renders the UI."""
    return render_template('inference.html')

@summarize_bp.route('/api/summarize', methods=['POST'])
def api_summarize():
    """API Endpoint to handle the AJAX request from the UI."""
    data = request.get_json()
    
    if not data or 'text' not in data:
        return jsonify({"error": "No text provided"}), 400
    hf_token = os.environ.get("HF_TOKEN","")
    config = {
        "text": data['text'],
        "model_name": data.get('model_name', "facebook/bart-large-cnn"),
        "hf_token": data.get('hf_token', hf_token),
        "temp": data.get('temp','0.7'),
        "topk": data.get('topk','50'),
        # We force this for the specific summarization UI, 
        # but the backend logic supports others.
        "task_type": "SEQ_2_SEQ_LM" 
    }

    try:
        result = run_inference_logic(config)
        return jsonify(result)
    except Exception as e:
        print(f"Error: {e}")
        return jsonify({"error": str(e)}), 500