from typing import List, Dict, Any, Optional
from dataclasses import dataclass
PROCESSING = "processing"
class ProductionBatchScheduler:
Production-grade dynamic batching with:
- Queue size monitoring for autoscaling
- Memory-aware batch sizing
- Timeout-based batching windows
- Comprehensive metrics collection
max_batch_size: int = 32,
queue_threshold: int = 5,
max_memory_gb: float = 80.0
self.model_endpoint = model_endpoint
self.max_batch_size = max_batch_size
self.queue_threshold = queue_threshold
self.max_wait_ms = max_wait_ms
self.max_memory_gb = max_memory_gb
self.request_queue: List[BatchRequest] = []
self._batch_task: Optional[asyncio.Task] = None
Add request to queue with automatic batching.
Returns response when batch completes.
future = asyncio.Future()
self.request_queue.append(request)
self.metrics["queue_sizes"].append(len(self.request_queue))
# Trigger batching if threshold reached
if len(self.request_queue) >= self.queue_threshold:
if self._batch_task is None or self._batch_task.done():
self._batch_task = asyncio.create_task(self._process_batch_loop())
async def _process_batch_loop(self):
Continuous batching loop with timeout window.
while self.request_queue:
# Wait for batching window
asyncio.sleep(self.max_wait_ms / 1000),
timeout=self.max_wait_ms / 1000
except asyncio.TimeoutError:
batch = self._form_batch()
await self._execute_batch(batch)
# Check for scale-up signal
def _form_batch(self) -> List[BatchRequest]:
Form batch respecting memory constraints and max size.
if not self.request_queue:
# Estimate memory usage (simplified)
estimated_memory_per_request = 0.5 # GB per request
available_memory = self.max_memory_gb - (len(self.request_queue) * estimated_memory_per_request)
memory_limited_size = int(available_memory / estimated_memory_per_request)
# Take minimum of: queue size, max batch size, memory limit
batch = self.request_queue[:batch_size]
self.request_queue = self.request_queue[batch_size:]
self.metrics["batch_sizes"].append(len(batch))
async def _execute_batch(self, batch: List[BatchRequest]):
Execute batch against model endpoint.
"inputs": [req.prompt for req in batch],
"max_tokens": max(req.max_tokens for req in batch),
"temperature": batch[0].temperature, # Use first request's temp
"return_full_text": False
async with aiohttp.ClientSession() as session:
f"{self.model_endpoint}/generate",
timeout=aiohttp.ClientTimeout(total=30)
result = await response.json()
outputs = result.get("outputs", [])
for i, req in enumerate(batch):
req.future.set_result(outputs[i])
req.future.set_exception(Exception("Missing output"))
processing_time = time.time() - start_time
self.metrics["processing_times"].append(processing_time)
# Fail all requests in batch
req.future.set_exception(e)
self.metrics["processing_times"].append(-1) # Error indicator
def _evaluate_scaling(self):
Generate scaling signals based on queue dynamics.
if not self.metrics["queue_sizes"]:
recent_queue_sizes = self.metrics["queue_sizes"][-10:]
avg_queue = sum(recent_queue_sizes) / len(recent_queue_sizes)
current_queue = len(self.request_queue)
if current_queue > self.queue_threshold * 2:
self.metrics["scale_events"].append({
"timestamp": time.time(),
"queue_size": current_queue,
"reason": "queue_threshold_exceeded"
print(f"🚨 SCALE UP: Queue={current_queue}, Threshold={self.queue_threshold}")
if current_queue == 0 and len(self.metrics["batch_sizes"]) > 10:
avg_batch = sum(self.metrics["batch_sizes"][-10:]) / 10
self.metrics["scale_events"].append({
"timestamp": time.time(),
"reason": "low_utilization"
print(f"✅ SCALE DOWN: Avg batch={avg_batch:.2f}")
def get_metrics(self) -> Dict[str, Any]:
"""Return comprehensive metrics for monitoring."""
if not self.metrics["processing_times"]:
return {"status": "no_data"}
valid_times = [t for t in self.metrics["processing_times"] if t > 0]
avg_latency = sum(valid_times) / len(valid_times) if valid_times else 0
"avg_latency_s": avg_latency,
"total_requests": len(self.metrics["batch_sizes"]),
"avg_batch_size": sum(self.metrics["batch_sizes"]) / len(self.metrics["batch_sizes"]) if self.metrics["batch_sizes"] else 0,
"scale_events": len(self.metrics["scale_events"]),
"current_queue": len(self.request_queue)
# Production usage example
async def production_example():
scheduler = ProductionBatchScheduler(
model_endpoint="http://vllm-service:8000",
# Simulate production load
async def generate_load():
prompt=f"Analyze this transaction: {i}",
results = await asyncio.gather(*tasks, return_exceptions=True)
results = await generate_load()
metrics = scheduler.get_metrics()
print(f"Metrics: {metrics}")
if __name__ == "__main__":
asyncio.run(production_example())