akseljoonas HF Staff commited on
Commit
72bac94
·
1 Parent(s): 9615e37

inference token

Browse files
agent/context_manager/manager.py CHANGED
@@ -165,10 +165,12 @@ class ContextManager:
165
  )
166
  )
167
 
 
168
  response = await acompletion(
169
  model=model_name,
170
  messages=messages_to_summarize,
171
  max_completion_tokens=self.compact_size,
 
172
  )
173
  summarized_message = Message(
174
  role="assistant", content=response.choices[0].message.content
 
165
  )
166
  )
167
 
168
+ api_key = os.environ.get("INFERENCE_TOKEN")
169
  response = await acompletion(
170
  model=model_name,
171
  messages=messages_to_summarize,
172
  max_completion_tokens=self.compact_size,
173
+ **({'api_key': api_key} if api_key and model_name.startswith('huggingface/') else {}),
174
  )
175
  summarized_message = Message(
176
  role="assistant", content=response.choices[0].message.content
agent/core/agent_loop.py CHANGED
@@ -5,8 +5,9 @@ Main agent implementation with integrated tool system and MCP support
5
  import asyncio
6
  import json
7
  import logging
 
8
 
9
- from litellm import ChatCompletionMessageToolCall, Message, ModelResponse, acompletion
10
  from lmnr import observe
11
 
12
  from agent.config import Config
@@ -17,6 +18,9 @@ from agent.tools.jobs_tool import CPU_FLAVORS
17
  logger = logging.getLogger(__name__)
18
 
19
  ToolCall = ChatCompletionMessageToolCall
 
 
 
20
 
21
 
22
  def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
@@ -41,7 +45,9 @@ def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
41
  return True, None
42
 
43
 
44
- def _needs_approval(tool_name: str, tool_args: dict, config: Config | None = None) -> bool:
 
 
45
  """Check if a tool call requires user approval before execution."""
46
  # Yolo mode: skip all approvals
47
  if config and config.yolo_mode:
@@ -56,19 +62,24 @@ def _needs_approval(tool_name: str, tool_args: dict, config: Config | None = Non
56
  operation = tool_args.get("operation", "")
57
  if operation not in ["run", "uv", "scheduled run", "scheduled uv"]:
58
  return False
59
-
60
  # Check if this is a CPU-only job
61
  # hardware_flavor is at top level of tool_args, not nested in args
62
- hardware_flavor = tool_args.get("hardware_flavor") or tool_args.get("flavor") or tool_args.get("hardware") or "cpu-basic"
 
 
 
 
 
63
  is_cpu_job = hardware_flavor in CPU_FLAVORS
64
-
65
  if is_cpu_job:
66
  if config and not config.confirm_cpu_jobs:
67
  return False
68
  return True
69
-
70
  return True
71
-
72
  # Check for file upload operations (hf_private_repos or other tools)
73
  if tool_name == "hf_private_repos":
74
  operation = tool_args.get("operation", "")
@@ -89,7 +100,13 @@ def _needs_approval(tool_name: str, tool_args: dict, config: Config | None = Non
89
  # hf_repo_git: destructive operations require approval
90
  if tool_name == "hf_repo_git":
91
  operation = tool_args.get("operation", "")
92
- if operation in ["delete_branch", "delete_tag", "merge_pr", "create_repo", "update_repo"]:
 
 
 
 
 
 
93
  return True
94
 
95
  return False
@@ -140,6 +157,12 @@ class Handlers:
140
  tool_choice="auto",
141
  stream=True,
142
  stream_options={"include_usage": True},
 
 
 
 
 
 
143
  )
144
 
145
  full_content = ""
@@ -180,13 +203,13 @@ class Handlers:
180
  tool_calls_acc[idx]["id"] = tc_delta.id
181
  if tc_delta.function:
182
  if tc_delta.function.name:
183
- tool_calls_acc[idx]["function"][
184
- "name"
185
- ] += tc_delta.function.name
186
  if tc_delta.function.arguments:
187
- tool_calls_acc[idx]["function"][
188
- "arguments"
189
- ] += tc_delta.function.arguments
190
 
191
  # Capture usage from the final chunk
192
  if hasattr(chunk, "usage") and chunk.usage:
@@ -219,9 +242,7 @@ class Handlers:
219
  if not tool_calls:
220
  if content:
221
  assistant_msg = Message(role="assistant", content=content)
222
- session.context_manager.add_message(
223
- assistant_msg, token_count
224
- )
225
  final_response = content
226
  break
227
 
 
5
  import asyncio
6
  import json
7
  import logging
8
+ import os
9
 
10
+ from litellm import ChatCompletionMessageToolCall, Message, acompletion
11
  from lmnr import observe
12
 
13
  from agent.config import Config
 
18
  logger = logging.getLogger(__name__)
19
 
20
  ToolCall = ChatCompletionMessageToolCall
21
+ # Explicit inference token — needed because litellm checks HF_TOKEN before
22
+ # HUGGINGFACE_API_KEY, and HF_TOKEN (used for Hub ops) may lack inference permissions.
23
+ _INFERENCE_API_KEY = os.environ.get("INFERENCE_TOKEN")
24
 
25
 
26
  def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
 
45
  return True, None
46
 
47
 
48
+ def _needs_approval(
49
+ tool_name: str, tool_args: dict, config: Config | None = None
50
+ ) -> bool:
51
  """Check if a tool call requires user approval before execution."""
52
  # Yolo mode: skip all approvals
53
  if config and config.yolo_mode:
 
62
  operation = tool_args.get("operation", "")
63
  if operation not in ["run", "uv", "scheduled run", "scheduled uv"]:
64
  return False
65
+
66
  # Check if this is a CPU-only job
67
  # hardware_flavor is at top level of tool_args, not nested in args
68
+ hardware_flavor = (
69
+ tool_args.get("hardware_flavor")
70
+ or tool_args.get("flavor")
71
+ or tool_args.get("hardware")
72
+ or "cpu-basic"
73
+ )
74
  is_cpu_job = hardware_flavor in CPU_FLAVORS
75
+
76
  if is_cpu_job:
77
  if config and not config.confirm_cpu_jobs:
78
  return False
79
  return True
80
+
81
  return True
82
+
83
  # Check for file upload operations (hf_private_repos or other tools)
84
  if tool_name == "hf_private_repos":
85
  operation = tool_args.get("operation", "")
 
100
  # hf_repo_git: destructive operations require approval
101
  if tool_name == "hf_repo_git":
102
  operation = tool_args.get("operation", "")
103
+ if operation in [
104
+ "delete_branch",
105
+ "delete_tag",
106
+ "merge_pr",
107
+ "create_repo",
108
+ "update_repo",
109
+ ]:
110
  return True
111
 
112
  return False
 
157
  tool_choice="auto",
158
  stream=True,
159
  stream_options={"include_usage": True},
160
+ **(
161
+ {"api_key": _INFERENCE_API_KEY}
162
+ if _INFERENCE_API_KEY
163
+ and session.config.model_name.startswith("huggingface/")
164
+ else {}
165
+ ),
166
  )
167
 
168
  full_content = ""
 
203
  tool_calls_acc[idx]["id"] = tc_delta.id
204
  if tc_delta.function:
205
  if tc_delta.function.name:
206
+ tool_calls_acc[idx]["function"]["name"] += (
207
+ tc_delta.function.name
208
+ )
209
  if tc_delta.function.arguments:
210
+ tool_calls_acc[idx]["function"]["arguments"] += (
211
+ tc_delta.function.arguments
212
+ )
213
 
214
  # Capture usage from the final chunk
215
  if hasattr(chunk, "usage") and chunk.usage:
 
242
  if not tool_calls:
243
  if content:
244
  assistant_msg = Message(role="assistant", content=content)
245
+ session.context_manager.add_message(assistant_msg, token_count)
 
 
246
  final_response = content
247
  break
248
 
backend/routes/agent.py CHANGED
@@ -5,13 +5,19 @@ dependency. In dev mode (no OAUTH_CLIENT_ID), auth is bypassed automatically.
5
  """
6
 
7
  import logging
 
8
  from typing import Any
9
 
10
- from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket, WebSocketDisconnect
11
-
12
  from dependencies import get_current_user, get_ws_user
 
 
 
 
 
 
 
 
13
  from litellm import acompletion
14
-
15
  from models import (
16
  ApprovalRequest,
17
  HealthResponse,
@@ -27,6 +33,31 @@ logger = logging.getLogger(__name__)
27
 
28
  router = APIRouter(prefix="/api", tags=["agent"])
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def _check_session_access(session_id: str, user: dict[str, Any]) -> None:
32
  """Verify the user has access to the given session. Raises 403 or 404."""
@@ -58,21 +89,37 @@ async def llm_health_check() -> LLMHealthResponse:
58
  - timeout / network → provider unreachable
59
  """
60
  model = session_manager.config.model_name
 
 
 
 
61
  try:
62
  await acompletion(
63
  model=model,
64
  messages=[{"role": "user", "content": "hi"}],
65
  max_tokens=1,
66
  timeout=10,
 
67
  )
68
  return LLMHealthResponse(status="ok", model=model)
69
  except Exception as e:
70
  err_str = str(e).lower()
71
  error_type = "unknown"
72
 
73
- if "401" in err_str or "auth" in err_str or "invalid" in err_str or "api key" in err_str:
 
 
 
 
 
74
  error_type = "auth"
75
- elif "402" in err_str or "credit" in err_str or "quota" in err_str or "insufficient" in err_str or "billing" in err_str:
 
 
 
 
 
 
76
  error_type = "credits"
77
  elif "429" in err_str or "rate" in err_str:
78
  error_type = "rate_limit"
@@ -88,14 +135,6 @@ async def llm_health_check() -> LLMHealthResponse:
88
  )
