| | |
| | """ |
| | Custom streaming data loader for AGILLM training |
| | Pulls from stream_server on scraper box via HTTP |
| | Drop-in replacement for HuggingFace dataset streaming |
| | """ |
| | import requests |
| | import json |
| | from typing import Iterator, Dict, Any |
| |
|
| | class ScraperStreamDataset: |
| | """ |
| | Streams training data from the scraper server. |
| | Compatible with AGILLM's _stream() interface. |
| | """ |
| | def __init__( |
| | self, |
| | server_url: str = "http://localhost:8888", |
| | batch_size: int = 100, |
| | text_field: str = "text", |
| | shuffle: bool = True |
| | ): |
| | self.server_url = server_url |
| | self.batch_size = batch_size |
| | self.text_field = text_field |
| | self.shuffle = shuffle |
| | self._buffer = [] |
| | |
| | def __iter__(self) -> Iterator[Dict[str, Any]]: |
| | return self |
| | |
| | def __next__(self) -> Dict[str, Any]: |
| | if not self._buffer: |
| | self._fetch_batch() |
| | if not self._buffer: |
| | raise StopIteration |
| | return self._buffer.pop(0) |
| | |
| | def _fetch_batch(self): |
| | """Fetch a batch from stream server""" |
| | endpoint = "/stream" if self.shuffle else "/sequential" |
| | try: |
| | resp = requests.get( |
| | f"{self.server_url}{endpoint}", |
| | params={"batch": self.batch_size}, |
| | stream=True, |
| | timeout=30 |
| | ) |
| | for line in resp.iter_lines(): |
| | if line: |
| | try: |
| | obj = json.loads(line.decode('utf-8')) |
| | |
| | self._buffer.append({self.text_field: obj.get("text", "")}) |
| | except json.JSONDecodeError: |
| | continue |
| | except requests.RequestException as e: |
| | print(f"[StreamLoader] Fetch error: {e}") |
| | |
| | def get_status(self) -> dict: |
| | """Get server status""" |
| | try: |
| | resp = requests.get(f"{self.server_url}/status", timeout=10) |
| | return resp.json() |
| | except: |
| | return {"error": "unreachable"} |
| |
|
| |
|
| | def create_stream_iterator(server_url: str = "http://localhost:8888", seed: int = 42): |
| | """ |
| | Create iterator compatible with AGILLM's _stream() function. |
| | Returns infinite iterator of {"text": "..."} dicts. |
| | """ |
| | dataset = ScraperStreamDataset(server_url=server_url) |
| | while True: |
| | try: |
| | yield next(dataset) |
| | except StopIteration: |
| | |
| | dataset._fetch_batch() |
| | if dataset._buffer: |
| | yield dataset._buffer.pop(0) |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | import sys |
| | url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:8888" |
| | print(f"Testing stream from {url}") |
| | |
| | ds = ScraperStreamDataset(server_url=url, batch_size=5) |
| | print(f"Status: {ds.get_status()}") |
| | |
| | for i, item in enumerate(ds): |
| | text = item["text"] |
| | print(f"Sample {i}: {len(text)} chars - {text[:100]}...") |
| | if i >= 4: |
| | break |
| |
|