fugthchat commited on
Commit
462067d
·
verified ·
1 Parent(s): d034cb0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -34
app.py CHANGED
@@ -4,48 +4,63 @@ import os
4
 
5
  app = Flask(__name__)
6
 
7
- MODEL_PATH = "./model.gguf"
8
- MODEL_URL = "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat.Q4_K_M.gguf"
9
-
10
- # Download if missing
11
- if not os.path.exists(MODEL_PATH):
12
- os.system(f"wget -O {MODEL_PATH} {MODEL_URL}")
13
-
14
- # Load model (optimized for CPU Spaces)
15
- llm = Llama(
16
- model_path=MODEL_PATH,
17
- n_threads=4,
18
- n_ctx=2048,
19
- use_mlock=False,
20
- )
21
-
22
- @app.route('/')
23
- def index():
24
- return jsonify({
25
- "message": "FugthDes Story Generator Active",
26
- "model": "TinyLlama GGUF (CPU)"
27
- })
28
-
29
- @app.route('/generate', methods=['POST'])
 
 
 
 
 
 
 
 
 
 
 
 
30
  def generate():
31
- data = request.get_json()
 
32
  prompt = data.get("prompt", "")
33
- feedback = data.get("feedback", "")
34
  story_memory = data.get("story_memory", "")
 
 
 
35
 
36
- final_prompt = story_memory + "\n\n" + prompt
37
  if feedback:
38
- final_prompt += f"\n\nUser feedback: {feedback}\nContinue or refine story naturally."
 
 
 
 
39
 
40
- print("Prompt received:", final_prompt[:250])
41
 
42
- output = llm(final_prompt, max_tokens=512, temperature=0.8, top_p=0.9)
43
- response_text = output["choices"][0]["text"].strip()
 
44
 
45
- return jsonify({
46
- "response": response_text,
47
- "tokens_used": output["usage"]["total_tokens"]
48
- })
49
 
50
  if __name__ == "__main__":
51
  app.run(host="0.0.0.0", port=7860)
 
4
 
5
  app = Flask(__name__)
6
 
7
+ MODEL_URLS = {
8
+ "light": "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q3_K_S.gguf",
9
+ "medium": "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q4_K_M.gguf",
10
+ "heavy": "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q5_0.gguf"
11
+ }
12
+
13
+ MODEL_PATHS = {
14
+ k: f"{k}.gguf" for k in MODEL_URLS
15
+ }
16
+
17
+ current_model = None
18
+ llm = None
19
+
20
+ def ensure_model(model_choice):
21
+ global llm, current_model
22
+ model_path = MODEL_PATHS[model_choice]
23
+ url = MODEL_URLS[model_choice]
24
+
25
+ if not os.path.exists(model_path):
26
+ print(f"Downloading {model_choice} model...")
27
+ os.system(f"wget -O {model_path} {url}")
28
+
29
+ if current_model != model_choice:
30
+ print(f"Loading {model_choice} model...")
31
+ llm = Llama(model_path=model_path, n_ctx=2048, n_threads=4, use_mlock=False)
32
+ current_model = model_choice
33
+ return llm
34
+
35
+
36
+ @app.route("/status")
37
+ def status():
38
+ return jsonify({"status": "ok" if llm else "not_loaded", "model": current_model})
39
+
40
+
41
+ @app.route("/generate", methods=["POST"])
42
  def generate():
43
+ data = request.get_json(force=True)
44
+ model_choice = data.get("model_choice", "light")
45
  prompt = data.get("prompt", "")
 
46
  story_memory = data.get("story_memory", "")
47
+ feedback = data.get("feedback", "")
48
+
49
+ llm = ensure_model(model_choice)
50
 
51
+ full_prompt = story_memory + "\n\n" + prompt
52
  if feedback:
53
+ full_prompt += f"\n\nUser feedback: {feedback}\n"
54
+
55
+ result = llm(full_prompt, max_tokens=512, temperature=0.8)
56
+ text = result["choices"][0]["text"].strip()
57
+ return jsonify({"response": text})
58
 
 
59
 
60
+ @app.route("/")
61
+ def root():
62
+ return "StableLM Zephyr GGUF API running!"
63
 
 
 
 
 
64
 
65
  if __name__ == "__main__":
66
  app.run(host="0.0.0.0", port=7860)