89
 
90
 
91
- AVAILABLE_MODELS = [
92
- {"id": "huggingface/novita/MiniMaxAI/MiniMax-M2.1", "label": "MiniMax M2.1", "provider": "huggingface", "recommended": True},
93
- {"id": "anthropic/claude-opus-4-5-20251101", "label": "Claude Opus 4.5", "provider": "anthropic", "recommended": True},
94
- {"id": "huggingface/novita/moonshotai/Kimi-K2.5", "label": "Kimi K2.5", "provider": "huggingface"},
95
- {"id": "huggingface/novita/zai-org/GLM-5", "label": "GLM 5", "provider": "huggingface"},
96
- ]
97
-
98
-
99
  @router.get("/config/model")
100
  async def get_model() -> dict:
101
  """Get current model and available models. No auth required."""
@@ -106,9 +145,7 @@ async def get_model() -> dict:
106
 
107
 
108
  @router.post("/config/model")
109
- async def set_model(
110
- body: dict, user: dict = Depends(get_current_user)
111
- ) -> dict:
112
  """Set the LLM model. Applies to new conversations."""
113
  model_id = body.get("model")
114
  if not model_id:
@@ -127,6 +164,10 @@ async def generate_title(
127
  ) -> dict:
128
  """Generate a short title for a chat session based on the first user message."""
