| | |
| |
|
| | from fastapi import FastAPI, Request, HTTPException, Response |
| | from fastapi.middleware.cors import CORSMiddleware |
| | import re, json |
| | import httpx |
| | import uuid |
| |
|
| | app = FastAPI() |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | @app.get("/") |
| | def greet_json(): |
| | return {"Hello": "World!"} |
| |
|
| | @app.get("/v1/{dest_url:path}") |
| | async def gre_dest_url(dest_url: str): |
| | return dest_url |
| |
|
| | API_CLIENT = "genai-js/0.21.0" |
| | DEFAULT_MODEL = "gemini-1.5-pro-latest" |
| |
|
| | async def transform_request(req: dict): |
| | |
| | return req |
| |
|
| | async def process_completions_response(data: dict, model: str, id: str): |
| | |
| | choices = [] |
| | if "candidates" in data: |
| | for i, candidate in enumerate(data["candidates"]): |
| | message = {} |
| | if "content" in candidate: |
| | message["content"] = candidate["content"] |
| | else: |
| | message["content"] = "" |
| | message["role"] = "assistant" |
| | choices.append({ |
| | "finish_reason": candidate.get("finishReason", "stop"), |
| | "index": i, |
| | "message": message |
| | }) |
| |
|
| | usage = {} |
| | if "usageMetadata" in data: |
| | usage = { |
| | "completion_tokens": data["usageMetadata"].get("tokenCount", 0), |
| | "prompt_tokens": 0, |
| | "total_tokens": data["usageMetadata"].get("tokenCount", 0) |
| | } |
| |
|
| | response_data = { |
| | "id": id, |
| | "choices": choices, |
| | "created": 1678787675, |
| | "model": model, |
| | "object": "chat.completion", |
| | "usage": usage |
| | } |
| | return json.dumps(response_data, ensure_ascii=False) |
| |
|
| | @app.post("/v1/{dest_url:path}") |
| | async def proxy_url(dest_url: str, request: Request): |
| | body = await request.body() |
| | headers = dict(request.headers) |
| |
|
| | |
| | if 'content-length' in headers: |
| | del headers['content-length'] |
| | if 'host' in headers: |
| | del headers['host'] |
| |
|
| | |
| | auth = headers.get("Authorization") |
| | api_key = auth.split(" ")[1] if auth else None |
| |
|
| | |
| | headers["x-goog-api-client"] = API_CLIENT |
| | if api_key: |
| | headers["x-goog-api-key"] = api_key |
| | headers['Content-Type'] = 'application/json' |
| | |
| | |
| |
|
| | dest_url = re.sub('/', '://', dest_url, count=1) |
| |
|
| | |
| | if dest_url.endswith("/chat/completions"): |
| | model = DEFAULT_MODEL |
| | req_body = json.loads(body.decode('utf-8')) |
| | if 'model' in req_body: |
| | model = req_body['model'] |
| | if model.startswith("models/"): |
| | model = model[7:] |
| | TASK = "generateContent" |
| | url = f"{dest_url.rsplit('/', 1)[0]}/{model}:{TASK}" |
| |
|
| | async with httpx.AsyncClient() as client: |
| | try: |
| | |
| | response = await client.post(url, content=body, headers=headers) |
| |
|
| | |
| | if response.status_code == 200: |
| | |
| | if 'application/json' in response.headers.get('content-type', ''): |
| | json_response = response.json() |
| | json_response['id'] = f"chatcmpl-{uuid.uuid4()}" |
| | processed_response = await process_completions_response(json_response, model, json_response['id']) |
| | resp = Response(content=processed_response, media_type="application/json") |
| | resp.headers["Access-Control-Allow-Origin"] = "*" |
| | return resp |
| | else: |
| | return {"error": "Response is not in JSON format", "id": f"chatcmpl-{uuid.uuid4()}"} |
| | else: |
| | |
| | try: |
| | error_data = response.json() |
| | error_data['id'] = f"chatcmpl-{uuid.uuid4()}" |
| | except ValueError: |
| | error_data = {"status_code": response.status_code, "detail": response.text, "id": f"chatcmpl-{uuid.uuid4()}"} |
| | print(f"Error response: {error_data}") |
| | resp = Response(content=json.dumps(error_data, ensure_ascii=False), media_type="application/json") |
| | resp.headers["Access-Control-Allow-Origin"] = "*" |
| | return resp |
| |
|
| | except httpx.RequestError as e: |
| | |
| | print(f"Request error: {e}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| | else: |
| | async with httpx.AsyncClient() as client: |
| | try: |
| | |
| | response = await client.post(dest_url, content=body, headers=headers) |
| |
|
| | |
| | if response.status_code == 200: |
| | |
| | if 'application/json' in response.headers.get('content-type', ''): |
| | json_response = response.json() |
| | json_response['id'] = f"chatcmpl-{uuid.uuid4()}" |
| | resp = Response(content=json.dumps(json_response), media_type="application/json") |
| | resp.headers["Access-Control-Allow-Origin"] = "*" |
| | return resp |
| | else: |
| | return {"error": "Response is not in JSON format", "id": f"chatcmpl-{uuid.uuid4()}"} |
| | else: |
| | |
| | try: |
| | error_data = response.json() |
| | error_data['id'] = f"chatcmpl-{uuid.uuid4()}" |
| | except ValueError: |
| | error_data = {"status_code": response.status_code, "detail": response.text, "id": f"chatcmpl-{uuid.uuid4()}"} |
| | resp = Response(content=json.dumps(error_data), media_type="application/json") |
| | resp.headers["Access-Control-Allow-Origin"] = "*" |
| | return resp |
| |
|
| | except httpx.RequestError as e: |
| | |
| | raise HTTPException(status_code=500, detail=str(e)) |
| | |
| |
|
| | from fastapi import FastAPI, Request, HTTPException, Response |
| | import re, json |
| | import httpx |
| | import uuid |
| |
|
| | app = FastAPI() |
| |
|
| | @app.get("/") |
| | def greet_json(): |
| | return {"Hello": "World!"} |
| |
|
| | @app.get("/v1/{dest_url:path}") |
| | async def gre_dest_url(dest_url: str): |
| | return dest_url |
| |
|
| | @app.post("/v1/{dest_url:path}") |
| | async def proxy_url(dest_url: str, request: Request): |
| | body = await request.body() |
| | headers = dict(request.headers) |
| | |
| | |
| | if 'content-length' in headers: |
| | del headers['content-length'] |
| | if 'host' in headers: |
| | del headers['host'] |
| | |
| | headers['User-Agent']='PostmanRuntime/7.43.0' |
| |
|
| | dest_url = re.sub('/', '://', dest_url, count=1) |
| |
|
| | async with httpx.AsyncClient() as client: |
| | try: |
| | print(f"Request Headers: {headers}") |
| | |
| | response = await client.post(dest_url, content=body, headers=headers) |
| |
|
| | |
| | if response.status_code == 200: |
| | |
| | if 'application/json' in response.headers.get('content-type', ''): |
| | json_response = response.json() |
| | json_response['id'] = f"chatcmpl-{uuid.uuid4()}" |
| | |
| | |
| | resp = Response(content=json.dumps(json_response), media_type="application/json") |
| | resp.headers["Access-Control-Allow-Origin"] = "*" |
| | return resp |
| | else: |
| | return {"error": "Response is not in JSON format", "id": f"chatcmpl-{uuid.uuid4()}"} |
| | else: |
| | |
| | try: |
| | error_data = response.json() |
| | error_data['id'] = f"chatcmpl-{uuid.uuid4()}" |
| | except ValueError: |
| | error_data = {"status_code": response.status_code, "detail": response.text, "id": f"chatcmpl-{uuid.uuid4()}"} |
| | resp = Response(content=str(error_data), media_type="application/json") |
| | resp.headers["Access-Control-Allow-Origin"] = "*" |
| | return resp |
| | |
| | except httpx.RequestError as e: |
| | |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|