File size: 16,246 Bytes
900bbdd
53299b7
900bbdd
53299b7
3b1d19a
53299b7
900bbdd
9cc9909
900bbdd
 
 
 
1636cdb
900bbdd
 
 
 
 
 
26346e7
 
 
 
 
83eadcd
53299b7
900bbdd
5b62571
 
 
53299b7
 
 
 
d2e445a
83eadcd
d2e445a
 
53299b7
2c319e7
83eadcd
53299b7
 
 
963cb7e
83eadcd
963cb7e
 
83eadcd
53299b7
1937f0a
 
d2e445a
53299b7
 
 
d2e445a
53299b7
 
 
83eadcd
1937f0a
 
 
 
 
 
 
 
fabff21
1937f0a
83eadcd
41a07df
83eadcd
53299b7
83eadcd
fabff21
83eadcd
53299b7
83eadcd
fabff21
53299b7
963cb7e
53299b7
 
 
 
b9ed765
9138579
b9ed765
8d3ffb0
 
02e9a1c
53299b7
83eadcd
02e9a1c
 
 
 
 
83eadcd
02e9a1c
 
 
 
 
 
d2e445a
83eadcd
1636cdb
 
d2e445a
53299b7
83eadcd
 
d2e445a
83eadcd
 
 
 
 
53299b7
 
83eadcd
d2e445a
83eadcd
53299b7
fabff21
83eadcd
53299b7
 
 
 
 
 
83eadcd
 
53299b7
 
d2e445a
83eadcd
53299b7
 
 
 
 
 
 
83eadcd
 
53299b7
 
 
 
 
 
83eadcd
 
53299b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef7c129
53299b7
 
 
 
6d2594a
1636cdb
5b62571
9cc9909
53299b7
15da475
 
13cb7f1
d5e67d1
92a992c
 
d5e67d1
13cb7f1
 
 
 
 
 
53299b7
13cb7f1
 
 
 
 
 
 
 
 
 
926ac87
d2e445a
1636cdb
2ef0835
 
53299b7
6b2e3a0
1636cdb
244ffde
1f98263
1636cdb
6b2e3a0
b075772
8d3ffb0
6b2e3a0
83eadcd
6b2e3a0
c394f8c
 
 
 
 
1f98263
 
53299b7
1f98263
c394f8c
 
963cb7e
83eadcd
 
 
 
 
d5e67d1
c394f8c
53299b7
be9067e
827f3c3
c394f8c
92a992c
 
83eadcd
 
 
92a992c
 
 
56e545a
1f98263
2ef0835
57c7071
1636cdb
83eadcd
 
5b62571
57c7071
53299b7
83eadcd
53299b7
 
 
1636cdb
2ef0835
53299b7
b5903c1
53299b7
244ffde
83eadcd
 
53299b7
83eadcd
 
53299b7
83eadcd
53299b7
 
 
 
 
 
 
 
 
 
 
 
 
 
1f7adcb
2ef0835
6b2e3a0
1636cdb
83eadcd
1f98263
 
 
 
be9067e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
"""
Proxy handler for Z.AI API requests, updated with simplified signature logic.
"""
import json, logging, re, time, uuid, base64, hashlib, hmac
from typing import AsyncGenerator, Dict, Any, Tuple, List

import httpx
from fastapi import HTTPException
from fastapi.responses import StreamingResponse

from config import settings
from cookie_manager import cookie_manager
from models import ChatCompletionRequest, ChatCompletionResponse

logger = logging.getLogger(__name__)


