File size: 13,815 Bytes
96986cb
 
 
 
 
de96e47
34aa6dd
 
 
de96e47
34aa6dd
b31d7cd
de96e47
 
96986cb
 
de96e47
b31d7cd
de96e47
34aa6dd
96986cb
 
 
34aa6dd
 
 
 
96986cb
de96e47
b31d7cd
96986cb
34aa6dd
 
 
 
96986cb
de96e47
34aa6dd
 
1794d35
96986cb
34aa6dd
 
 
 
 
de96e47
34aa6dd
 
 
de96e47
34aa6dd
 
de96e47
34aa6dd
 
de96e47
34aa6dd
 
 
 
 
 
 
 
 
 
de96e47
34aa6dd
 
de96e47
34aa6dd
 
 
96986cb
de96e47
96986cb
de96e47
 
 
 
 
 
 
 
 
34aa6dd
96986cb
34aa6dd
96986cb
 
 
 
 
 
 
34aa6dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96986cb
34aa6dd
 
 
 
 
 
 
de96e47
34aa6dd
de96e47
 
34aa6dd
de96e47
34aa6dd
 
de96e47
 
b31d7cd
34aa6dd
de96e47
 
 
 
34aa6dd
 
 
 
 
 
 
 
 
 
 
 
de96e47
34aa6dd
 
de96e47
 
 
b31d7cd
de96e47
34aa6dd
 
 
 
 
 
 
 
 
b31d7cd
34aa6dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b31d7cd
34aa6dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b31d7cd
34aa6dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de96e47
34aa6dd
 
 
de96e47
b31d7cd
34aa6dd
 
 
 
 
de96e47
34aa6dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96986cb
34aa6dd
 
 
 
 
 
 
 
 
 
96986cb
de96e47
 
34aa6dd
 
 
 
 
 
de96e47
b31d7cd
de96e47
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import os
import io
import base64
import json
import time
import subprocess
import threading # type: ignore
import shutil # type: ignore
from typing import List, Any, Dict, Union, Optional

from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse, Response
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from PIL import Image
from dotenv import load_dotenv
import ollama
import spaces # type: ignore
import gradio as gr
import google.generativeai as genai

load_dotenv()

ALLOWED_ORIGINS_RAW: Optional[str] = os.getenv("ALLOWED_ORIGINS")
MODEL_NAME: Optional[str] = os.getenv("MODEL_NAME")
GEMINI_API_KEY: Optional[str] = os.getenv("GEMINI_API_KEY")
GEMINI_MODELS_RAW: Optional[str] = os.getenv("GEMINI_MODELS")

SERVICE_MAP_STR = os.getenv("SERVICE_CODES_MAP", "{}")
SERVICE_MAP = json.loads(SERVICE_MAP_STR)

GEMINI_SYSTEM_INSTRUCTION = os.getenv("GEMINI_SYSTEM_INSTRUCTION", "{}")

ALLOWED_ORIGINS = ["*"] if ALLOWED_ORIGINS_RAW == "*" else [origin.strip() for origin in ALLOWED_ORIGINS_RAW.split(",")] # type: ignore
GEMINI_MODEL_LIST: List[str] = [model.strip() for model in GEMINI_MODELS_RAW.split(',')] if GEMINI_MODELS_RAW else []

print(f"ALLOWED_ORIGINS: {ALLOWED_ORIGINS}")
print(f"LOCAL_MODEL_NAME: {MODEL_NAME}")
print(f"GEMINI_MODELS: {GEMINI_MODEL_LIST}")
print(f"GEMINI_API_KEY: {GEMINI_API_KEY}")

# def setup_ollama():
#     print("Checking Ollama setup...")
#     if not shutil.which("ollama"):
#         print("Ollama not found. Installing...")
#         subprocess.run("curl -fsSL https://ollama.com/install.sh | sh", shell=True, check=True)
    
#     def run_server():
#         print("Starting Ollama Serve...")
#         subprocess.Popen(["ollama", "serve"])
    
#     t = threading.Thread(target=run_server, daemon=True)
#     t.start()
    
#     print("Waiting for Ollama to spin up...")
#     time.sleep(5)
    
#     print(f"Pulling Model: {MODEL_NAME}...")
#     try:
#         subprocess.run(["ollama", "pull", MODEL_NAME], check=True) # type: ignore
#         print("Model pulled successfully.")
#     except Exception as e:
#         print(f"Error pulling model: {e}")

# setup_ollama()

if GEMINI_API_KEY:
    try:
        genai.configure(api_key=GEMINI_API_KEY) # type: ignore
        print("Gemini client configured successfully.")
    except Exception as e:
        raise EnvironmentError(f"Error configuring Gemini: {e}")
else:
    raise EnvironmentError("Warning: GEMINI_API_KEY not found. The /api/analyze/gemini endpoint and fallback will be unavailable.")

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=ALLOWED_ORIGINS,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

