from datetime import datetime
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from sentence_transformers import SentenceTransformer
"""Comprehensive drift analysis result"""
class SemanticDriftDetector:
Production-grade semantic drift detection for RAG systems.
Implements three-layer validation: retrieval consistency,
factual accuracy, and historical pattern analysis.
embedding_model: str = 'all-MiniLM-L6-v2',
retrieval_threshold: float = 0.3,
historical_threshold: float = 0.95,
materiality_threshold: float = 0.05
Initialize drift detector with configurable thresholds.
embedding_model: Sentence transformer model for similarity
retrieval_threshold: Minimum similarity for retrieved docs (lower = stricter)
historical_threshold: Minimum similarity to historical responses
materiality_threshold: Financial materiality tolerance (5% = GAAP standard)
self.model = SentenceTransformer(embedding_model)
self.retrieval_threshold = retrieval_threshold
self.historical_threshold = historical_threshold
self.materiality_threshold = materiality_threshold
self.response_history = {} # query_hash -> [responses]
def check_retrieval_consistency(self, retrieved_docs: List[str]) -> DriftReport:
Layer 1: Validate that retrieved documents are mutually consistent.
Returns False if contradictions detected.
if len(retrieved_docs) 2:
detected_at=datetime.now(),
violation_type="insufficient_docs",
details={"count": len(retrieved_docs)}
embeddings = self.model.encode(retrieved_docs)
similarities = np.dot(embeddings, embeddings.T)
# Extract upper triangle (pairwise similarities, excluding self)
triu_indices = np.triu_indices_from(similarities, k=1)
pairwise_sims = similarities[triu_indices]
min_similarity = np.min(pairwise_sims)
avg_similarity = np.mean(pairwise_sims)
is_consistent = min_similarity > self.retrieval_threshold
is_drift=not is_consistent,
confidence=float(min_similarity),
detected_at=datetime.now(),
violation_type="retrieval_contradiction",
"min_similarity": float(min_similarity),
"avg_similarity": float(avg_similarity),
"threshold": self.retrieval_threshold,
"doc_pairs_checked": len(pairwise_sims)
def extract_factual_claims(self, response: str) -> List[Dict]:
Extract numeric claims and entities from response for validation.
In production, integrate with NER or LLM-based extraction.
# Simplified example - production would use spaCy, LLM, or regex patterns
# Extract numeric patterns (currency, percentages, dates)
(r'\$?(\d+(?:\.\d+)?)(?:M|million|k|thousand)?', 'currency'),
(r'(\d+(?:\.\d+)?)\s*%', 'percentage'),
for pattern, claim_type in numeric_patterns:
matches = re.finditer(pattern, response)
'value': float(match.group(1)),
'position': match.start()
def validate_factual_consistency(
ground_truth_source: callable
Layer 2: Cross-reference extracted facts against ground truth.
Ground truth source should be a function that returns authoritative values.
claims = self.extract_factual_claims(response)
detected_at=datetime.now(),
violation_type="no_claims",
details={"response_length": len(response)}
if claim['type'] in ['currency', 'percentage']:
actual_value = ground_truth_source(claim['type'], claim['text'])
if actual_value is not None:
deviation = abs(claim['value'] - actual_value) / actual_value
if deviation > self.materiality_threshold:
'expected': actual_value,
# Log but don't fail on validation errors
print(f"Validation error for {claim}: {e}")
is_consistent = len(inconsistencies) == 0
is_drift=not is_consistent,
confidence=float(1.0 - (len(inconsistencies) / len(claims))),
detected_at=datetime.now(),
violation_type="factual_inconsistency",
"total_claims": len(claims),
"inconsistencies": len(inconsistencies),
"materiality_threshold": self.materiality_threshold,
"violation_examples": inconsistencies[:3] # Top 3 for brevity
def check_historical_drift(
Layer 3: Compare new response to historical responses for similar queries.
Detects gradual drift patterns over time.
query_hash = hashlib.sha256(query.encode()).hexdigest()
# Get historical responses for this query
historical = self.response_history.get(query_hash, [])
# No baseline yet - store and return no drift
self._store_response(query_hash, new_response)
detected_at=datetime.now(),
violation_type="no_baseline",
details={"action": "baseline_created"}
new_embedding = self.model.encode(new_response)
# Compare with historical responses
for hist_response in historical[-top_k:]: # Last N responses
hist_embedding = self.model.encode(hist_response)
sim = np.dot(new_embedding, hist_embedding) / (
np.linalg.norm(new_embedding) * np.linalg.norm(hist_embedding)
max_similarity = max(similarities) if similarities else 0.0
avg_similarity = np.mean(similarities) if similarities else 0.0
is_drift = max_similarity self.historical_threshold
# Store new response for future comparisons
self._store_response(query_hash, new_response)
confidence=float(max_similarity),
detected_at=datetime.now(),
violation_type="historical_drift",
"max_similarity": float(max_similarity),
"avg_similarity": float(avg_similarity),
"threshold": self.historical_threshold,
"historical_count": len(historical),
"comparisons_made": len(similarities)
def _store_response(self, query_hash: str, response: str):
"""Store response in history for future drift detection"""
if query_hash not in self.response_history:
self.response_history[query_hash] = []
self.response_history[query_hash].append(response)
# Keep only last 50 responses per query to manage memory
if len(self.response_history[query_hash]) > 50:
self.response_history[query_hash] = self.response_history[query_hash][-50:]
retrieved_docs: List[str],
ground_truth_source: Optional[callable] = None
) -> Dict[str, DriftReport]:
Execute complete three-layer drift detection pipeline.
Returns dictionary of all drift reports.
# Layer 1: Retrieval consistency
reports['retrieval'] = self.check_retrieval_consistency(retrieved_docs)
# Layer 2: Factual validation (if ground truth provided)
reports['factual'] = self.validate_factual_consistency(
response, ground_truth_source
# Layer 3: Historical drift
reports['historical'] = self.check_historical_drift(query, response)
# Production Integration Example
class ProductionRAGPipeline:
"""Complete RAG pipeline with integrated drift detection"""
def __init__(self, llm_client, drift_detector: SemanticDriftDetector):
self.drift_detector = drift_detector