class ProxyHandler:
    def __init__(self):
        self.client = httpx.AsyncClient(
            timeout=httpx.Timeout(60.0, read=300.0),
            limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
            http2=True,
        )
        # The primary secret key from the reference code.
        self.primary_secret = "junjie".encode('utf-8')

    async def aclose(self):
        if not self.client.is_closed:
            await self.client.aclose()
            
    def _get_timestamp_millis(self) -> int:
        return int(time.time() * 1000)

    def _parse_jwt_token(self, token: str) -> Dict[str, str]:
        """A simple JWT payload decoder to get user ID ('sub' claim)."""
        try:
            parts = token.split('.')
            if len(parts) != 3: return {"user_id": ""}
            payload_b64 = parts[1]
            payload_b64 += '=' * (-len(payload_b64) % 4) # Add padding if needed
            payload_json = base64.urlsafe_b64decode(payload_b64).decode('utf-8')
            payload = json.loads(payload_json)
            return {"user_id": payload.get("sub", "")}
        except Exception:
            # It's okay if this fails; we'll proceed with an empty user_id.
            return {"user_id": ""}

    def _generate_signature(self, e_payload: str, t_payload: str) -> Dict[str, Any]:
        """
        Generates the signature based on the logic from the reference JS code (work.js).
        This is a two-level HMAC-SHA256 process with Base64 encoding for the content.

        Args:
            e_payload (str): The simplified payload string (e.g., "requestId,...,timestamp,...").
            t_payload (str): The last message content.

        Returns:
            A dictionary with 'signature' and 'timestamp'.
        """
        timestamp_ms = self._get_timestamp_millis()

        # --- MODIFICATION START ---
        # As per work.js, the last message content (t_payload) must be Base64 encoded.
        content_base64 = base64.b64encode(t_payload.encode('utf-8')).decode('utf-8')
        
        # Concatenate with the Base64 encoded content
        message_string = f"{e_payload}|{content_base64}|{timestamp_ms}"
        # --- MODIFICATION END ---
        
        # Per the Python snippet and JS reference: n is a 5-minute bucket
        n = timestamp_ms // (5 * 60 * 1000) 
        
        # Intermediate key derivation
        msg1 = str(n).encode("utf-8")
        intermediate_key = hmac.new(self.primary_secret, msg1, hashlib.sha256).hexdigest()

        # Final signature
        msg2 = message_string.encode("utf-8")
        final_signature = hmac.new(intermediate_key.encode("utf-8"), msg2, hashlib.sha256).hexdigest()

        return {"signature": final_signature, "timestamp": timestamp_ms}

    def _clean_thinking_content(self, text: str) -> str:
        if not text: return ""
        cleaned_text = re.sub(r'<summary>.*?</summary>|<glm_block.*?</glm_block>|<[^>]*duration="[^"]*"[^>]*>', '', text, flags=re.DOTALL)
        cleaned_text = cleaned_text.replace("</thinking>", "").replace("<Full>", "").replace("</Full>", "")
        cleaned_text = re.sub(r'</?details[^>]*>', '', cleaned_text)
        cleaned_text = re.sub(r'^\s*>\s*(?!>)', '', cleaned_text, flags=re.MULTILINE)
        cleaned_text = cleaned_text.replace("Thinking…", "")
        return cleaned_text.strip()

    def _clean_answer_content(self, text: str) -> str:
        if not text: return ""
        cleaned_text = re.sub(r'<glm_block.*?</glm_block>|<details[^>]*>.*?</details>|<summary>.*?</summary>', '', text, flags=re.DOTALL)
        return cleaned_text

    def _serialize_msgs(self, msgs) -> list:
        out = []
        for m in msgs:
            # Adapting to Pydantic v1/v2 and dicts
            if hasattr(m, "dict"): out.append(m.dict())
            elif hasattr(m, "model_dump"): out.append(m.model_dump())
            elif isinstance(m, dict): out.append(m)
            else: out.append({"role": getattr(m, "role", "user"), "content": getattr(m, "content", str(m))})
        return out

    async def _prep_upstream(self, req: ChatCompletionRequest) -> Tuple[Dict[str, Any], Dict[str, str], str, str]:
        """Prepares the request body, headers, cookie, and URL for the upstream API."""
        ck = await cookie_manager.get_next_cookie()
        if not ck: raise HTTPException(503, "No available cookies")

        model = settings.UPSTREAM_MODEL if req.model == settings.MODEL_NAME else req.model
        chat_id = str(uuid.uuid4())
        request_id = str(uuid.uuid4())
        
        # --- NEW Simplified Signature Payload Logic ---
        user_info = self._parse_jwt_token(ck)
        user_id = user_info.get("user_id", "")
        # The reference code uses a separate UUID for user_id in payload, let's follow that.
        # This seems strange, but let's replicate the reference code exactly.
        payload_user_id = str(uuid.uuid4())
        payload_request_id = str(uuid.uuid4())
        payload_timestamp = str(self._get_timestamp_millis())
        
        # e: The simplified payload for the signature
        e_payload = f"requestId,{payload_request_id},timestamp,{payload_timestamp},user_id,{payload_user_id}"

        # t: The last message content
        t_payload = ""
        if req.messages:
            last_message = req.messages[-1]
            if isinstance(last_message.content, str):
                t_payload = last_message.content

        # Generate the signature
        signature_data = self._generate_signature(e_payload, t_payload)
        signature = signature_data["signature"]
        signature_timestamp = signature_data["timestamp"]
        
        # The reference code sends these as URL parameters, not in the body.
        url_params = {
            "requestId": payload_request_id,
            "timestamp": payload_timestamp,
            "user_id": payload_user_id,
            "signature_timestamp": str(signature_timestamp)
        }
        
        # Construct URL with query parameters
        # Note: The reference code has a typo `f"{BASE_URL}/api/chat/completions"`, it should be `z.ai`
        final_url = httpx.URL(settings.UPSTREAM_URL).copy_with(params=url_params)

        body = {
            "stream": True,
            "model": model,
            "messages": self._serialize_msgs(req.messages),
            "chat_id": chat_id,
            "id": request_id,
            "features": {
                "image_generation": False,
                "web_search": False,
                "auto_web_search": False,
                "preview_mode": False,
                "flags": [],
                "enable_thinking": True,
            }
        }
        
        headers = {
            "Accept": "*/*",
            "Accept-Language": "zh-CN",
            "Authorization": f"Bearer {ck}",
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "Content-Type": "application/json",
            "Origin": "https://chat.z.ai",
            "Referer": "https://chat.z.ai/",
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36",
            "X-FE-Version": "prod-fe-1.0.103",
            "X-Signature": signature,
        }
        
        return body, headers, ck, str(final_url)
        
    async def stream_proxy_response(self, req: ChatCompletionRequest) -> AsyncGenerator[str, None]:
        ck = None
        try:
            body, headers, ck, url = await self._prep_upstream(req)
            comp_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
            think_open = False
            yielded_think_buffer = "" 
            current_raw_thinking = ""
            is_first_answer_chunk = True

            async def yield_delta(content_type: str, text: str):
                nonlocal think_open, yielded_think_buffer
                if content_type == "thinking" and settings.SHOW_THINK_TAGS:
                    if not think_open:
                        yield f"data: {json.dumps({'id': comp_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model, 'choices': [{'index': 0, 'delta': {'content': '<think>'}, 'finish_reason': None}]})}\n\n"
                        think_open = True
                    cleaned_full_text = self._clean_thinking_content(text)
                    delta_to_send = cleaned_full_text[len(yielded_think_buffer):]
                    if delta_to_send:
                        yield f"data: {json.dumps({'id': comp_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model, 'choices': [{'index': 0, 'delta': {'content': delta_to_send}, 'finish_reason': None}]})}\n\n"
                    yielded_think_buffer = cleaned_full_text
                elif content_type == "answer":
                    if think_open:
                        yield f"data: {json.dumps({'id': comp_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model, 'choices': [{'index': 0, 'delta': {'content': '</think>'}, 'finish_reason': None}]})}\n\n"
                        think_open = False
                    cleaned_text = self._clean_answer_content(text)
                    if cleaned_text:
                        yield f"data: {json.dumps({'id': comp_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model, 'choices': [{'index': 0, 'delta': {'content': cleaned_text}, 'finish_reason': None}]})}\n\n"

            async with self.client.stream("POST", url, json=body, headers=headers) as resp:
                if resp.status_code != 200:
                    await cookie_manager.mark_cookie_failed(ck); err_body = await resp.aread()
                    err_msg = f"Error: {resp.status_code} - {err_body.decode(errors='ignore')}"
                    logger.error(f"Upstream error: {err_msg}")
                    err = {"id": comp_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": req.model, "choices": [{"index": 0, "delta": {"content": err_msg}, "finish_reason": "stop"}],}
                    yield f"data: {json.dumps(err)}\n\n"; yield "data: [DONE]\n\n"; return
                await cookie_manager.mark_cookie_success(ck)
                
                async for raw in resp.aiter_text():
                    for line in raw.strip().split('\n'):
                        line = line.strip()
                        if not line.startswith('data: '): continue
                        payload_str = line[6:]
                        # The reference code has a special 'done' phase, but the original Z.AI uses [DONE]
                        if payload_str == '[DONE]':
                            if think_open: 
                                yield f"data: {json.dumps({'id': comp_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model, 'choices': [{'index': 0, 'delta': {'content': '</think>'}, 'finish_reason': None}]})}\n\n"
                            yield f"data: {json.dumps({'id': comp_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': req.model, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"; 
                            yield "data: [DONE]\n\n"; 
                            return
                        try:
                            dat = json.loads(payload_str).get("data", {})
                        except (json.JSONDecodeError, AttributeError): continue
                        
                        phase = dat.get("phase")
                        content_chunk = dat.get("delta_content") or dat.get("edit_content")
                        if not content_chunk:
                            # Handle case where chunk is just usage info, etc.
                            if phase == 'other' and dat.get('usage'):
                                pass # In streaming, usage might come with the final chunk
                            else:
                                continue

                        if phase == "thinking":
                            current_raw_thinking = content_chunk if dat.get("edit_content") is not None else current_raw_thinking + content_chunk
                            async for item in yield_delta("thinking", current_raw_thinking):
                                yield item
                        elif phase == "answer":
                            content_to_process = content_chunk
                            if is_first_answer_chunk:
                                if '</details>' in content_to_process:
                                    parts = content_to_process.split('</details>', 1)
                                    content_to_process = parts[1] if len(parts) > 1 else ""
                                is_first_answer_chunk = False
                            if content_to_process:
                                async for item in yield_delta("answer", content_to_process):
                                    yield item
        except Exception:
            logger.exception("Stream error"); raise

    async def non_stream_proxy_response(self, req: ChatCompletionRequest) -> ChatCompletionResponse:
        # This part of the code can be simplified as well, but let's focus on fixing the streaming first.
        # The logic will be almost identical to the streaming one.
        ck = None
        try:
            body, headers, ck, url = await self._prep_upstream(req)
            # For non-stream, set stream to False in the body
            body["stream"] = False
            
            async with self.client.post(url, json=body, headers=headers) as resp:
                if resp.status_code != 200:
                    await cookie_manager.mark_cookie_failed(ck); error_detail = await resp.text()
                    logger.error(f"Upstream error: {resp.status_code} - {error_detail}")
                    raise HTTPException(resp.status_code, f"Upstream error: {error_detail}")
                
                await cookie_manager.mark_cookie_success(ck)
                
                # Z.AI non-stream response is a single JSON object
                response_data = resp.json()
                
                # We need to adapt Z.AI's response format to OpenAI's format
                final_content = ""
                finish_reason = "stop" # Default
                
                if "choices" in response_data and response_data["choices"]:
                    first_choice = response_data["choices"][0]
                    if "message" in first_choice and "content" in first_choice["message"]:
                        final_content = first_choice["message"]["content"]
                    if "finish_reason" in first_choice:
                        finish_reason = first_choice["finish_reason"]

                return ChatCompletionResponse(
                    id=response_data.get("id", f"chatcmpl-{uuid.uuid4().hex[:29]}"),
                    created=int(time.time()),
                    model=req.model,
                    choices=[{"index": 0, "message": {"role": "assistant", "content": final_content}, "finish_reason": finish_reason}],
                )
        except Exception:
            logger.exception("Non-stream processing failed"); raise

    async def handle_chat_completion(self, req: ChatCompletionRequest):
        """Determines whether to stream or not and handles the request."""
        stream = bool(req.stream) if req.stream is not None else settings.DEFAULT_STREAM
        if stream:
            return StreamingResponse(self.stream_proxy_response(req), media_type="text/event-stream",
                                     headers={"Cache-Control": "no-cache", "Connection": "keep-alive"})
        return await self.non_stream_proxy_response(req)