silmaQ5 / app.py
odai0's picture
cleaning
af0f57d
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama
from transformers import AutoTokenizer
import os
import json
import requests
app = FastAPI()
MODE = os.environ.get("MODE", "LLM")
class MockLLM:
def create_chat_completion(self, messages, max_tokens=512, temperature=0):
return {
"choices": [{
"message": {"content": f"[MOCKED RESPONSE] This is a reply"}
}]
}
print(f"Running in {MODE} mode")
if MODE == "MOCK":
# llm = MockLLM()
input_limit = 512
context_length = 1024
llm = Llama(model_path="./model/SILMA-9B-Instruct-v1.0-Q2_K_2.gguf",
n_ctx=context_length, n_gpu_layers=10, n_patch=256)
else:
input_limit = 2048
context_length = 4096
llm = Llama.from_pretrained(
repo_id="bartowski/SILMA-9B-Instruct-v1.0-GGUF",
filename="SILMA-9B-Instruct-v1.0-Q5_K_M.gguf",
n_ctx=context_length,
n_threads=2
)
class PromptRequest(BaseModel):
prompt: str
class AnalyzeRequest(BaseModel):
data: list
tokenizer = AutoTokenizer.from_pretrained("silma-ai/SILMA-9B-Instruct-v1.0")
# signal codes
codes = """- "m0": regular reply"
"- "a0": request site chunked data for analysis"
"- "a1": analysis complete"
"- "e0": general error"
"""
analysis_system_message = {
"role": "system",
"content": (
"""You are an assistant for an accessibility browser extension. "
"Your only task is to return a **valid JSON object** based on site chunks content including a summary, action list and a section list. "
"The JSON must have this format:"
{
"signal": string,
"message": string, // (optional)
"summary": string,"""
# "actions": [{ "id": string, "name": string }; use only existing HTML ids
# ],
# "sections": [{ "id": string, "name": string }
# ]
# Where:
# - actions is an array where each value consist of JSON object of HTML element ID and a name you suggest for the action. for example, { "id": "loginBtn", "name": "Press login" }.
# - sections is an array where each value consist of JSON object of section ID and a name you suggest for the section. for example, { "id": "about-me", "name": "Header Section" }.
+"""Valid signal codes:"""
+ codes +
"""
Rules:
1. Always return JSON, never plain text or explanations.
2. Do not include extra keys.
3. Do not escape JSON unnecessarily.
4. Use signal "a1" when analysis is complete.
5. For actions and sections, use strictly existing HTML ids. In case of missing ids, ignore the element.
6. If unsure, default to {"signal": "e0", "message": "I did not understand the request."}
7. Use message only if necessary, like to describe issue or concern"""
)
}
final_analysis_message = {
"role": "system",
"content": (
"""You are an assistant that combines multiple partial website analyses
into one comprehensive final report. Return **only a valid JSON object**
in this format:"
{ "signal": string,
message": string,
summary": string ,"""+
# "actions": array,
# "sections": array
"}"+
# "Where:"
# "- "actions" is an array where each value consist of JSON object of HTML element ID and a name you suggest for the action. for example, { "id": "loginBtn", "name": "Press login" }."
# "- "sections" is an array where each value consist of JSON object of section ID and a name you suggest for the section. for example, { "id": "about-me", "name": "Header Section" }."
"""Valid signal codes:"""
+ codes +
"""Rules:
1. Always return JSON, never plain text or explanations.
2. Do not include extra keys.
3. Do not escape JSON unnecessarily.
4. Use signal "a1" when analysis is complete.
5. For actions and sections, use strictly existing HTML ids. In case of missing ids, ignore the element.
6. If unsure, default to {"signal": "e0", "message": "I did not understand the request."}
7. Use message only if necessary, like to describe issue or concern"""
)
}
def count_tokens(str):
return len(tokenizer.encode(str))
def format_messages(messages):
formatted = ""
for m in messages:
formatted += f"{m['role'].upper()}: {m['content'].strip()}\n"
return formatted
def compute_input_size(chunk_str, msg=analysis_system_message):
return count_tokens(format_messages(
[msg, {"role": "user", "content": chunk_str}]))
def process_chunks(chunks, msg, limit):
processed_chunks = []
for chunk in chunks:
print("input chunk: ", chunk)
if compute_input_size(chunk, msg) > limit:
print("chunk exceeds limit")
old_chunk = json.loads(chunk)
# Remove largest elements until it fits
elements = old_chunk.get('elements', [])
print("elements: ", elements)
# reminder: [0] for size, [1] for index
element_sizes = [
(count_tokens(json.dumps(element)), i)
for i, element in enumerate(elements)
]
element_sizes.sort(reverse=True)
print("element_sizes: ", element_sizes)
print("elements length: ", len(elements))
for i in range(len(elements)):
element_index = element_sizes[i][1]
print("element index: ", element_index)
if compute_input_size(json.dumps(elements[element_index]), msg) < limit:
processed_chunks.append(json.dumps(
{**elements[element_index], "parent_id": old_chunk.get('id', '')}))
reduced_chunk = {**old_chunk,
"elements": elements[:element_index]+elements[element_index+1:]}
print("reduced chunk: ", reduced_chunk)
if compute_input_size(json.dumps(reduced_chunk), msg) < limit:
print("reduced chunk fits")
processed_chunks.append(json.dumps(reduced_chunk))
break
else:
print("reduced chunk exceeds limit")
processed_chunks.extend(
process_chunks([json.dumps(reduced_chunk)], msg, limit))
break
else:
processed_chunks.extend(
process_chunks([json.dumps(elements[element_index])], msg, limit))
else:
print("chunk fits")
processed_chunks.append(chunk)
print("processed_chunks final:", processed_chunks)
return processed_chunks
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Routes
@app.get("/")
def api_home():
return {'detail': 'Welcome to FastAPI TextGen Tutorial!'}
@app.post("/prompt")
def generate_text(request: PromptRequest):
messages = [
{
"role": "system",
"content": (
"""You are an assistant for an accessibility browser extension.
Your only task is to return a **valid JSON object** based on the user's request.
The JSON must have this format:
{ "signal": string, "message": string }
Valid signal codes:
""" + codes + """
Rules:
1. Always return JSON, never plain text
2. Do not include extra keys.
3. Do not escape JSON unnecessarily.
4. Request chunking using valid signal if user asks for analysis, summarization, or possible actions.
5. If unsure, default to {"signal": "m0", "message": "I did not understand the request."}"""
)
},
{"role": "user", "content": request.prompt}
]
token_count = count_tokens(format_messages(messages))
if token_count > input_limit:
return {"signal": "e0", "message": "Input exceeds token limit."}
output = llm.create_chat_completion(
messages=messages,
max_tokens=1024,
temperature=0
)
output_str = output["choices"][0]["message"]["content"]
try:
output_json = json.loads(output_str)
except json.JSONDecodeError:
output_json = {"signal": "m0", "message": output_str}
return {"output": output_json}
@app.post("/analyze")
def analyze(request: AnalyzeRequest):
analysis_results = []
chunks = process_chunks(request.data, analysis_system_message, input_limit)
print("chunks: ", chunks)
if not chunks:
print("chunks: ", chunks)
return {"signal": "e0", "message": "No chunks."}
manual_combination = False if input_limit/len(chunks) >= 90 else True
for chunk in chunks:
print("Analyzing chunk of size:", compute_input_size(
chunk, analysis_system_message))
output = llm.create_chat_completion(
messages=[
analysis_system_message,
{"role": "user", "content": chunk}
],
max_tokens=(input_limit) /
len(chunks) if not manual_combination else input_limit,
temperature=0
)
output_str = output["choices"][0]["message"]["content"]
try:
output_json = json.loads(output_str)
except json.JSONDecodeError:
output_json = {"signal": "e0",
"message": "Invalid JSON parsing."}
print("JSON parsing error:", output_str)
analysis_results.append(output_json)
# combine results
if not manual_combination:
combined_result = json.dumps(analysis_results)
print("Combined result: ", combined_result)
output = llm.create_chat_completion(
messages=[
final_analysis_message,
{"role": "user", "content": combined_result}
],
# input might exccede the limit due to system message
max_tokens=context_length -
compute_input_size(json.dumps(analysis_results),
final_analysis_message),
temperature=0)
output_str = output["choices"][0]["message"]["content"]
try:
output_json = json.loads(output_str)
except json.JSONDecodeError:
output_json = {"signal": "e0",
"message": "Invalid JSON parsing." + output_str}
print("JSON parsing error:", output_str)
else:
for result in analysis_results:
if result.get("signal") != "a1":
output_json = {"signal": "e0",
"message": "Chunk Analysis Failure."}
return output_json
output_json = {
"signal": "a2",
"message": "Analysis complete, results were combined manually.",
"summary": " ".join([res.get("summary", "")
for res in analysis_results]),
"actions": [action for res in analysis_results
for action in res.get("actions", [])],
"sections": [section for res in analysis_results
for section in res.get("sections", [])],
}
return output_json
if __name__ == "__main__" and MODE == "MOCK":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)