File size: 12,519 Bytes
b4a4eae
 
 
 
3cd7159
b4a4eae
 
912e36b
b4a4eae
 
73b9ebc
d91a4e1
2c06be4
b4a4eae
 
 
 
 
 
 
 
2c06be4
b4a4eae
 
 
d91a4e1
af0f57d
912e36b
 
 
b4a4eae
af0f57d
912e36b
b4a4eae
 
 
d91a4e1
912e36b
b4a4eae
 
2c06be4
b4a4eae
 
 
912e36b
2c06be4
 
 
912e36b
3cd7159
 
 
 
 
 
 
 
 
 
 
 
 
 
d91a4e1
912e36b
 
f9b06f5
d247d97
 
 
 
 
 
 
f9b06f5
 
 
d91a4e1
 
 
 
 
 
 
 
3cd7159
 
52426df
 
 
f9b06f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cea507
52426df
 
 
8cea507
 
 
 
 
 
 
 
3cd7159
 
 
 
 
d91a4e1
3cd7159
 
912e36b
3cd7159
912e36b
d91a4e1
3cd7159
 
 
 
912e36b
3cd7159
 
 
 
 
 
912e36b
 
3cd7159
 
 
912e36b
3cd7159
 
 
 
 
912e36b
3cd7159
d91a4e1
3cd7159
d91a4e1
3cd7159
d91a4e1
3cd7159
 
912e36b
d91a4e1
 
 
52426df
 
d91a4e1
52426df
912e36b
 
4f8bdb1
3cd7159
52426df
b4a4eae
 
 
 
 
 
 
 
2c06be4
 
912e36b
b4a4eae
 
 
 
 
 
 
3cd7159
 
 
 
af0f57d
 
 
 
 
3cd7159
af0f57d
 
 
 
 
 
3cd7159
 
 
 
 
8cea507
af0f57d
 
 
3cd7159
 
8cea507
b4a4eae
 
2c06be4
b4a4eae
 
 
 
 
 
 
 
 
2c06be4
 
 
52426df
912e36b
d91a4e1
912e36b
d91a4e1
52426df
 
d91a4e1
3cd7159
912e36b
 
2c06be4
 
4f8bdb1
3cd7159
 
d91a4e1
 
3cd7159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643c21e
3cd7159
 
 
 
 
 
912e36b
3cd7159
 
 
 
2c06be4
 
 
 
3cd7159
 
2c06be4
3cd7159
 
 
 
 
d91a4e1
 
 
 
 
 
 
 
 
 
 
3cd7159
2c06be4
912e36b
b4a4eae
 
912e36b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama
from transformers import AutoTokenizer
import os
import json
import requests

app = FastAPI()
MODE = os.environ.get("MODE", "LLM")


class MockLLM:
    def create_chat_completion(self, messages, max_tokens=512, temperature=0):
        return {
            "choices": [{
                "message": {"content": f"[MOCKED RESPONSE] This is a reply"}
            }]
        }


print(f"Running in {MODE} mode")

if MODE == "MOCK":
    # llm = MockLLM()
    input_limit = 512
    context_length = 1024
    llm = Llama(model_path="./model/SILMA-9B-Instruct-v1.0-Q2_K_2.gguf",
                n_ctx=context_length, n_gpu_layers=10, n_patch=256)
else:
    input_limit = 2048
    context_length = 4096
    llm = Llama.from_pretrained(
        repo_id="bartowski/SILMA-9B-Instruct-v1.0-GGUF",
        filename="SILMA-9B-Instruct-v1.0-Q5_K_M.gguf",
        n_ctx=context_length,
        n_threads=2
    )


class PromptRequest(BaseModel):
    prompt: str


class AnalyzeRequest(BaseModel):
    data: list


tokenizer = AutoTokenizer.from_pretrained("silma-ai/SILMA-9B-Instruct-v1.0")

