|
|
from app.backend.controllers.messages import register_message |
|
|
from app.core.document_validator import path_is_valid |
|
|
from app.core.response_parser import add_links |
|
|
from app.backend.models.users import User |
|
|
from app.settings import BASE_DIR |
|
|
from app.backend.controllers.chats import ( |
|
|
get_chat_with_messages, |
|
|
create_new_chat, |
|
|
update_title, |
|
|
list_user_chats |
|
|
) |
|
|
from app.backend.controllers.users import ( |
|
|
extract_user_from_context, |
|
|
get_current_user, |
|
|
get_latest_chat, |
|
|
refresh_cookie, |
|
|
authorize_user, |
|
|
check_cookie, |
|
|
create_user |
|
|
) |
|
|
from app.core.utils import ( |
|
|
construct_collection_name, |
|
|
create_collection, |
|
|
extend_context, |
|
|
initialize_rag, |
|
|
save_documents, |
|
|
protect_chat, |
|
|
TextHandler, |
|
|
PDFHandler, |
|
|
) |
|
|
|
|
|
from fastapi.templating import Jinja2Templates |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi import ( |
|
|
HTTPException, |
|
|
UploadFile, |
|
|
Request, |
|
|
Depends, |
|
|
FastAPI, |
|
|
Form, |
|
|
File, |
|
|
) |
|
|
from fastapi.responses import ( |
|
|
StreamingResponse, |
|
|
RedirectResponse, |
|
|
FileResponse, |
|
|
JSONResponse, |
|
|
) |
|
|
|
|
|
from typing import Optional |
|
|
import os |
|
|
|
|
|
|
|
|
api = FastAPI() |
|
|
rag = initialize_rag() |
|
|
|
|
|
origins = [ |
|
|
"*", |
|
|
] |
|
|
|
|
|
api.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=origins, |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
api.mount( |
|
|
"/chats_storage", |
|
|
StaticFiles(directory=os.path.join(BASE_DIR, "chats_storage")), |
|
|
name="chats_storage", |
|
|
) |
|
|
api.mount( |
|
|
"/static", |
|
|
StaticFiles(directory=os.path.join(BASE_DIR, "app", "frontend", "static")), |
|
|
name="static", |
|
|
) |
|
|
templates = Jinja2Templates( |
|
|
directory=os.path.join(BASE_DIR, "app", "frontend", "templates") |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@api.middleware("http") |
|
|
async def require_user(request: Request, call_next): |
|
|
print("&" * 40, "START MIDDLEWARE", "&" * 40) |
|
|
try: |
|
|
print(f"Path ----> {request.url.path}, Method ----> {request.method}, Port ----> {request.url.port}\n") |
|
|
|
|
|
stripped_path = request.url.path.strip("/") |
|
|
|
|
|
if ( |
|
|
stripped_path.startswith("pdfs") |
|
|
or "static/styles.css" in stripped_path |
|
|
or "favicon.ico" in stripped_path |
|
|
): |
|
|
return await call_next(request) |
|
|
|
|
|
user = get_current_user(request) |
|
|
authorized = True |
|
|
if user is None: |
|
|
authorized = False |
|
|
user = create_user() |
|
|
|
|
|
print(f"User in Context ----> {user.id}\n") |
|
|
|
|
|
request.state.current_user = user |
|
|
response = await call_next(request) |
|
|
|
|
|
if authorized: |
|
|
refresh_cookie(request=request, response=response) |
|
|
else: |
|
|
authorize_user(response, user) |
|
|
return response |
|
|
|
|
|
except Exception as exception: |
|
|
raise exception |
|
|
finally: |
|
|
print("&" * 40, "END MIDDLEWARE", "&" * 40, "\n\n") |
|
|
|
|
|
|
|
|
|
|
|
@api.post("/message_with_docs") |
|
|
async def send_message( |
|
|
request: Request, |
|
|
files: list[UploadFile] = File(None), |
|
|
prompt: str = Form(...), |
|
|
chat_id: str = Form(None), |
|
|
) -> StreamingResponse: |
|
|
status = 200 |
|
|
try: |
|
|
user = extract_user_from_context(request) |
|
|
print("-" * 100, "User ---->", user, "-" * 100, "\n\n") |
|
|
collection_name = construct_collection_name(user, chat_id) |
|
|
|
|
|
message_id = register_message(content=prompt, sender="user", chat_id=chat_id) |
|
|
|
|
|
await save_documents( |
|
|
collection_name, files=files, RAG=rag, user=user, chat_id=chat_id, message_id=message_id |
|
|
) |
|
|
|
|
|
return StreamingResponse( |
|
|
rag.generate_response_stream( |
|
|
collection_name=collection_name, user_prompt=prompt, stream=True |
|
|
), |
|
|
status, |
|
|
media_type="text/event-stream", |
|
|
) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
|
|
|
|
|
|
@api.post("/replace_message") |
|
|
async def replace_message(request: Request): |
|
|
data = await request.json() |
|
|
with open(os.path.join(BASE_DIR, "response.txt"), "w") as f: |
|
|
f.write(data.get("message", "")) |
|
|
updated_message = data.get("message", "") |
|
|
register_message( |
|
|
content=updated_message, sender="system", chat_id=data.get("chatId") |
|
|
) |
|
|
return JSONResponse({"updated_message": updated_message}) |
|
|
|
|
|
|
|
|
@api.get("/viewer/{path:path}") |
|
|
def show_document( |
|
|
request: Request, |
|
|
path: str, |
|
|
page: Optional[int] = 1, |
|
|
lines: Optional[str] = "1-1", |
|
|
start: Optional[int] = 0, |
|
|
): |
|
|
print(f"DEBUG: Show document with path: {path}, page: {page}, lines: {lines}, start: {start}") |
|
|
path = os.path.realpath(path) |
|
|
print(f"DEBUG: Real path: {path}") |
|
|
|
|
|
path = os.path.realpath(path) |
|
|
if not path_is_valid(path): |
|
|
return HTTPException(status_code=404, detail="Document not found") |
|
|
|
|
|
ext = path.split(".")[-1] |
|
|
if ext == "pdf": |
|
|
print("Open pdf file by path") |
|
|
return FileResponse(path=path) |
|
|
elif ext in ("txt", "csv", "md", "json"): |
|
|
print("Open txt file by path") |
|
|
return TextHandler(request, path=path, lines=lines, templates=templates) |
|
|
elif ext in ("docx", "doc"): |
|
|
return TextHandler( |
|
|
request, path=path, lines=lines, templates=templates |
|
|
) |
|
|
else: |
|
|
return FileResponse(path=path) |
|
|
|
|
|
|
|
|
|
|
|
@api.get("/list_chats") |
|
|
def list_chats_for_user(request: Request): |
|
|
user = extract_user_from_context(request) |
|
|
chats = list_user_chats(user.id) |
|
|
print(f"Chats for user {user.id}: {chats}") |
|
|
return JSONResponse({"chats": chats}) |
|
|
|
|
|
|
|
|
@api.get("/chats/{chat_id}") |
|
|
def show_chat(request: Request, chat_id: str): |
|
|
user = extract_user_from_context(request) |
|
|
|
|
|
if not protect_chat(user, chat_id): |
|
|
raise HTTPException(401, "Yod do not have rights to use this chat!") |
|
|
|
|
|
chat_data = get_chat_with_messages(chat_id) |
|
|
|
|
|
print(f"DEBUG: Data for chat '{chat_id}' from get_chat_with_messages: {chat_data}") |
|
|
|
|
|
if not chat_data: |
|
|
raise HTTPException(status_code=404, detail=f"Chat with id {chat_id} not found.") |
|
|
|
|
|
update_title(chat_data["chat_id"]) |
|
|
|
|
|
return JSONResponse(content=chat_data) |
|
|
|
|
|
|
|
|
@api.get("/") |
|
|
def last_user_chat(request: Request): |
|
|
user = extract_user_from_context(request) |
|
|
chat = get_latest_chat(user) |
|
|
|
|
|
if chat is None: |
|
|
print("new_chat") |
|
|
new_chat = create_new_chat("new chat", user) |
|
|
url = new_chat.get("url") |
|
|
|
|
|
try: |
|
|
create_collection(user, new_chat.get("chat_id"), rag) |
|
|
except Exception as e: |
|
|
raise HTTPException(500, e) |
|
|
|
|
|
else: |
|
|
url = f"/chats/{chat.id}" |
|
|
|
|
|
return RedirectResponse(url, status_code=303) |
|
|
|
|
|
|
|
|
|
|
|
@api.post("/new_chat") |
|
|
def create_chat(request: Request, title: Optional[str] = "new chat"): |
|
|
user = extract_user_from_context(request) |
|
|
new_chat_data = create_new_chat(title, user) |
|
|
if not new_chat_data.get("id"): |
|
|
raise HTTPException(500, "New chat could not be created.") |
|
|
|
|
|
create_collection(user, new_chat_data["id"], rag) |
|
|
|
|
|
return JSONResponse(new_chat_data) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
pass |
|
|
|