def process_image_to_base64(image_bytes: bytes) -> Union[str, None]:
    """Converts image bytes to a base64 encoded string."""
    try:
        img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        buffered = io.BytesIO()
        img.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')
    except Exception as e:
        print(f"Error processing image: {e}")
        return None

async def process_uploaded_files(images: List[UploadFile]) -> Dict[str, List[Any]]:
    """Reads uploaded files and converts them to bytes and base64 strings."""
    if not images:
        raise HTTPException(status_code=400, detail="Wajib melampirkan minimal 1 foto bukti.")

    image_bytes_list: List[bytes] = []
    base64_images: List[str] = []
    
    for img_file in images:
        content = await img_file.read()
        if len(content) > 0:
            image_bytes_list.append(content)
            b64 = process_image_to_base64(content)
            if b64:
                base64_images.append(b64)

    if not base64_images:
        raise HTTPException(status_code=400, detail="File gambar tidak valid atau corrupt.")

    return {"bytes": image_bytes_list, "b64": base64_images}

def validate_ai_output(ai_content: Dict[str, Any]) -> Dict[str, Any]:
    """Validates the JSON output from an AI model against the required structure and values."""
    required_keys = ["title", "category", "priority", "service_code"]
    missing_keys = [key for key in required_keys if key not in ai_content]
    if missing_keys:
        raise ValueError(f"Missing keys in AI JSON response: {', '.join(missing_keys)}")

    if not str(ai_content.get("title", "")).strip():
        raise ValueError("AI returned an empty title")

    service_code = ai_content["service_code"]
    if service_code not in SERVICE_MAP:
        raise ValueError(f"Invalid service_code '{service_code}'. Not found in service map.")

    expected_category = SERVICE_MAP[service_code]
    if ai_content["category"] != expected_category:
        raise ValueError(f"Category mismatch for code {service_code}. Got '{ai_content['category']}', expected '{expected_category}'")

    priority = str(ai_content["priority"]).lower()
    if priority not in ['high', 'medium', 'low']:
        raise ValueError(f"Invalid priority value: '{priority}'")
    
    ai_content["priority"] = priority
    return ai_content

@spaces.GPU(duration=60)
def run_local_inference(report_text: str, base64_images: List[str]) -> Dict[str, Any]:
    """Runs inference using the local Ollama model."""
    print("Starting Local GPU Inference...")
    try:
        ollama.show(MODEL_NAME) # type: ignore
    except Exception:
        print("Model not found in GPU context, pulling again...")
        subprocess.run(["ollama", "pull", MODEL_NAME], check=True) # type: ignore

    response = ollama.chat( # type: ignore
        model=MODEL_NAME, # type: ignore
        messages=[{
            'role': 'user',
            'content': report_text,
            'images': base64_images,
        }],
        format='json',
        options={'temperature': 0.1}
    )
    return response # type: ignore

def run_gemini_inference(report_text: str, image_bytes_list: List[bytes], model_name: str) -> Dict[str, Any]:
    """Runs inference using the Google Gemini model."""
    print(f"Starting Gemini Inference with model: {model_name}...")
    if not GEMINI_API_KEY:
        raise ConnectionError("GEMINI_API_KEY is not configured.")

    model = genai.GenerativeModel(model_name, system_instruction=GEMINI_SYSTEM_INSTRUCTION) # type: ignore
    pil_images = [Image.open(io.BytesIO(content)) for content in image_bytes_list]
    
    response = model.generate_content([report_text, *pil_images], generation_config={"response_mime_type": "application/json"}) # type: ignore
    
    ai_content = json.loads(response.text)
    return ai_content

@app.get("/")
def health_check():
    return Response("Python Backend is running.")

@app.post("/api/analyze/local")
async def analyze_local(report: str = Form(...), images: List[UploadFile] = File(...)): # type: ignore
    """Endpoint to analyze a report using only the local Ollama model."""
    if not report or len(report) < 10:
        raise HTTPException(status_code=400, detail="Deskripsi laporan wajib diisi minimal 10 karakter.")
    
    processed_images = await process_uploaded_files(images)
    base64_images = processed_images["b64"]
    
    try:
        response_raw = run_local_inference(report, base64_images)
        if 'message' not in response_raw or 'content' not in response_raw['message']:
            raise ValueError("Empty or invalid response structure from local AI")

        ai_content = validate_ai_output(json.loads(response_raw['message']['content']))

        return { # type: ignore
            "status": "success",
            "data": ai_content,
            "meta": {
                "model": MODEL_NAME,
                'processing_time_sec': (response_raw.get("total_duration", 0)) / 1e9,
                "images_analyzed": len(base64_images),
            }
        }
    except Exception as e:
        print(f"Local analysis failed: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Local AI Failed: {str(e)}")

