broadfield-dev commited on
Commit
890fbe9
·
verified ·
1 Parent(s): 2d4edb9

Create blueprints/inference.py

Browse files
Files changed (1) hide show
  1. blueprints/inference.py +106 -0
blueprints/inference.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from flask import Blueprint, request, jsonify, render_template
4
+ from transformers import pipeline
5
+ from huggingface_hub import HfFolder
6
+
7
+ # Define the Blueprint
8
+ inference_bp = Blueprint('inference', __name__)
9
+
10
+ # Global cache to store the loaded model in memory
11
+ # This prevents reloading the model on every single request
12
+ MODEL_CACHE = {
13
+ "model_name": None,
14
+ "pipeline": None
15
+ }
16
+
17
+ def get_pipeline(model_name, task_type):
18
+ """
19
+ Retrieves a pipeline from cache or loads it if it's new.
20
+ """
21
+ global MODEL_CACHE
22
+
23
+ # If we already have this model loaded, return it
24
+ if MODEL_CACHE["model_name"] == model_name and MODEL_CACHE["pipeline"] is not None:
25
+ return MODEL_CACHE["pipeline"]
26
+
27
+ # Authentication
28
+ hf_token = os.getenv("HF_TOKEN")
29
+ if hf_token:
30
+ HfFolder.save_token(hf_token)
31
+
32
+ # Determine device
33
+ device = 0 if torch.cuda.is_available() else -1
34
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
35
+
36
+ # Load Pipeline
37
+ print(f"Loading model: {model_name}...")
38
+ if task_type == 'SEQ_2_SEQ_LM': # Summarization
39
+ pipe = pipeline("summarization", model=model_id, device=device)
40
+ elif task_type == 'TOKEN_CLS':
41
+ pipe = pipeline("token-classification", model=model_name, aggregation_strategy="simple")
42
+ else:
43
+ pipe = pipeline("text-generation", model=model_name, torch_dtype=dtype, device=device)
44
+
45
+ # Update Cache
46
+ MODEL_CACHE["model_name"] = model_name
47
+ MODEL_CACHE["pipeline"] = pipe
48
+
49
+ return pipe
50
+
51
+ def run_inference_logic(config):
52
+ model_id = config['model_name']
53
+ text = config['text']
54
+ task_type = config['task_type']
55
+
56
+ pipe = get_pipeline(model_id, task_type)
57
+
58
+ if task_type == 'TOKEN_CLS':
59
+ results = pipe(text)
60
+ results_sorted = sorted(results, key=lambda x: x['start'], reverse=True)
61
+ masked_list = list(text)
62
+ for ent in results_sorted:
63
+ masked_list[ent['start']:ent['end']] = list(f"<{ent['entity_group']}>")
64
+ return {
65
+ "masked_text": "".join(masked_list),
66
+ "labels": [r['entity_group'] for r in results]
67
+ }
68
+
69
+ elif task_type == 'SEQ_2_SEQ_LM':
70
+ # Summarization specific args
71
+ out = pipe(text, max_length=512, min_length=30, do_sample=False)
72
+ return {"output": out[0]['summary_text']}
73
+
74
+ else:
75
+ out = pipe(text, max_new_tokens=1024)
76
+ return {"output": out[0]['generated_text']}
77
+
78
+ # --- Routes ---
79
+
80
+ @inference_bp.route('/', methods=['GET'])
81
+ def index():
82
+ """Renders the UI."""
83
+ return render_template('index.html')
84
+
85
+ @inference_bp.route('/api/summarize', methods=['POST'])
86
+ def api_summarize():
87
+ """API Endpoint to handle the AJAX request from the UI."""
88
+ data = request.get_json()
89
+
90
+ if not data or 'text' not in data:
91
+ return jsonify({"error": "No text provided"}), 400
92
+
93
+ config = {
94
+ "text": data['text'],
95
+ "model_name": data.get('model_name', "facebook/bart-large-cnn"),
96
+ # We force this for the specific summarization UI,
97
+ # but the backend logic supports others.
98
+ "task_type": "SEQ_2_SEQ_LM"
99
+ }
100
+
101
+ try:
102
+ result = run_inference_logic(config)
103
+ return jsonify(result)
104
+ except Exception as e:
105
+ print(f"Error: {e}")
106
+ return jsonify({"error": str(e)}), 500