memorease commited on
Commit
5860888
·
verified ·
1 Parent(s): c134853

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -28
app.py CHANGED
@@ -1,47 +1,40 @@
1
  from flask import Flask, request, jsonify
2
- from gradio_client import Client
3
- import os
4
- import threading
5
 
6
  app = Flask(__name__)
7
 
8
- HF_TOKEN = os.environ.get("HF_TOKEN") # token env üzerinden alınıyor
9
- client = None # Global client cache
10
-
11
- # 🔥 Client preload – FLASK DIŞI başlatılıyor
12
- def preload_client():
13
- global client
14
- try:
15
- if client is None:
16
- print("[Startup] Preloading Client...")
17
- client = Client("memorease/flan5_memorease", hf_token=HF_TOKEN)
18
- print("[Startup] Client initialized.")
19
- except Exception as e:
20
- print(f"[Startup] Client preload failed: {e}")
21
-
22
- # ⏱️ Flask başlamadan önce preload başlasın
23
- threading.Thread(target=preload_client).start()
24
 
25
  @app.route("/ask", methods=["POST"])
26
  def ask_question():
27
- global client
28
  try:
29
- if client is None:
30
- client = Client("memorease/flan5_memorease", hf_token=HF_TOKEN)
31
-
32
  input_text = request.json.get("text")
33
  if not input_text:
34
  return jsonify({"error": "Missing 'text'"}), 400
35
 
36
- result = client.predict(input_text, api_name="/predict")
37
- return jsonify({"question": result})
 
 
 
 
 
 
 
 
 
38
  except Exception as e:
39
  return jsonify({"error": str(e)}), 500
40
 
41
  @app.route("/", methods=["GET"])
42
- def root_check():
43
  return jsonify({"status": "running"})
44
 
45
  if __name__ == "__main__":
46
- port = int(os.environ.get("PORT", 7860))
47
- app.run(host="0.0.0.0", port=port)
 
1
  from flask import Flask, request, jsonify
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import torch
 
4
 
5
  app = Flask(__name__)
6
 
7
+ # Modeli ve tokenizer'ı direkt Hugging Face'ten yüklüyoruz
8
+ model_name = "memorease/memorease-flan-t5"
9
+ print("[Startup] Loading model...")
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
+ print("[Startup] Model loaded.")
 
 
 
 
 
 
 
 
 
 
13
 
14
  @app.route("/ask", methods=["POST"])
15
  def ask_question():
 
16
  try:
 
 
 
17
  input_text = request.json.get("text")
18
  if not input_text:
19
  return jsonify({"error": "Missing 'text'"}), 400
20
 
21
+ # Prompt oluştur
22
+ prompt = f"Only generate a factual and relevant question about this memory: {input_text}"
23
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
24
+
25
+ # Inference
26
+ with torch.no_grad():
27
+ outputs = model.generate(**inputs, max_new_tokens=64)
28
+
29
+ question = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
+ return jsonify({"question": question})
31
+
32
  except Exception as e:
33
  return jsonify({"error": str(e)}), 500
34
 
35
  @app.route("/", methods=["GET"])
36
+ def healthcheck():
37
  return jsonify({"status": "running"})
38
 
39
  if __name__ == "__main__":
40
+ app.run(host="0.0.0.0", port=7860)