from datetime import datetime, timezone
from typing import Dict, Any, Optional
Production AI audit logger with compliance-grade fields.
Captures identity, model behavior, and system decisions.
def __init__(self, service_name: str, tenant_id: str):
self.service_name = service_name
self.tenant_id = tenant_id
self.logger = logging.getLogger(f"ai.audit.{service_name}")
response: Dict[str, Any],
context_sources: Optional[list] = None,
safety_filters: Optional[Dict[str, Any]] = None,
system_decisions: Optional[Dict[str, Any]] = None,
correlation_id: Optional[str] = None,
ip_address: Optional[str] = None
Log a complete AI inference event with all compliance fields.
Returns the correlation_id for request tracing.
# Generate correlation ID if not provided
correlation_id = str(uuid.uuid4())
# High-precision timestamp
timestamp = datetime.now(timezone.utc).isoformat()
# Identity & Access Layer
"event_id": str(uuid.uuid4()),
"correlation_id": correlation_id,
"tenant_id": self.tenant_id,
"user_type": user_type, # service-account, arthur-managed, idp-managed
"ip_address": ip_address,
"service_name": self.service_name,
"model_provider": model_provider,
"model_name": model_name,
"messages": request.get("messages", []),
"system_prompt": request.get("system_prompt", ""),
"temperature": request.get("temperature", 0.7),
"max_tokens": request.get("max_tokens", 1000),
"top_p": request.get("top_p", 1.0)
"token_count": request.get("token_count", 0)
"content": response.get("content", ""),
"finish_reason": response.get("finish_reason", "unknown"),
"token_count": response.get("token_count", 0),
"model_name": response.get("model_name", model_name)
# Context Retrieval (if RAG or tool use)
"sources": context_sources or [],
"source_count": len(context_sources) if context_sources else 0,
"retrieval_time_ms": response.get("retrieval_time_ms", 0)
"safety_filters": safety_filters or {
"content_filtered": False,
"system_decisions": system_decisions or {
"model_selected": model_name,
"routing_reason": "default",
"fallback_triggered": False,
# Cost & Performance (from verified pricing data)
"input_cost": self._calculate_cost(
model_provider, model_name,
request.get("token_count", 0),
"output_cost": self._calculate_cost(
model_provider, model_name,
response.get("token_count", 0),
"total_cost": 0.0, # Calculated below
"latency_ms": response.get("latency_ms", 0),
"time_per_token_ms": response.get("time_per_token_ms", 0)
audit_event["cost_metrics"]["total_cost"] = (
audit_event["cost_metrics"]["input_cost"] +
audit_event["cost_metrics"]["output_cost"]
# Log with appropriate severity
f"AI inference event: {correlation_id}",
extra={"audit_event": audit_event}
# Also emit as structured JSON for aggregation
print(json.dumps(audit_event, indent=2))
Calculate cost based on verified pricing data.
Note: Pricing data must be kept current with provider updates.
# Pricing per 1M tokens (verified as of 2024-11-15)
"claude-3-5-sonnet": {"input": 3.00, "output": 15.00},
"haiku-3.5": {"input": 1.25, "output": 5.00}
"gpt-4o": {"input": 5.00, "output": 15.00},
"gpt-4o-mini": {"input": 0.15, "output": 0.60}
provider_key = provider.lower()
model_key = model.lower()
if provider_key in pricing and model_key in pricing[provider_key]:
rate = pricing[provider_key][model_key]["input" if is_input else "output"]
return (tokens / 1_000_000) * rate
# Default to zero if pricing unknown
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
audit_logger = AIAuditLogger(
service_name="ai-gateway-prod",
tenant_id="tenant-acme-corp"
# Simulate an inference request
correlation_id = audit_logger.log_inference(
model_provider="anthropic",
model_name="claude-3-5-sonnet",
"messages": [{"role": "user", "content": "Explain quantum computing"}],
"system_prompt": "You are a helpful assistant.",
"content": "Quantum computing leverages quantum mechanical phenomena...",
"model_name": "claude-3-5-sonnet-20241022",
"time_per_token_ms": 9.77
{"type": "vector_db", "query": "quantum computing basics", "results": 3}
"content_filtered": False,
"model_selected": "claude-3-5-sonnet",
"routing_reason": "default",
"fallback_triggered": False,
ip_address="203.0.113.42"
print(f"Logged event with correlation_id: {correlation_id}")