|
|
import asyncio |
|
|
|
|
|
class BatchScheduler: |
|
|
def __init__(self, max_batch=8, max_wait_ms=30): |
|
|
self.queue = [] |
|
|
self.max_batch = max_batch |
|
|
self.max_wait_ms = max_wait_ms |
|
|
self.lock = asyncio.Lock() |
|
|
|
|
|
async def add(self, prompt: str): |
|
|
|
|
|
queue = asyncio.Queue() |
|
|
async with self.lock: |
|
|
self.queue.append((prompt, queue)) |
|
|
return queue |
|
|
|
|
|
async def get_batch(self): |
|
|
if not self.queue: |
|
|
return None |
|
|
|
|
|
|
|
|
await asyncio.sleep(self.max_wait_ms / 1000) |
|
|
|
|
|
async with self.lock: |
|
|
|
|
|
batch = self.queue[:self.max_batch] |
|
|
self.queue = self.queue[self.max_batch:] |
|
|
|
|
|
return batch if batch else None |
|
|
|