129
  model = session_manager.config.model_name
 
 
 
 
130
  try:
131
  response = await acompletion(
132
  model=model,
@@ -144,6 +185,7 @@ async def generate_title(
144
  max_tokens=20,
145
  temperature=0.3,
146
  timeout=8,
 
147
  )
148
  title = response.choices[0].message.content.strip().strip('"').strip("'")
149
  # Safety: cap at 50 chars
@@ -259,9 +301,7 @@ async def interrupt_session(
259
 
260
 
261
  @router.post("/undo/{session_id}")
262
- async def undo_session(
263
- session_id: str, user: dict = Depends(get_current_user)
264
- ) -> dict:
265
  """Undo the last turn in a session."""
266
  _check_session_access(session_id, user)
267
  success = await session_manager.undo(session_id)
@@ -312,7 +352,9 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str) -> None:
312
  # Authenticate the WebSocket connection
313
  user = await get_ws_user(websocket)
314
  if not user:
315
- logger.warning(f"WebSocket rejected: authentication failed for session {session_id}")
 
 
316
  await websocket.accept()
317
  await websocket.close(code=4001, reason="Authentication required")
318
  return
@@ -340,10 +382,12 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str) -> None:
340
  # knows the session is alive. The original ready event from _run_session
341
  # fires before the WS is connected and is always lost.
342
  try:
343
- await websocket.send_json({
344
- "event_type": "ready",
345
- "data": {"message": "Agent initialized"},
346
- })
 
 
347
  except Exception as e:
348
  logger.error(f"Failed to send ready event for session {session_id}: {e}")
