Spaces:
Paused
Paused
| # | |
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| import json | |
| import os | |
| import re | |
| from datetime import datetime, timedelta | |
| from flask import request, Response | |
| from flask_login import login_required, current_user | |
| from api.db import FileType, ParserType, FileSource | |
| from api.db.db_models import APIToken, API4Conversation, Task, File | |
| from api.db.services import duplicate_name | |
| from api.db.services.api_service import APITokenService, API4ConversationService | |
| from api.db.services.dialog_service import DialogService, chat | |
| from api.db.services.document_service import DocumentService | |
| from api.db.services.file2document_service import File2DocumentService | |
| from api.db.services.file_service import FileService | |
| from api.db.services.knowledgebase_service import KnowledgebaseService | |
| from api.db.services.task_service import queue_tasks, TaskService | |
| from api.db.services.user_service import UserTenantService | |
| from api.settings import RetCode, retrievaler | |
| from api.utils import get_uuid, current_timestamp, datetime_format | |
| from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request | |
| from itsdangerous import URLSafeTimedSerializer | |
| from api.utils.file_utils import filename_type, thumbnail | |
| from rag.utils.minio_conn import MINIO | |
| from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService | |
| from agent.canvas import Canvas | |
| from functools import partial | |
| def generate_confirmation_token(tenent_id): | |
| serializer = URLSafeTimedSerializer(tenent_id) | |
| return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34] | |
| def new_token(): | |
| req = request.json | |
| try: | |
| tenants = UserTenantService.query(user_id=current_user.id) | |
| if not tenants: | |
| return get_data_error_result(retmsg="Tenant not found!") | |
| tenant_id = tenants[0].tenant_id | |
| obj = {"tenant_id": tenant_id, "token": generate_confirmation_token(tenant_id), | |
| "create_time": current_timestamp(), | |
| "create_date": datetime_format(datetime.now()), | |
| "update_time": None, | |
| "update_date": None | |
| } | |
| if req.get("canvas_id"): | |
| obj["dialog_id"] = req["canvas_id"] | |
| obj["source"] = "agent" | |
| else: | |
| obj["dialog_id"] = req["dialog_id"] | |
| if not APITokenService.save(**obj): | |
| return get_data_error_result(retmsg="Fail to new a dialog!") | |
| return get_json_result(data=obj) | |
| except Exception as e: | |
| return server_error_response(e) | |
| def token_list(): | |
| try: | |
| tenants = UserTenantService.query(user_id=current_user.id) | |
| if not tenants: | |
| return get_data_error_result(retmsg="Tenant not found!") | |
| id = request.args.get("dialog_id", request.args["canvas_id"]) | |
| objs = APITokenService.query(tenant_id=tenants[0].tenant_id, dialog_id=id) | |
| return get_json_result(data=[o.to_dict() for o in objs]) | |
| except Exception as e: | |
| return server_error_response(e) | |
| def rm(): | |
| req = request.json | |
| try: | |
| for token in req["tokens"]: | |
| APITokenService.filter_delete( | |
| [APIToken.tenant_id == req["tenant_id"], APIToken.token == token]) | |
| return get_json_result(data=True) | |
| except Exception as e: | |
| return server_error_response(e) | |
| def stats(): | |
| try: | |
| tenants = UserTenantService.query(user_id=current_user.id) | |
| if not tenants: | |
| return get_data_error_result(retmsg="Tenant not found!") | |
| objs = API4ConversationService.stats( | |
| tenants[0].tenant_id, | |
| request.args.get( | |
| "from_date", | |
| (datetime.now() - | |
| timedelta( | |
| days=7)).strftime("%Y-%m-%d 24:00:00")), | |
| request.args.get( | |
| "to_date", | |
| datetime.now().strftime("%Y-%m-%d %H:%M:%S")), | |
| "agent" if request.args.get("canvas_id") else None) | |
| res = { | |
| "pv": [(o["dt"], o["pv"]) for o in objs], | |
| "uv": [(o["dt"], o["uv"]) for o in objs], | |
| "speed": [(o["dt"], float(o["tokens"]) / (float(o["duration"] + 0.1))) for o in objs], | |
| "tokens": [(o["dt"], float(o["tokens"]) / 1000.) for o in objs], | |
| "round": [(o["dt"], o["round"]) for o in objs], | |
| "thumb_up": [(o["dt"], o["thumb_up"]) for o in objs] | |
| } | |
| return get_json_result(data=res) | |
| except Exception as e: | |
| return server_error_response(e) | |
| def set_conversation(): | |
| token = request.headers.get('Authorization').split()[1] | |
| objs = APIToken.query(token=token) | |
| if not objs: | |
| return get_json_result( | |
| data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) | |
| req = request.json | |
| try: | |
| if objs[0].source == "agent": | |
| e, c = UserCanvasService.get_by_id(objs[0].dialog_id) | |
| if not e: | |
| return server_error_response("canvas not found.") | |
| conv = { | |
| "id": get_uuid(), | |
| "dialog_id": c.id, | |
| "user_id": request.args.get("user_id", ""), | |
| "message": [{"role": "assistant", "content": "Hi there!"}], | |
| "source": "agent" | |
| } | |
| API4ConversationService.save(**conv) | |
| return get_json_result(data=conv) | |
| else: | |
| e, dia = DialogService.get_by_id(objs[0].dialog_id) | |
| if not e: | |
| return get_data_error_result(retmsg="Dialog not found") | |
| conv = { | |
| "id": get_uuid(), | |
| "dialog_id": dia.id, | |
| "user_id": request.args.get("user_id", ""), | |
| "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] | |
| } | |
| API4ConversationService.save(**conv) | |
| return get_json_result(data=conv) | |
| except Exception as e: | |
| return server_error_response(e) | |
| def completion(): | |
| token = request.headers.get('Authorization').split()[1] | |
| objs = APIToken.query(token=token) | |
| if not objs: | |
| return get_json_result( | |
| data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) | |
| req = request.json | |
| e, conv = API4ConversationService.get_by_id(req["conversation_id"]) | |
| if not e: | |
| return get_data_error_result(retmsg="Conversation not found!") | |
| if "quote" not in req: req["quote"] = False | |
| msg = [] | |
| for m in req["messages"]: | |
| if m["role"] == "system": | |
| continue | |
| if m["role"] == "assistant" and not msg: | |
| continue | |
| msg.append({"role": m["role"], "content": m["content"]}) | |
| def fillin_conv(ans): | |
| nonlocal conv | |
| if not conv.reference: | |
| conv.reference.append(ans["reference"]) | |
| else: | |
| conv.reference[-1] = ans["reference"] | |
| conv.message[-1] = {"role": "assistant", "content": ans["answer"]} | |
| def rename_field(ans): | |
| reference = ans['reference'] | |
| if not isinstance(reference, dict): | |
| return | |
| for chunk_i in reference.get('chunks', []): | |
| if 'docnm_kwd' in chunk_i: | |
| chunk_i['doc_name'] = chunk_i['docnm_kwd'] | |
| chunk_i.pop('docnm_kwd') | |
| try: | |
| if conv.source == "agent": | |
| stream = req.get("stream", True) | |
| conv.message.append(msg[-1]) | |
| e, cvs = UserCanvasService.get_by_id(conv.dialog_id) | |
| if not e: | |
| return server_error_response("canvas not found.") | |
| del req["conversation_id"] | |
| del req["messages"] | |
| if not isinstance(cvs.dsl, str): | |
| cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) | |
| if not conv.reference: | |
| conv.reference = [] | |
| conv.message.append({"role": "assistant", "content": ""}) | |
| conv.reference.append({"chunks": [], "doc_aggs": []}) | |
| final_ans = {"reference": [], "content": ""} | |
| canvas = Canvas(cvs.dsl, objs[0].tenant_id) | |
| canvas.messages.append(msg[-1]) | |
| canvas.add_user_input(msg[-1]["content"]) | |
| answer = canvas.run(stream=stream) | |
| assert answer is not None, "Nothing. Is it over?" | |
| if stream: | |
| assert isinstance(answer, partial), "Nothing. Is it over?" | |
| def sse(): | |
| nonlocal answer, cvs, conv | |
| try: | |
| for ans in answer(): | |
| for k in ans.keys(): | |
| final_ans[k] = ans[k] | |
| ans = {"answer": ans["content"], "reference": ans.get("reference", [])} | |
| fillin_conv(ans) | |
| rename_field(ans) | |
| yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, | |
| ensure_ascii=False) + "\n\n" | |
| canvas.messages.append({"role": "assistant", "content": final_ans["content"]}) | |
| if final_ans.get("reference"): | |
| canvas.reference.append(final_ans["reference"]) | |
| cvs.dsl = json.loads(str(canvas)) | |
| API4ConversationService.append_message(conv.id, conv.to_dict()) | |
| except Exception as e: | |
| yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), | |
| "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, | |
| ensure_ascii=False) + "\n\n" | |
| yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" | |
| resp = Response(sse(), mimetype="text/event-stream") | |
| resp.headers.add_header("Cache-control", "no-cache") | |
| resp.headers.add_header("Connection", "keep-alive") | |
| resp.headers.add_header("X-Accel-Buffering", "no") | |
| resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") | |
| return resp | |
| final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" | |
| canvas.messages.append({"role": "assistant", "content": final_ans["content"]}) | |
| if final_ans.get("reference"): | |
| canvas.reference.append(final_ans["reference"]) | |
| cvs.dsl = json.loads(str(canvas)) | |
| result = None | |
| for ans in answer(): | |
| ans = {"answer": ans["content"], "reference": ans.get("reference", [])} | |
| result = ans | |
| fillin_conv(ans) | |
| API4ConversationService.append_message(conv.id, conv.to_dict()) | |
| break | |
| rename_field(result) | |
| return get_json_result(data=result) | |
| #******************For dialog****************** | |
| conv.message.append(msg[-1]) | |
| e, dia = DialogService.get_by_id(conv.dialog_id) | |
| if not e: | |
| return get_data_error_result(retmsg="Dialog not found!") | |
| del req["conversation_id"] | |
| del req["messages"] | |
| if not conv.reference: | |
| conv.reference = [] | |
| conv.message.append({"role": "assistant", "content": ""}) | |
| conv.reference.append({"chunks": [], "doc_aggs": []}) | |
| def stream(): | |
| nonlocal dia, msg, req, conv | |
| try: | |
| for ans in chat(dia, msg, True, **req): | |
| fillin_conv(ans) | |
| rename_field(ans) | |
| yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, | |
| ensure_ascii=False) + "\n\n" | |
| API4ConversationService.append_message(conv.id, conv.to_dict()) | |
| except Exception as e: | |
| yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), | |
| "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, | |
| ensure_ascii=False) + "\n\n" | |
| yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" | |
| if req.get("stream", True): | |
| resp = Response(stream(), mimetype="text/event-stream") | |
| resp.headers.add_header("Cache-control", "no-cache") | |
| resp.headers.add_header("Connection", "keep-alive") | |
| resp.headers.add_header("X-Accel-Buffering", "no") | |
| resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") | |
| return resp | |
| answer = None | |
| for ans in chat(dia, msg, **req): | |
| answer = ans | |
| fillin_conv(ans) | |
| API4ConversationService.append_message(conv.id, conv.to_dict()) | |
| break | |
| rename_field(answer) | |
| return get_json_result(data=answer) | |
| except Exception as e: | |
| return server_error_response(e) | |
| # @login_required | |
| def get(conversation_id): | |
| try: | |
| e, conv = API4ConversationService.get_by_id(conversation_id) | |
| if not e: | |
| return get_data_error_result(retmsg="Conversation not found!") | |
| conv = conv.to_dict() | |
| for referenct_i in conv['reference']: | |
| if referenct_i is None or len(referenct_i) == 0: | |
| continue | |
| for chunk_i in referenct_i['chunks']: | |
| if 'docnm_kwd' in chunk_i.keys(): | |
| chunk_i['doc_name'] = chunk_i['docnm_kwd'] | |
| chunk_i.pop('docnm_kwd') | |
| return get_json_result(data=conv) | |
| except Exception as e: | |
| return server_error_response(e) | |
| def upload(): | |
| token = request.headers.get('Authorization').split()[1] | |
| objs = APIToken.query(token=token) | |
| if not objs: | |
| return get_json_result( | |
| data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) | |
| kb_name = request.form.get("kb_name").strip() | |
| tenant_id = objs[0].tenant_id | |
| try: | |
| e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id) | |
| if not e: | |
| return get_data_error_result( | |
| retmsg="Can't find this knowledgebase!") | |
| kb_id = kb.id | |
| except Exception as e: | |
| return server_error_response(e) | |
| if 'file' not in request.files: | |
| return get_json_result( | |
| data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR) | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return get_json_result( | |
| data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) | |
| root_folder = FileService.get_root_folder(tenant_id) | |
| pf_id = root_folder["id"] | |
| FileService.init_knowledgebase_docs(pf_id, tenant_id) | |
| kb_root_folder = FileService.get_kb_folder(tenant_id) | |
| kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) | |
| try: | |
| if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)): | |
| return get_data_error_result( | |
| retmsg="Exceed the maximum file number of a free user!") | |
| filename = duplicate_name( | |
| DocumentService.query, | |
| name=file.filename, | |
| kb_id=kb_id) | |
| filetype = filename_type(filename) | |
| if not filetype: | |
| return get_data_error_result( | |
| retmsg="This type of file has not been supported yet!") | |
| location = filename | |
| while MINIO.obj_exist(kb_id, location): | |
| location += "_" | |
| blob = request.files['file'].read() | |
| MINIO.put(kb_id, location, blob) | |
| doc = { | |
| "id": get_uuid(), | |
| "kb_id": kb.id, | |
| "parser_id": kb.parser_id, | |
| "parser_config": kb.parser_config, | |
| "created_by": kb.tenant_id, | |
| "type": filetype, | |
| "name": filename, | |
| "location": location, | |
| "size": len(blob), | |
| "thumbnail": thumbnail(filename, blob) | |
| } | |
| form_data = request.form | |
| if "parser_id" in form_data.keys(): | |
| if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]: | |
| doc["parser_id"] = request.form.get("parser_id").strip() | |
| if doc["type"] == FileType.VISUAL: | |
| doc["parser_id"] = ParserType.PICTURE.value | |
| if doc["type"] == FileType.AURAL: | |
| doc["parser_id"] = ParserType.AUDIO.value | |
| if re.search(r"\.(ppt|pptx|pages)$", filename): | |
| doc["parser_id"] = ParserType.PRESENTATION.value | |
| doc_result = DocumentService.insert(doc) | |
| FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id) | |
| except Exception as e: | |
| return server_error_response(e) | |
| if "run" in form_data.keys(): | |
| if request.form.get("run").strip() == "1": | |
| try: | |
| info = {"run": 1, "progress": 0} | |
| info["progress_msg"] = "" | |
| info["chunk_num"] = 0 | |
| info["token_num"] = 0 | |
| DocumentService.update_by_id(doc["id"], info) | |
| # if str(req["run"]) == TaskStatus.CANCEL.value: | |
| tenant_id = DocumentService.get_tenant_id(doc["id"]) | |
| if not tenant_id: | |
| return get_data_error_result(retmsg="Tenant not found!") | |
| # e, doc = DocumentService.get_by_id(doc["id"]) | |
| TaskService.filter_delete([Task.doc_id == doc["id"]]) | |
| e, doc = DocumentService.get_by_id(doc["id"]) | |
| doc = doc.to_dict() | |
| doc["tenant_id"] = tenant_id | |
| bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"]) | |
| queue_tasks(doc, bucket, name) | |
| except Exception as e: | |
| return server_error_response(e) | |
| return get_json_result(data=doc_result.to_json()) | |
| # @login_required | |
| def list_chunks(): | |
| token = request.headers.get('Authorization').split()[1] | |
| objs = APIToken.query(token=token) | |
| if not objs: | |
| return get_json_result( | |
| data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) | |
| req = request.json | |
| try: | |
| if "doc_name" in req.keys(): | |
| tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name']) | |
| doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name']) | |
| elif "doc_id" in req.keys(): | |
| tenant_id = DocumentService.get_tenant_id(req['doc_id']) | |
| doc_id = req['doc_id'] | |
| else: | |
| return get_json_result( | |
| data=False, retmsg="Can't find doc_name or doc_id" | |
| ) | |
| res = retrievaler.chunk_list(doc_id=doc_id, tenant_id=tenant_id) | |
| res = [ | |
| { | |
| "content": res_item["content_with_weight"], | |
| "doc_name": res_item["docnm_kwd"], | |
| "img_id": res_item["img_id"] | |
| } for res_item in res | |
| ] | |
| except Exception as e: | |
| return server_error_response(e) | |
| return get_json_result(data=res) | |
| # @login_required | |
| def list_kb_docs(): | |
| token = request.headers.get('Authorization').split()[1] | |
| objs = APIToken.query(token=token) | |
| if not objs: | |
| return get_json_result( | |
| data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) | |
| req = request.json | |
| tenant_id = objs[0].tenant_id | |
| kb_name = req.get("kb_name", "").strip() | |
| try: | |
| e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id) | |
| if not e: | |
| return get_data_error_result( | |
| retmsg="Can't find this knowledgebase!") | |
| kb_id = kb.id | |
| except Exception as e: | |
| return server_error_response(e) | |
| page_number = int(req.get("page", 1)) | |
| items_per_page = int(req.get("page_size", 15)) | |
| orderby = req.get("orderby", "create_time") | |
| desc = req.get("desc", True) | |
| keywords = req.get("keywords", "") | |
| try: | |
| docs, tol = DocumentService.get_by_kb_id( | |
| kb_id, page_number, items_per_page, orderby, desc, keywords) | |
| docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs] | |
| return get_json_result(data={"total": tol, "docs": docs}) | |
| except Exception as e: | |
| return server_error_response(e) | |
| # @login_required | |
| def document_rm(): | |
| token = request.headers.get('Authorization').split()[1] | |
| objs = APIToken.query(token=token) | |
| if not objs: | |
| return get_json_result( | |
| data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) | |
| tenant_id = objs[0].tenant_id | |
| req = request.json | |
| doc_ids = [] | |
| try: | |
| doc_ids = [DocumentService.get_doc_id_by_doc_name(doc_name) for doc_name in req.get("doc_names", [])] | |
| for doc_id in req.get("doc_ids", []): | |
| if doc_id not in doc_ids: | |
| doc_ids.append(doc_id) | |
| if not doc_ids: | |
| return get_json_result( | |
| data=False, retmsg="Can't find doc_names or doc_ids" | |
| ) | |
| except Exception as e: | |
| return server_error_response(e) | |
| root_folder = FileService.get_root_folder(tenant_id) | |
| pf_id = root_folder["id"] | |
| FileService.init_knowledgebase_docs(pf_id, tenant_id) | |
| errors = "" | |
| for doc_id in doc_ids: | |
| try: | |
| e, doc = DocumentService.get_by_id(doc_id) | |
| if not e: | |
| return get_data_error_result(retmsg="Document not found!") | |
| tenant_id = DocumentService.get_tenant_id(doc_id) | |
| if not tenant_id: | |
| return get_data_error_result(retmsg="Tenant not found!") | |
| b, n = File2DocumentService.get_minio_address(doc_id=doc_id) | |
| if not DocumentService.remove_document(doc, tenant_id): | |
| return get_data_error_result( | |
| retmsg="Database error (Document removal)!") | |
| f2d = File2DocumentService.get_by_document_id(doc_id) | |
| FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) | |
| File2DocumentService.delete_by_document_id(doc_id) | |
| MINIO.rm(b, n) | |
| except Exception as e: | |
| errors += str(e) | |
| if errors: | |
| return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR) | |
| return get_json_result(data=True) | |
| def completion_faq(): | |
| import base64 | |
| req = request.json | |
| token = req["Authorization"] | |
| objs = APIToken.query(token=token) | |
| if not objs: | |
| return get_json_result( | |
| data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) | |
| e, conv = API4ConversationService.get_by_id(req["conversation_id"]) | |
| if not e: | |
| return get_data_error_result(retmsg="Conversation not found!") | |
| if "quote" not in req: req["quote"] = True | |
| msg = [] | |
| msg.append({"role": "user", "content": req["word"]}) | |
| try: | |
| conv.message.append(msg[-1]) | |
| e, dia = DialogService.get_by_id(conv.dialog_id) | |
| if not e: | |
| return get_data_error_result(retmsg="Dialog not found!") | |
| del req["conversation_id"] | |
| if not conv.reference: | |
| conv.reference = [] | |
| conv.message.append({"role": "assistant", "content": ""}) | |
| conv.reference.append({"chunks": [], "doc_aggs": []}) | |
| def fillin_conv(ans): | |
| nonlocal conv | |
| if not conv.reference: | |
| conv.reference.append(ans["reference"]) | |
| else: | |
| conv.reference[-1] = ans["reference"] | |
| conv.message[-1] = {"role": "assistant", "content": ans["answer"]} | |
| data_type_picture = { | |
| "type": 3, | |
| "url": "base64 content" | |
| } | |
| data = [ | |
| { | |
| "type": 1, | |
| "content": "" | |
| } | |
| ] | |
| ans = "" | |
| for a in chat(dia, msg, stream=False, **req): | |
| ans = a | |
| break | |
| data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"]) | |
| fillin_conv(ans) | |
| API4ConversationService.append_message(conv.id, conv.to_dict()) | |
| chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])] | |
| for chunk_idx in chunk_idxs[:1]: | |
| if ans["reference"]["chunks"][chunk_idx]["img_id"]: | |
| try: | |
| bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-") | |
| response = MINIO.get(bkt, nm) | |
| data_type_picture["url"] = base64.b64encode(response).decode('utf-8') | |
| data.append(data_type_picture) | |
| break | |
| except Exception as e: | |
| return server_error_response(e) | |
| response = {"code": 200, "msg": "success", "data": data} | |
| return response | |
| except Exception as e: | |
| return server_error_response(e) | |
| def retrieval(): | |
| token = request.headers.get('Authorization').split()[1] | |
| objs = APIToken.query(token=token) | |
| if not objs: | |
| return get_json_result( | |
| data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) | |
| req = request.json | |
| kb_id = req.get("kb_id") | |
| doc_ids = req.get("doc_ids", []) | |
| question = req.get("question") | |
| page = int(req.get("page", 1)) | |
| size = int(req.get("size", 30)) | |
| similarity_threshold = float(req.get("similarity_threshold", 0.2)) | |
| vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | |
| top = int(req.get("top_k", 1024)) | |
| try: | |
| e, kb = KnowledgebaseService.get_by_id(kb_id) | |
| if not e: | |
| return get_data_error_result(retmsg="Knowledgebase not found!") | |
| embd_mdl = TenantLLMService.model_instance( | |
| kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | |
| rerank_mdl = None | |
| if req.get("rerank_id"): | |
| rerank_mdl = TenantLLMService.model_instance( | |
| kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) | |
| if req.get("keyword", False): | |
| chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) | |
| question += keyword_extraction(chat_mdl, question) | |
| ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, | |
| similarity_threshold, vector_similarity_weight, top, | |
| doc_ids, rerank_mdl=rerank_mdl) | |
| for c in ranks["chunks"]: | |
| if "vector" in c: | |
| del c["vector"] | |
| return get_json_result(data=ranks) | |
| except Exception as e: | |
| if str(e).find("not_found") > 0: | |
| return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!', | |
| retcode=RetCode.DATA_ERROR) | |
| return server_error_response(e) | |