meirk-brd commited on
Commit
4e729a5
·
1 Parent(s): bcba5ba

fix dataset request

Browse files
Files changed (1) hide show
  1. tool.py +50 -8
tool.py CHANGED
@@ -1,5 +1,6 @@
1
  from __future__ import annotations
2
 
 
3
  import json
4
  import os
5
  import time
@@ -70,14 +71,16 @@ class BrightDataDatasetTool(Tool):
70
  def forward(
71
  self,
72
  dataset: str,
73
- url: Optional[str] = None,
74
- keyword: Optional[str] = None,
75
- first_name: Optional[str] = None,
76
- last_name: Optional[str] = None,
77
- days_limit: Optional[str] = None,
78
- num_of_reviews: Optional[str] = None,
79
- num_of_comments: Optional[str] = None,
80
- ) -> str:
 
 
81
  api_token = os.getenv("BRIGHT_DATA_API_TOKEN")
82
  if not api_token:
83
  raise ValueError("BRIGHT_DATA_API_TOKEN not found in environment variables")
@@ -197,6 +200,45 @@ class BrightDataDatasetTool(Tool):
197
 
198
  raise TimeoutError(f"Timeout waiting for snapshot {snapshot_id} after {max_attempts} seconds")
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def _get_gradio_app_code(self, tool_module_name: str = "tool") -> str:
201
  choices = sorted(self.datasets.keys())
202
  dataset_fields = {key: value["inputs"] for key, value in self.datasets.items()}
 
1
  from __future__ import annotations
2
 
3
+ import ast
4
  import json
5
  import os
6
  import time
 
71
  def forward(
72
  self,
73
  dataset: str,
74
+ url: Optional[str] = None,
75
+ keyword: Optional[str] = None,
76
+ first_name: Optional[str] = None,
77
+ last_name: Optional[str] = None,
78
+ days_limit: Optional[str] = None,
79
+ num_of_reviews: Optional[str] = None,
80
+ num_of_comments: Optional[str] = None,
81
+ ) -> str:
82
+ url = self._coerce_url_input(url)
83
+
84
  api_token = os.getenv("BRIGHT_DATA_API_TOKEN")
85
  if not api_token:
86
  raise ValueError("BRIGHT_DATA_API_TOKEN not found in environment variables")
 
200
 
201
  raise TimeoutError(f"Timeout waiting for snapshot {snapshot_id} after {max_attempts} seconds")
202
 
203
+ def _coerce_url_input(self, raw: Optional[Any]) -> Optional[str]:
204
+ if raw is None:
205
+ return None
206
+
207
+ if isinstance(raw, str):
208
+ if raw.strip().startswith("{") and "orig_name" in raw:
209
+ parsed = self._parse_file_dict_string(raw)
210
+ if parsed:
211
+ raw = parsed
212
+ else:
213
+ return self._ensure_scheme(raw)
214
+ else:
215
+ return self._ensure_scheme(raw)
216
+
217
+ if isinstance(raw, dict):
218
+ orig_name = raw.get("orig_name")
219
+ if isinstance(orig_name, str) and orig_name:
220
+ return self._ensure_scheme(orig_name)
221
+
222
+ url_value = raw.get("url")
223
+ if isinstance(url_value, str):
224
+ if url_value.startswith(("http://", "https://")):
225
+ return url_value
226
+ return None
227
+
228
+ return None
229
+
230
+ def _ensure_scheme(self, url: str) -> str:
231
+ if url.startswith(("http://", "https://")):
232
+ return url
233
+ return f"https://{url}"
234
+
235
+ def _parse_file_dict_string(self, value: str) -> Optional[dict]:
236
+ try:
237
+ parsed = ast.literal_eval(value)
238
+ return parsed if isinstance(parsed, dict) else None
239
+ except (ValueError, SyntaxError):
240
+ return None
241
+
242
  def _get_gradio_app_code(self, tool_module_name: str = "tool") -> str:
243
  choices = sorted(self.datasets.keys())
244
  dataset_fields = {key: value["inputs"] for key, value in self.datasets.items()}