from app.backend.controllers.messages import register_message from app.core.document_validator import path_is_valid from app.core.response_parser import add_links from app.backend.models.users import User from app.settings import BASE_DIR from app.backend.controllers.chats import ( get_chat_with_messages, create_new_chat, update_title, list_user_chats ) from app.backend.controllers.users import ( extract_user_from_context, get_current_user, get_latest_chat, refresh_cookie, authorize_user, check_cookie, create_user ) from app.core.utils import ( construct_collection_name, create_collection, extend_context, initialize_rag, save_documents, protect_chat, TextHandler, PDFHandler, ) from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from fastapi import ( HTTPException, UploadFile, Request, Depends, FastAPI, Form, File, ) from fastapi.responses import ( StreamingResponse, RedirectResponse, FileResponse, JSONResponse, ) from typing import Optional import os # <------------------------------------- API -------------------------------------> api = FastAPI() rag = initialize_rag() origins = [ "*", ] api.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) api.mount( "/chats_storage", StaticFiles(directory=os.path.join(BASE_DIR, "chats_storage")), name="chats_storage", ) api.mount( "/static", StaticFiles(directory=os.path.join(BASE_DIR, "app", "frontend", "static")), name="static", ) templates = Jinja2Templates( directory=os.path.join(BASE_DIR, "app", "frontend", "templates") ) # <--------------------------------- Middleware ---------------------------------> @api.middleware("http") async def require_user(request: Request, call_next): print("&" * 40, "START MIDDLEWARE", "&" * 40) try: print(f"Path ----> {request.url.path}, Method ----> {request.method}, Port ----> {request.url.port}\n") stripped_path = request.url.path.strip("/") if ( stripped_path.startswith("pdfs") or "static/styles.css" in stripped_path or "favicon.ico" in stripped_path ): return await call_next(request) user = get_current_user(request) authorized = True if user is None: authorized = False user = create_user() print(f"User in Context ----> {user.id}\n") request.state.current_user = user response = await call_next(request) if authorized: refresh_cookie(request=request, response=response) else: authorize_user(response, user) return response except Exception as exception: raise exception finally: print("&" * 40, "END MIDDLEWARE", "&" * 40, "\n\n") # <--------------------------------- Common routes ---------------------------------> @api.post("/message_with_docs") async def send_message( request: Request, files: list[UploadFile] = File(None), prompt: str = Form(...), chat_id: str = Form(None), ) -> StreamingResponse: status = 200 try: user = extract_user_from_context(request) print("-" * 100, "User ---->", user, "-" * 100, "\n\n") collection_name = construct_collection_name(user, chat_id) message_id = register_message(content=prompt, sender="user", chat_id=chat_id) await save_documents( collection_name, files=files, RAG=rag, user=user, chat_id=chat_id, message_id=message_id ) return StreamingResponse( rag.generate_response_stream( collection_name=collection_name, user_prompt=prompt, stream=True ), status, media_type="text/event-stream", ) except Exception as e: print(e) @api.post("/replace_message") async def replace_message(request: Request): data = await request.json() with open(os.path.join(BASE_DIR, "response.txt"), "w") as f: f.write(data.get("message", "")) updated_message = data.get("message", "") register_message( content=updated_message, sender="system", chat_id=data.get("chatId") ) return JSONResponse({"updated_message": updated_message}) @api.get("/viewer/{path:path}") def show_document( request: Request, path: str, page: Optional[int] = 1, lines: Optional[str] = "1-1", start: Optional[int] = 0, ): print(f"DEBUG: Show document with path: {path}, page: {page}, lines: {lines}, start: {start}") path = os.path.realpath(path) print(f"DEBUG: Real path: {path}") path = os.path.realpath(path) if not path_is_valid(path): return HTTPException(status_code=404, detail="Document not found") ext = path.split(".")[-1] if ext == "pdf": print("Open pdf file by path") return FileResponse(path=path) elif ext in ("txt", "csv", "md", "json"): print("Open txt file by path") return TextHandler(request, path=path, lines=lines, templates=templates) elif ext in ("docx", "doc"): return TextHandler( request, path=path, lines=lines, templates=templates ) else: return FileResponse(path=path) # <--------------------------------- Get ---------------------------------> @api.get("/list_chats") def list_chats_for_user(request: Request): user = extract_user_from_context(request) chats = list_user_chats(user.id) print(f"Chats for user {user.id}: {chats}") return JSONResponse({"chats": chats}) @api.get("/chats/{chat_id}") def show_chat(request: Request, chat_id: str): user = extract_user_from_context(request) if not protect_chat(user, chat_id): raise HTTPException(401, "Yod do not have rights to use this chat!") chat_data = get_chat_with_messages(chat_id) print(f"DEBUG: Data for chat '{chat_id}' from get_chat_with_messages: {chat_data}") if not chat_data: raise HTTPException(status_code=404, detail=f"Chat with id {chat_id} not found.") update_title(chat_data["chat_id"]) return JSONResponse(content=chat_data) @api.get("/") def last_user_chat(request: Request): user = extract_user_from_context(request) chat = get_latest_chat(user) if chat is None: print("new_chat") new_chat = create_new_chat("new chat", user) url = new_chat.get("url") try: create_collection(user, new_chat.get("chat_id"), rag) except Exception as e: raise HTTPException(500, e) else: url = f"/chats/{chat.id}" return RedirectResponse(url, status_code=303) # <--------------------------------- Post ---------------------------------> @api.post("/new_chat") def create_chat(request: Request, title: Optional[str] = "new chat"): user = extract_user_from_context(request) new_chat_data = create_new_chat(title, user) if not new_chat_data.get("id"): raise HTTPException(500, "New chat could not be created.") create_collection(user, new_chat_data["id"], rag) return JSONResponse(new_chat_data) if __name__ == "__main__": pass