# signal codes
codes = """- "m0": regular reply"
"- "a0": request site chunked data for analysis"
"- "a1": analysis complete"
"- "e0": general error"
"""
analysis_system_message = {
    "role": "system",
    "content": (
            """You are an assistant for an accessibility browser extension. "
                        "Your only task is to return a **valid JSON object** based on site chunks content including a summary, action list and a section list. "
                        "The JSON must have this format:"
                        {
                        "signal": string,
                        "message": string,  // (optional)
                        "summary": string,"""
                        # "actions": [{ "id": string, "name": string }; use only existing HTML ids
                        # ],
                        # "sections": [{ "id": string, "name": string }
                        # ]
                        # Where:
                        # - actions is an array where each value consist of JSON object of HTML element ID and a name you suggest for the action. for example, { "id": "loginBtn", "name": "Press login" }.
                        # - sections is an array where each value consist of JSON object of section ID and a name you suggest for the section. for example, { "id": "about-me", "name": "Header Section" }.
                        +"""Valid signal codes:"""
                        + codes + 
                        """
                        Rules:
                        1. Always return JSON, never plain text or explanations.
                        2. Do not include extra keys.
                        3. Do not escape JSON unnecessarily.
                        4. Use signal "a1" when analysis is complete.
                        5. For actions and sections, use strictly existing HTML ids. In case of missing ids, ignore the element.
                        6. If unsure, default to {"signal": "e0", "message": "I did not understand the request."}
                        7. Use message only if necessary, like to describe issue or concern"""
    )
}
final_analysis_message = {
    "role": "system",
    "content": (
            """You are an assistant that combines multiple partial website analyses 
                        into one comprehensive final report. Return **only a valid JSON object** 
                        in this format:"
                        { "signal": string,
                        message": string,
                        summary": string ,"""+
                        # "actions": array,
                        # "sections": array 
                        "}"+
                        # "Where:"
                        # "- "actions" is an array where each value consist of JSON object of HTML element ID and a name you suggest for the action. for example, { "id": "loginBtn", "name": "Press login" }."
                        # "- "sections" is an array where each value consist of JSON object of section ID and a name you suggest for the section. for example, { "id": "about-me", "name": "Header Section" }."
                        """Valid signal codes:"""
                        + codes +
                        """Rules:
                        1. Always return JSON, never plain text or explanations.
                        2. Do not include extra keys.
                        3. Do not escape JSON unnecessarily.
                        4. Use signal "a1" when analysis is complete.
                        5. For actions and sections, use strictly existing HTML ids. In case of missing ids, ignore the element.
                        6. If unsure, default to {"signal": "e0", "message": "I did not understand the request."}
                        7. Use message only if necessary, like to describe issue or concern"""

    )
}

def count_tokens(str):
    return len(tokenizer.encode(str))

def format_messages(messages):
    formatted = ""
    for m in messages:
        formatted += f"{m['role'].upper()}: {m['content'].strip()}\n"
    return formatted

def compute_input_size(chunk_str, msg=analysis_system_message):
    return count_tokens(format_messages(
        [msg, {"role": "user", "content": chunk_str}]))

def process_chunks(chunks, msg, limit):
    processed_chunks = []
    for chunk in chunks:
        print("input chunk: ", chunk)
        if compute_input_size(chunk, msg) > limit:

            print("chunk exceeds limit")
            old_chunk = json.loads(chunk)

            # Remove largest elements until it fits
            elements = old_chunk.get('elements', [])
            print("elements: ", elements)
            # reminder: [0] for size, [1] for index
            element_sizes = [
                (count_tokens(json.dumps(element)), i)
                for i, element in enumerate(elements)
            ]
            element_sizes.sort(reverse=True)
            print("element_sizes: ", element_sizes)
            print("elements length: ", len(elements))

            for i in range(len(elements)):
                element_index = element_sizes[i][1]
                print("element index: ", element_index)
                if compute_input_size(json.dumps(elements[element_index]), msg) < limit:
                    processed_chunks.append(json.dumps(
                        {**elements[element_index], "parent_id": old_chunk.get('id', '')}))
                    reduced_chunk = {**old_chunk,
                                     "elements": elements[:element_index]+elements[element_index+1:]}
                    print("reduced chunk: ", reduced_chunk)
                    if compute_input_size(json.dumps(reduced_chunk), msg) < limit:
                        print("reduced chunk fits")
                        processed_chunks.append(json.dumps(reduced_chunk))
                        break
                    else:
                        print("reduced chunk exceeds limit")
                        processed_chunks.extend(
                            process_chunks([json.dumps(reduced_chunk)], msg, limit))
                        break
                else:
                    processed_chunks.extend(
                        process_chunks([json.dumps(elements[element_index])], msg, limit))

        else:
            print("chunk fits")
            processed_chunks.append(chunk)

    print("processed_chunks final:", processed_chunks)
    return processed_chunks