@app.post("/api/analyze/gemini")
async def analyze_gemini(report: str = Form(...), images: List[UploadFile] = File(...)): # type: ignore
    """Endpoint to analyze a report using only the Gemini model."""
    if not report or len(report) < 10:
        raise HTTPException(status_code=400, detail="Deskripsi laporan wajib diisi minimal 10 karakter.")
        
    processed_images = await process_uploaded_files(images)
    image_bytes_list = processed_images["bytes"]

    if not GEMINI_MODEL_LIST:
        raise HTTPException(status_code=501, detail="No Gemini models configured in the environment.")

    primary_gemini_model = GEMINI_MODEL_LIST[0]
    
    try:
        start_time = time.time()
        ai_content = validate_ai_output(run_gemini_inference(report, image_bytes_list, primary_gemini_model))
        end_time = time.time()

        return { # type: ignore
            "status": "success",
            "data": ai_content,
            "meta": {
                "model": primary_gemini_model,
                'processing_time_sec': end_time - start_time,
                "images_analyzed": len(image_bytes_list),
            }
        }
    except Exception as e:
        print(f"Gemini analysis failed: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Gemini AI Failed: {str(e)}")

@app.post("/api/analyze")
async def analyze_with_fallback(report: str = Form(...), images: List[UploadFile] = File(...)): # type: ignore
    """
    Main analysis endpoint. Tries the local model up to 3 times.
    If it fails, it falls back to the Gemini model.
    """
    if not report or len(report) < 10:
        raise HTTPException(status_code=400, detail="Deskripsi laporan wajib diisi minimal 10 karakter.")
        
    processed_images = await process_uploaded_files(images)
    base64_images = processed_images["b64"] # type: ignore
    image_bytes_list = processed_images["bytes"]
    
    last_local_exception = None
    last_gemini_exception = None
    
    # max_local_retries = 3 # type: ignore
    # for attempt in range(max_local_retries):
    #     try:
    #         print(f"Attempting Local AI Analysis... ({attempt + 1}/{max_local_retries})")
    #         response_raw = run_local_inference(report, base64_images)

    #         if 'message' not in response_raw or 'content' not in response_raw['message']:
    #             raise ValueError("Empty response structure from local AI")

    #         ai_content = validate_ai_output(json.loads(response_raw['message']['content']))
            
    #         response = { # type: ignore
    #             "status": "success",
    #             "data": ai_content,
    #             "meta": {
    #                 "model": MODEL_NAME,
    #                 'processing_time_sec': (response_raw.get("total_duration", 0)) / 1e9,
    #                 "images_analyzed": len(base64_images),
    #                 "source": "local",
    #                 "attempts": attempt + 1
    #             }
    #         }
            
    #         print("Local AI Success")
    #         print(json.dumps(response, indent=2, ensure_ascii=True))
            
    #         return response # type: ignore
    #     except Exception as e:
    #         print(f"Local AI Attempt {attempt + 1} failed: {str(e)}")
    #         last_local_exception = e
    #         time.sleep(1)

    # print(f"Local model failed. Falling back to Gemini models.")
    
    if not GEMINI_MODEL_LIST:
        return JSONResponse(
            status_code=500,
            content={
                "status": "error",
                "message": "Local AI failed and no Gemini models are configured for fallback.",
                "local_model_error": str(last_local_exception),
            }
        )
        
    print(GEMINI_MODEL_LIST)
        
    for model_name in [model_name for model_name in GEMINI_MODEL_LIST for _ in range(3)]:
        try:
            start_time = time.time()
            ai_content = validate_ai_output(run_gemini_inference(report, image_bytes_list, model_name))
            end_time = time.time()
            
            response = { # type: ignore
                "status": "success",
                "data": ai_content,
                "meta": {
                    "model": model_name,
                    'processing_time_sec': end_time - start_time,
                    "images_analyzed": len(image_bytes_list),
                    "source": "gemini_fallback"
                }
            }
            
            print(f"Gemini AI Fallback Success with model {model_name}")
            print(json.dumps(response, indent=2, ensure_ascii=True))
            
            return response # type: ignore
        except Exception as e:
            print(f"Gemini AI Fallback with model {model_name} failed: {str(e)}")
            last_gemini_exception = e
            continue
    
    return JSONResponse(
        status_code=500,
        content={
            "status": "error", 
            "message": "All AI models (Local and Gemini fallbacks) failed to process the request.",
            "local_model_error": str(last_local_exception),
            "last_gemini_model_error": str(last_gemini_exception)
        }
    )

if __name__ == "__main__":
    with gr.Blocks() as demo:
        gr.Markdown("# LAPOR AI API Backend")
        gr.Markdown(
            "This space hosts the API endpoints for analyzing citizen reports. "
            "The primary endpoint is `/api/analyze` which uses a local model with a Gemini fallback."
        )
        gr.Markdown(f"**Local Model:** `{MODEL_NAME}`")
        gr.Markdown(f"**Fallback Models (in order):** `{', '.join(GEMINI_MODEL_LIST)}`")
    
    app = gr.mount_gradio_app(app, demo, path="/") # type: ignore
    
    uvicorn.run(app, host="0.0.0.0", port=7860)