from fastapi import FastAPI, HTTPException from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from llama_cpp import Llama from transformers import AutoTokenizer import os import json import requests app = FastAPI() MODE = os.environ.get("MODE", "LLM") class MockLLM: def create_chat_completion(self, messages, max_tokens=512, temperature=0): return { "choices": [{ "message": {"content": f"[MOCKED RESPONSE] This is a reply"} }] } print(f"Running in {MODE} mode") if MODE == "MOCK": # llm = MockLLM() input_limit = 512 context_length = 1024 llm = Llama(model_path="./model/SILMA-9B-Instruct-v1.0-Q2_K_2.gguf", n_ctx=context_length, n_gpu_layers=10, n_patch=256) else: input_limit = 2048 context_length = 4096 llm = Llama.from_pretrained( repo_id="bartowski/SILMA-9B-Instruct-v1.0-GGUF", filename="SILMA-9B-Instruct-v1.0-Q5_K_M.gguf", n_ctx=context_length, n_threads=2 ) class PromptRequest(BaseModel): prompt: str class AnalyzeRequest(BaseModel): data: list tokenizer = AutoTokenizer.from_pretrained("silma-ai/SILMA-9B-Instruct-v1.0") # signal codes codes = """- "m0": regular reply" "- "a0": request site chunked data for analysis" "- "a1": analysis complete" "- "e0": general error" """ analysis_system_message = { "role": "system", "content": ( """You are an assistant for an accessibility browser extension. " "Your only task is to return a **valid JSON object** based on site chunks content including a summary, action list and a section list. " "The JSON must have this format:" { "signal": string, "message": string, // (optional) "summary": string,""" # "actions": [{ "id": string, "name": string }; use only existing HTML ids # ], # "sections": [{ "id": string, "name": string } # ] # Where: # - actions is an array where each value consist of JSON object of HTML element ID and a name you suggest for the action. for example, { "id": "loginBtn", "name": "Press login" }. # - sections is an array where each value consist of JSON object of section ID and a name you suggest for the section. for example, { "id": "about-me", "name": "Header Section" }. +"""Valid signal codes:""" + codes + """ Rules: 1. Always return JSON, never plain text or explanations. 2. Do not include extra keys. 3. Do not escape JSON unnecessarily. 4. Use signal "a1" when analysis is complete. 5. For actions and sections, use strictly existing HTML ids. In case of missing ids, ignore the element. 6. If unsure, default to {"signal": "e0", "message": "I did not understand the request."} 7. Use message only if necessary, like to describe issue or concern""" ) } final_analysis_message = { "role": "system", "content": ( """You are an assistant that combines multiple partial website analyses into one comprehensive final report. Return **only a valid JSON object** in this format:" { "signal": string, message": string, summary": string ,"""+ # "actions": array, # "sections": array "}"+ # "Where:" # "- "actions" is an array where each value consist of JSON object of HTML element ID and a name you suggest for the action. for example, { "id": "loginBtn", "name": "Press login" }." # "- "sections" is an array where each value consist of JSON object of section ID and a name you suggest for the section. for example, { "id": "about-me", "name": "Header Section" }." """Valid signal codes:""" + codes + """Rules: 1. Always return JSON, never plain text or explanations. 2. Do not include extra keys. 3. Do not escape JSON unnecessarily. 4. Use signal "a1" when analysis is complete. 5. For actions and sections, use strictly existing HTML ids. In case of missing ids, ignore the element. 6. If unsure, default to {"signal": "e0", "message": "I did not understand the request."} 7. Use message only if necessary, like to describe issue or concern""" ) } def count_tokens(str): return len(tokenizer.encode(str)) def format_messages(messages): formatted = "" for m in messages: formatted += f"{m['role'].upper()}: {m['content'].strip()}\n" return formatted def compute_input_size(chunk_str, msg=analysis_system_message): return count_tokens(format_messages( [msg, {"role": "user", "content": chunk_str}])) def process_chunks(chunks, msg, limit): processed_chunks = [] for chunk in chunks: print("input chunk: ", chunk) if compute_input_size(chunk, msg) > limit: print("chunk exceeds limit") old_chunk = json.loads(chunk) # Remove largest elements until it fits elements = old_chunk.get('elements', []) print("elements: ", elements) # reminder: [0] for size, [1] for index element_sizes = [ (count_tokens(json.dumps(element)), i) for i, element in enumerate(elements) ] element_sizes.sort(reverse=True) print("element_sizes: ", element_sizes) print("elements length: ", len(elements)) for i in range(len(elements)): element_index = element_sizes[i][1] print("element index: ", element_index) if compute_input_size(json.dumps(elements[element_index]), msg) < limit: processed_chunks.append(json.dumps( {**elements[element_index], "parent_id": old_chunk.get('id', '')})) reduced_chunk = {**old_chunk, "elements": elements[:element_index]+elements[element_index+1:]} print("reduced chunk: ", reduced_chunk) if compute_input_size(json.dumps(reduced_chunk), msg) < limit: print("reduced chunk fits") processed_chunks.append(json.dumps(reduced_chunk)) break else: print("reduced chunk exceeds limit") processed_chunks.extend( process_chunks([json.dumps(reduced_chunk)], msg, limit)) break else: processed_chunks.extend( process_chunks([json.dumps(elements[element_index])], msg, limit)) else: print("chunk fits") processed_chunks.append(chunk) print("processed_chunks final:", processed_chunks) return processed_chunks app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Routes @app.get("/") def api_home(): return {'detail': 'Welcome to FastAPI TextGen Tutorial!'} @app.post("/prompt") def generate_text(request: PromptRequest): messages = [ { "role": "system", "content": ( """You are an assistant for an accessibility browser extension. Your only task is to return a **valid JSON object** based on the user's request. The JSON must have this format: { "signal": string, "message": string } Valid signal codes: """ + codes + """ Rules: 1. Always return JSON, never plain text 2. Do not include extra keys. 3. Do not escape JSON unnecessarily. 4. Request chunking using valid signal if user asks for analysis, summarization, or possible actions. 5. If unsure, default to {"signal": "m0", "message": "I did not understand the request."}""" ) }, {"role": "user", "content": request.prompt} ] token_count = count_tokens(format_messages(messages)) if token_count > input_limit: return {"signal": "e0", "message": "Input exceeds token limit."} output = llm.create_chat_completion( messages=messages, max_tokens=1024, temperature=0 ) output_str = output["choices"][0]["message"]["content"] try: output_json = json.loads(output_str) except json.JSONDecodeError: output_json = {"signal": "m0", "message": output_str} return {"output": output_json} @app.post("/analyze") def analyze(request: AnalyzeRequest): analysis_results = [] chunks = process_chunks(request.data, analysis_system_message, input_limit) print("chunks: ", chunks) if not chunks: print("chunks: ", chunks) return {"signal": "e0", "message": "No chunks."} manual_combination = False if input_limit/len(chunks) >= 90 else True for chunk in chunks: print("Analyzing chunk of size:", compute_input_size( chunk, analysis_system_message)) output = llm.create_chat_completion( messages=[ analysis_system_message, {"role": "user", "content": chunk} ], max_tokens=(input_limit) / len(chunks) if not manual_combination else input_limit, temperature=0 ) output_str = output["choices"][0]["message"]["content"] try: output_json = json.loads(output_str) except json.JSONDecodeError: output_json = {"signal": "e0", "message": "Invalid JSON parsing."} print("JSON parsing error:", output_str) analysis_results.append(output_json) # combine results if not manual_combination: combined_result = json.dumps(analysis_results) print("Combined result: ", combined_result) output = llm.create_chat_completion( messages=[ final_analysis_message, {"role": "user", "content": combined_result} ], # input might exccede the limit due to system message max_tokens=context_length - compute_input_size(json.dumps(analysis_results), final_analysis_message), temperature=0) output_str = output["choices"][0]["message"]["content"] try: output_json = json.loads(output_str) except json.JSONDecodeError: output_json = {"signal": "e0", "message": "Invalid JSON parsing." + output_str} print("JSON parsing error:", output_str) else: for result in analysis_results: if result.get("signal") != "a1": output_json = {"signal": "e0", "message": "Chunk Analysis Failure."} return output_json output_json = { "signal": "a2", "message": "Analysis complete, results were combined manually.", "summary": " ".join([res.get("summary", "") for res in analysis_results]), "actions": [action for res in analysis_results for action in res.get("actions", [])], "sections": [section for res in analysis_results for section in res.get("sections", [])], } return output_json if __name__ == "__main__" and MODE == "MOCK": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)