app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Routes


@app.get("/")
def api_home():
    return {'detail': 'Welcome to FastAPI TextGen Tutorial!'}


@app.post("/prompt")
def generate_text(request: PromptRequest):
    messages = [
        {
            "role": "system",
            "content": (
                    """You are an assistant for an accessibility browser extension. 
                    Your only task is to return a **valid JSON object** based on the user's request. 
                    The JSON must have this format:
                    { "signal": string, "message": string }
                    Valid signal codes:
                    """ + codes + """
                    Rules:
                    1. Always return JSON, never plain text
                    2. Do not include extra keys.
                    3. Do not escape JSON unnecessarily.
                    4. Request chunking using valid signal if user asks for analysis, summarization, or possible actions.
                    5. If unsure, default to {"signal": "m0", "message": "I did not understand the request."}"""
            )
        },
        {"role": "user", "content": request.prompt}
    ]

    token_count = count_tokens(format_messages(messages))
    if token_count > input_limit:
        return {"signal": "e0", "message": "Input exceeds token limit."}
    
    output = llm.create_chat_completion(
        messages=messages,
        max_tokens=1024,
        temperature=0
    )

    output_str = output["choices"][0]["message"]["content"]
    try:
        output_json = json.loads(output_str)
    except json.JSONDecodeError:
        output_json = {"signal": "m0", "message": output_str}

    return {"output": output_json}


@app.post("/analyze")
def analyze(request: AnalyzeRequest):
    analysis_results = []
    chunks = process_chunks(request.data, analysis_system_message, input_limit)
    print("chunks: ", chunks)
    if not chunks:
        print("chunks: ", chunks)
        return {"signal": "e0", "message": "No chunks."}

    manual_combination = False if input_limit/len(chunks) >= 90 else True

    for chunk in chunks:
        print("Analyzing chunk of size:", compute_input_size(
            chunk, analysis_system_message))
        output = llm.create_chat_completion(
            messages=[
                analysis_system_message,
                {"role": "user", "content": chunk}
            ],
            max_tokens=(input_limit) /
            len(chunks) if not manual_combination else input_limit,
            temperature=0
        )
        output_str = output["choices"][0]["message"]["content"]
        try:
            output_json = json.loads(output_str)
        except json.JSONDecodeError:
            output_json = {"signal": "e0",
                           "message": "Invalid JSON parsing."}
            print("JSON parsing error:", output_str)

        analysis_results.append(output_json)

    # combine results

    if not manual_combination:
        combined_result = json.dumps(analysis_results)
        print("Combined result: ", combined_result)
        output = llm.create_chat_completion(
            messages=[
                final_analysis_message,
                {"role": "user", "content": combined_result}
            ],
            # input might exccede the limit due to system message
            max_tokens=context_length -
            compute_input_size(json.dumps(analysis_results),
                               final_analysis_message),
            temperature=0)

        output_str = output["choices"][0]["message"]["content"]
        try:
            output_json = json.loads(output_str)
        except json.JSONDecodeError:
            output_json = {"signal": "e0",
                           "message": "Invalid JSON parsing." + output_str}
            print("JSON parsing error:", output_str)
    else:
        for result in analysis_results:
            if result.get("signal") != "a1":
                output_json = {"signal": "e0",
                               "message": "Chunk Analysis Failure."}
                return output_json
        output_json = {
            "signal": "a2",
            "message": "Analysis complete, results were combined manually.",
            "summary": " ".join([res.get("summary", "")
                                 for res in analysis_results]),
            "actions": [action for res in analysis_results
                        for action in res.get("actions", [])],
            "sections": [section for res in analysis_results
                         for section in res.get("sections", [])],
        }
    return output_json


if __name__ == "__main__" and MODE == "MOCK":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)