from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from datetime import datetime
"""Configuration for a model tier in the routing system."""
cost_per_million_tokens: float # Combined input/output cost estimate
confidence_threshold: float # Max complexity this tier can handle
class IntelligentModelRouter:
Production-ready model router with complexity-based selection,
confidence thresholds, and automatic fallback logic.
def __init__(self, api_key: str, provider: str = "openai"):
self.client = openai.OpenAI(api_key=api_key)
# Cost-optimized tier configuration
# Ordered from cheapest to most expensive
cost_per_million_tokens=0.5, # $0.10 + $0.40
confidence_threshold=0.65,
description="Ultra-cheap for simple tasks",
cost_per_million_tokens=2.0, # $0.40 + $1.60
confidence_threshold=0.75,
description="Balanced cost/performance",
cost_per_million_tokens=10.0, # $2.00 + $8.00
confidence_threshold=0.95,
description="High-quality for complex tasks",
# Track usage for cost monitoring
def estimate_complexity(self, messages: List[Dict],
context_length: Optional[int] = None) -> float:
Estimate task complexity based on message characteristics.
Returns a score between 0.0 (trivial) and 1.0 (very complex).
text = " ".join([msg["content"] for msg in messages])
text_lower = text.lower()
# Length factor (normalized to 0-1)
length_factor = min(len(text) / 1000, 1.0)
complexity_score += length_factor * 0.2
"analyze", "reason", "calculate", "code", "program", "debug",
"architect", "design", "strategize", "evaluate", "compare",
"synthesize", "debug", "implement", "optimize", "refactor"
"hello", "hi", "thanks", "what is", "define", "list",
"explain", "summarize", "translate", "who", "when", "where"
# Adjust score based on keywords
for keyword in complex_keywords:
if keyword in text_lower:
for keyword in simple_keywords:
if keyword in text_lower:
# Check for code-related content
code_indicators = ["```", "def ", "function", "class ", "import ", "return "]
if any(indicator in text for indicator in code_indicators):
# Check for multiple documents or long context
if context_length and context_length > 5000:
# Cap between 0.1 and 1.0
return max(0.1, min(1.0, complexity_score))
def select_model(self, complexity_score: float,
user_confidence: float = 0.8,
context_length: int = 0) -> ModelTier:
Select the cheapest model that can handle the complexity.
# Filter by context window
viable_tiers = [t for t in self.model_tiers if t.max_context >= context_length]
# Fallback to highest context model
viable_tiers = [self.model_tiers[-1]]
# Find cheapest tier that meets complexity and confidence requirements
for tier in viable_tiers:
if (complexity_score <= tier.confidence_threshold and
user_confidence <= tier.confidence_threshold):
# Fallback to highest tier
def route_request(self, messages: List[Dict],
user_confidence: float = 0.8,
context_length: Optional[int] = None) -> Tuple[str, Dict]:
Main routing function. Returns response and metadata including
model used, costs, and complexity analysis.
# Calculate context length if not provided
if context_length is None:
context_length = sum(len(msg["content"]) for msg in messages)
complexity = self.estimate_complexity(messages, context_length)
# Select appropriate model
selected_tier = self.select_model(complexity, user_confidence, context_length)
# Make API call with fallback logic
response = self.client.chat.completions.create(
model=selected_tier.name,
output_text = response.choices[0].message.content
output_tokens = len(output_text.split()) # Approximate
# Calculate estimated cost
selected_tier.cost_per_million_tokens / 1_000_000 *
(context_length + output_tokens)
"model_used": selected_tier.name,
"provider": selected_tier.provider,
"complexity_score": round(complexity, 3),
"confidence_threshold": selected_tier.confidence_threshold,
"estimated_cost_usd": round(estimated_cost, 6),
"tier_description": selected_tier.description,
"context_length": context_length,
"output_tokens": output_tokens
# Log usage for monitoring
self._log_usage(metadata)
return output_text, metadata
# Automatic fallback to highest tier
fallback_tier = self.model_tiers[-1]
response = self.client.chat.completions.create(
model=fallback_tier.name,
output_text = response.choices[0].message.content
output_tokens = len(output_text.split())
"model_used": fallback_tier.name,
"provider": fallback_tier.provider,
"complexity_score": round(complexity, 3),
"note": "Used fallback due to error",
"estimated_cost_usd": round(
fallback_tier.cost_per_million_tokens / 1_000_000 *
(context_length + output_tokens),
return output_text, metadata
def _log_usage(self, metadata: Dict):
"""Internal method to track usage for cost monitoring."""
"timestamp": datetime.now().isoformat(),
self.usage_log.append(log_entry)
# Write to file periodically (in production, use proper logging)
if len(self.usage_log) >= 100:
"""Write accumulated logs to file."""
with open("model_router_usage.jsonl", "a") as f:
for entry in self.usage_log:
f.write(json.dumps(entry) + "\n")
if __name__ == "__main__":
router = IntelligentModelRouter(api_key="your-api-key")
"name": "Simple question",
"messages": [{"role": "user", "content": "What is the capital of France?"}]
"name": "Code generation",
"messages": [{"role": "user", "content": "Write a Python function to calculate fibonacci numbers"}]
"name": "Complex analysis",
"messages": [{"role": "user", "content": "Analyze the trade-offs between microservices and monolithic architecture for a high-traffic e-commerce platform"}]
print(f"\n--- {test['name']} ---")
response, metadata = router.route_request(test["messages"])
print(f"Model: {metadata['model_used']}")
print(f"Complexity: {metadata['complexity_score']}")
print(f"Cost: ${metadata['estimated_cost_usd']}")
print(f"Response: {response[:100]}...")