349
 
 
5
  """
6
 
7
  import logging
8
+ import os
9
  from typing import Any
10
 
 
 
11
  from dependencies import get_current_user, get_ws_user
12
+ from fastapi import (
13
+ APIRouter,
14
+ Depends,
15
+ HTTPException,
16
+ Request,
17
+ WebSocket,
18
+ WebSocketDisconnect,
19
+ )
20
  from litellm import acompletion
 
21
  from models import (
22
  ApprovalRequest,
23
  HealthResponse,
 
33
 
34
  router = APIRouter(prefix="/api", tags=["agent"])
35
 
36
+ AVAILABLE_MODELS = [
37
+ {
38
+ "id": "huggingface/novita/MiniMaxAI/MiniMax-M2.1",
39
+ "label": "MiniMax M2.1",
40
+ "provider": "huggingface",
41
+ "recommended": True,
42
+ },
43
+ {
44
+ "id": "anthropic/claude-opus-4-5-20251101",
45
+ "label": "Claude Opus 4.5",
46
+ "provider": "anthropic",
47
+ "recommended": True,
48
+ },
49
+ {
50
+ "id": "huggingface/novita/moonshotai/Kimi-K2.5",
51
+ "label": "Kimi K2.5",
52
+ "provider": "huggingface",
53
+ },
54
+ {
55
+ "id": "huggingface/novita/zai-org/GLM-5",
56
+ "label": "GLM 5",
57
+ "provider": "huggingface",
58
+ },
59
+ ]
60
+
61
 
62
  def _check_session_access(session_id: str, user: dict[str, Any]) -> None:
63
  """Verify the user has access to the given session. Raises 403 or 404."""
 
89
  - timeout / network → provider unreachable
90
  """
91
  model = session_manager.config.model_name
92
+ hf_key = os.environ.get("INFERENCE_TOKEN")
93
+ api_key_kw = (
94
+ {"api_key": hf_key} if hf_key and model.startswith("huggingface/") else {}
95
+ )
96
  try:
97
  await acompletion(
98
  model=model,
99
  messages=[{"role": "user", "content": "hi"}],
100
  max_tokens=1,
101
  timeout=10,
102
+ **api_key_kw,
103
  )
104
  return LLMHealthResponse(status="ok", model=model)
105
  except Exception as e:
106
  err_str = str(e).lower()
107
  error_type = "unknown"
108
 
109
+ if (
110
+ "401" in err_str
111
+ or "auth" in err_str
112
+ or "invalid" in err_str
113
+ or "api key" in err_str
114
+ ):
115
  error_type = "auth"
116
+ elif (
117
+ "402" in err_str
118
+ or "credit" in err_str
119
+ or "quota" in err_str
120
+ or "insufficient" in err_str
121
+ or "billing" in err_str
122
+ ):
123
  error_type = "credits"
124
  elif "429" in err_str or "rate" in err_str:
125
  error_type = "rate_limit"
 
135
  )
136
 
137
 
 
 
 
 
 
 
 
 
138
  @router.get("/config/model")
139
  async def get_model() -> dict:
140
  """Get current model and available models. No auth required."""
 
145
 
146
 
147
  @router.post("/config/model")
148
+ async def set_model(body: dict, user: dict = Depends(get_current_user)) -> dict:
 
 
149
  """Set the LLM model. Applies to new conversations."""
150
  model_id = body.get("model")
151
  if not model_id:
 
164
  ) -> dict:
165
  """Generate a short title for a chat session based on the first user message."""
166
  model = session_manager.config.model_name
167
+ hf_key = os.environ.get("INFERENCE_TOKEN")
168
+ api_key_kw = (
169
+ {"api_key": hf_key} if hf_key and model.startswith("huggingface/") else {}
170
+ )
171
  try:
172
  response = await acompletion(
173
  model=model,
 
185
  max_tokens=20,
186
  temperature=0.3,
187
  timeout=8,
188
+ **api_key_kw,
189
  )
190
  title = response.choices[0].message.content.strip().strip('"').strip("'")
191
  # Safety: cap at 50 chars
 
301
 
302
 
303
  @router.post("/undo/{session_id}")
304
+ async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict:
 
 
305
  """Undo the last turn in a session."""
306
  _check_session_access(session_id, user)
307
  success = await session_manager.undo(session_id)
 
352
  # Authenticate the WebSocket connection
353
  user = await get_ws_user(websocket)
354
  if not user:
355
+ logger.warning(
356
+ f"WebSocket rejected: authentication failed for session {session_id}"
357
+ )
358
  await websocket.accept()
359
  await websocket.close(code=4001, reason="Authentication required")
360
  return
 
382
  # knows the session is alive. The original ready event from _run_session
383
  # fires before the WS is connected and is always lost.
384
  try:
385
+ await websocket.send_json(
386
+ {
387
+ "event_type": "ready",
388
+ "data": {"message": "Agent initialized"},
389
+ }
390
+ )
391
  except Exception as e:
392
  logger.error(f"Failed to send ready event for session {session_id}: {e}")
393