|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from llama_cpp import Llama |
|
|
from transformers import AutoTokenizer |
|
|
import os |
|
|
import json |
|
|
import requests |
|
|
|
|
|
app = FastAPI() |
|
|
MODE = os.environ.get("MODE", "LLM") |
|
|
|
|
|
|
|
|
class MockLLM: |
|
|
def create_chat_completion(self, messages, max_tokens=512, temperature=0): |
|
|
return { |
|
|
"choices": [{ |
|
|
"message": {"content": f"[MOCKED RESPONSE] This is a reply"} |
|
|
}] |
|
|
} |
|
|
|
|
|
|
|
|
print(f"Running in {MODE} mode") |
|
|
|
|
|
if MODE == "MOCK": |
|
|
|
|
|
input_limit = 512 |
|
|
context_length = 1024 |
|
|
llm = Llama(model_path="./model/SILMA-9B-Instruct-v1.0-Q2_K_2.gguf", |
|
|
n_ctx=context_length, n_gpu_layers=10, n_patch=256) |
|
|
else: |
|
|
input_limit = 2048 |
|
|
context_length = 4096 |
|
|
llm = Llama.from_pretrained( |
|
|
repo_id="bartowski/SILMA-9B-Instruct-v1.0-GGUF", |
|
|
filename="SILMA-9B-Instruct-v1.0-Q5_K_M.gguf", |
|
|
n_ctx=context_length, |
|
|
n_threads=2 |
|
|
) |
|
|
|
|
|
|
|
|
class PromptRequest(BaseModel): |
|
|
prompt: str |
|
|
|
|
|
|
|
|
class AnalyzeRequest(BaseModel): |
|
|
data: list |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("silma-ai/SILMA-9B-Instruct-v1.0") |
|
|
|
|
|
|
|
|
codes = """- "m0": regular reply" |
|
|
"- "a0": request site chunked data for analysis" |
|
|
"- "a1": analysis complete" |
|
|
"- "e0": general error" |
|
|
""" |
|
|
analysis_system_message = { |
|
|
"role": "system", |
|
|
"content": ( |
|
|
"""You are an assistant for an accessibility browser extension. " |
|
|
"Your only task is to return a **valid JSON object** based on site chunks content including a summary, action list and a section list. " |
|
|
"The JSON must have this format:" |
|
|
{ |
|
|
"signal": string, |
|
|
"message": string, // (optional) |
|
|
"summary": string,""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+"""Valid signal codes:""" |
|
|
+ codes + |
|
|
""" |
|
|
Rules: |
|
|
1. Always return JSON, never plain text or explanations. |
|
|
2. Do not include extra keys. |
|
|
3. Do not escape JSON unnecessarily. |
|
|
4. Use signal "a1" when analysis is complete. |
|
|
5. For actions and sections, use strictly existing HTML ids. In case of missing ids, ignore the element. |
|
|
6. If unsure, default to {"signal": "e0", "message": "I did not understand the request."} |
|
|
7. Use message only if necessary, like to describe issue or concern""" |
|
|
) |
|
|
} |
|
|
final_analysis_message = { |
|
|
"role": "system", |
|
|
"content": ( |
|
|
"""You are an assistant that combines multiple partial website analyses |
|
|
into one comprehensive final report. Return **only a valid JSON object** |
|
|
in this format:" |
|
|
{ "signal": string, |
|
|
message": string, |
|
|
summary": string ,"""+ |
|
|
|
|
|
|
|
|
"}"+ |
|
|
|
|
|
|
|
|
|
|
|
"""Valid signal codes:""" |
|
|
+ codes + |
|
|
"""Rules: |
|
|
1. Always return JSON, never plain text or explanations. |
|
|
2. Do not include extra keys. |
|
|
3. Do not escape JSON unnecessarily. |
|
|
4. Use signal "a1" when analysis is complete. |
|
|
5. For actions and sections, use strictly existing HTML ids. In case of missing ids, ignore the element. |
|
|
6. If unsure, default to {"signal": "e0", "message": "I did not understand the request."} |
|
|
7. Use message only if necessary, like to describe issue or concern""" |
|
|
|
|
|
) |
|
|
} |
|
|
|
|
|
def count_tokens(str): |
|
|
return len(tokenizer.encode(str)) |
|
|
|
|
|
def format_messages(messages): |
|
|
formatted = "" |
|
|
for m in messages: |
|
|
formatted += f"{m['role'].upper()}: {m['content'].strip()}\n" |
|
|
return formatted |
|
|
|
|
|
def compute_input_size(chunk_str, msg=analysis_system_message): |
|
|
return count_tokens(format_messages( |
|
|
[msg, {"role": "user", "content": chunk_str}])) |
|
|
|
|
|
def process_chunks(chunks, msg, limit): |
|
|
processed_chunks = [] |
|
|
for chunk in chunks: |
|
|
print("input chunk: ", chunk) |
|
|
if compute_input_size(chunk, msg) > limit: |
|
|
|
|
|
print("chunk exceeds limit") |
|
|
old_chunk = json.loads(chunk) |
|
|
|
|
|
|
|
|
elements = old_chunk.get('elements', []) |
|
|
print("elements: ", elements) |
|
|
|
|
|
element_sizes = [ |
|
|
(count_tokens(json.dumps(element)), i) |
|
|
for i, element in enumerate(elements) |
|
|
] |
|
|
element_sizes.sort(reverse=True) |
|
|
print("element_sizes: ", element_sizes) |
|
|
print("elements length: ", len(elements)) |
|
|
|
|
|
for i in range(len(elements)): |
|
|
element_index = element_sizes[i][1] |
|
|
print("element index: ", element_index) |
|
|
if compute_input_size(json.dumps(elements[element_index]), msg) < limit: |
|
|
processed_chunks.append(json.dumps( |
|
|
{**elements[element_index], "parent_id": old_chunk.get('id', '')})) |
|
|
reduced_chunk = {**old_chunk, |
|
|
"elements": elements[:element_index]+elements[element_index+1:]} |
|
|
print("reduced chunk: ", reduced_chunk) |
|
|
if compute_input_size(json.dumps(reduced_chunk), msg) < limit: |
|
|
print("reduced chunk fits") |
|
|
processed_chunks.append(json.dumps(reduced_chunk)) |
|
|
break |
|
|
else: |
|
|
print("reduced chunk exceeds limit") |
|
|
processed_chunks.extend( |
|
|
process_chunks([json.dumps(reduced_chunk)], msg, limit)) |
|
|
break |
|
|
else: |
|
|
processed_chunks.extend( |
|
|
process_chunks([json.dumps(elements[element_index])], msg, limit)) |
|
|
|
|
|
else: |
|
|
print("chunk fits") |
|
|
processed_chunks.append(chunk) |
|
|
|
|
|
print("processed_chunks final:", processed_chunks) |
|
|
return processed_chunks |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def api_home(): |
|
|
return {'detail': 'Welcome to FastAPI TextGen Tutorial!'} |
|
|
|
|
|
|
|
|
@app.post("/prompt") |
|
|
def generate_text(request: PromptRequest): |
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": ( |
|
|
"""You are an assistant for an accessibility browser extension. |
|
|
Your only task is to return a **valid JSON object** based on the user's request. |
|
|
The JSON must have this format: |
|
|
{ "signal": string, "message": string } |
|
|
Valid signal codes: |
|
|
""" + codes + """ |
|
|
Rules: |
|
|
1. Always return JSON, never plain text |
|
|
2. Do not include extra keys. |
|
|
3. Do not escape JSON unnecessarily. |
|
|
4. Request chunking using valid signal if user asks for analysis, summarization, or possible actions. |
|
|
5. If unsure, default to {"signal": "m0", "message": "I did not understand the request."}""" |
|
|
) |
|
|
}, |
|
|
{"role": "user", "content": request.prompt} |
|
|
] |
|
|
|
|
|
token_count = count_tokens(format_messages(messages)) |
|
|
if token_count > input_limit: |
|
|
return {"signal": "e0", "message": "Input exceeds token limit."} |
|
|
|
|
|
output = llm.create_chat_completion( |
|
|
messages=messages, |
|
|
max_tokens=1024, |
|
|
temperature=0 |
|
|
) |
|
|
|
|
|
output_str = output["choices"][0]["message"]["content"] |
|
|
try: |
|
|
output_json = json.loads(output_str) |
|
|
except json.JSONDecodeError: |
|
|
output_json = {"signal": "m0", "message": output_str} |
|
|
|
|
|
return {"output": output_json} |
|
|
|
|
|
|
|
|
@app.post("/analyze") |
|
|
def analyze(request: AnalyzeRequest): |
|
|
analysis_results = [] |
|
|
chunks = process_chunks(request.data, analysis_system_message, input_limit) |
|
|
print("chunks: ", chunks) |
|
|
if not chunks: |
|
|
print("chunks: ", chunks) |
|
|
return {"signal": "e0", "message": "No chunks."} |
|
|
|
|
|
manual_combination = False if input_limit/len(chunks) >= 90 else True |
|
|
|
|
|
for chunk in chunks: |
|
|
print("Analyzing chunk of size:", compute_input_size( |
|
|
chunk, analysis_system_message)) |
|
|
output = llm.create_chat_completion( |
|
|
messages=[ |
|
|
analysis_system_message, |
|
|
{"role": "user", "content": chunk} |
|
|
], |
|
|
max_tokens=(input_limit) / |
|
|
len(chunks) if not manual_combination else input_limit, |
|
|
temperature=0 |
|
|
) |
|
|
output_str = output["choices"][0]["message"]["content"] |
|
|
try: |
|
|
output_json = json.loads(output_str) |
|
|
except json.JSONDecodeError: |
|
|
output_json = {"signal": "e0", |
|
|
"message": "Invalid JSON parsing."} |
|
|
print("JSON parsing error:", output_str) |
|
|
|
|
|
analysis_results.append(output_json) |
|
|
|
|
|
|
|
|
|
|
|
if not manual_combination: |
|
|
combined_result = json.dumps(analysis_results) |
|
|
print("Combined result: ", combined_result) |
|
|
output = llm.create_chat_completion( |
|
|
messages=[ |
|
|
final_analysis_message, |
|
|
{"role": "user", "content": combined_result} |
|
|
], |
|
|
|
|
|
max_tokens=context_length - |
|
|
compute_input_size(json.dumps(analysis_results), |
|
|
final_analysis_message), |
|
|
temperature=0) |
|
|
|
|
|
output_str = output["choices"][0]["message"]["content"] |
|
|
try: |
|
|
output_json = json.loads(output_str) |
|
|
except json.JSONDecodeError: |
|
|
output_json = {"signal": "e0", |
|
|
"message": "Invalid JSON parsing." + output_str} |
|
|
print("JSON parsing error:", output_str) |
|
|
else: |
|
|
for result in analysis_results: |
|
|
if result.get("signal") != "a1": |
|
|
output_json = {"signal": "e0", |
|
|
"message": "Chunk Analysis Failure."} |
|
|
return output_json |
|
|
output_json = { |
|
|
"signal": "a2", |
|
|
"message": "Analysis complete, results were combined manually.", |
|
|
"summary": " ".join([res.get("summary", "") |
|
|
for res in analysis_results]), |
|
|
"actions": [action for res in analysis_results |
|
|
for action in res.get("actions", [])], |
|
|
"sections": [section for res in analysis_results |
|
|
for section in res.get("sections", [])], |
|
|
} |
|
|
return output_json |
|
|
|
|
|
|
|
|
if __name__ == "__main__" and MODE == "MOCK": |
|
|
import uvicorn |
|
|
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |
|
|
|