Déploiement de Systèmes RAG en Production
RAG prêt pour la production : architecture, mise à l'échelle, surveillance, gestion des erreurs et meilleures pratiques opérationnelles pour des déploiements fiables.
- Auteur
- Équipe de Recherche Ailog
- Date de publication
- Temps de lecture
- 14 min de lecture
- Niveau
- advanced
Production vs Prototype
| Aspect | Prototype | Production | |--------|-----------|----------| | Disponibilité | Meilleur effort | SLA 99.9%+ | | Latence | Variable | p95 < 2s | | Gestion des erreurs | Plantages | Dégradation gracieuse | | Surveillance | Aucune | Complète | | Coût | Peu importe | Optimisé | | Échelle | Utilisateur unique | Milliers concurrents | | Données | Échantillon | Corpus complet |
Architecture de Production
Architecture de Base
`` ┌─────────┐ ┌──────────┐ ┌──────────┐ ┌─────┐ │ Requête │────>│ API │────>│ Base de │────>│ LLM │ │ Util. │ │ Gateway │ │ Données │ │ │ └─────────┘ └──────────┘ └──────────┘ └─────┘ │ │ │ v v v ┌──────────┐ ┌──────────┐ ┌──────────┐ │ Couche │ │ Service │ │ Surveillance│ │ Cache │ │ Embedding│ │ │ └──────────┘ └──────────┘ └──────────┘ `
Composants
API Gateway • Limitation de débit • Authentification • Routage des requêtes • Équilibrage de charge
Base de Données Vectorielle • Gérée (Pinecone) ou auto-hébergée (Qdrant) • Répliquée pour HA • Sauvegardes configurées
Service d'Embedding • Service séparé pour les embeddings • Traitement par lots • Mise en cache du modèle
Service LLM • Fournisseur API (OpenAI) ou auto-hébergé • Fournisseurs de secours • Streaming des réponses
Couche de Cache • Redis ou Memcached • Cache des requêtes fréquentes • Cache des embeddings
Surveillance • Prometheus + Grafana • Suivi des erreurs (Sentry) • Journalisation (stack ELK)
Stratégies de Mise à l'Échelle
Mise à l'Échelle Verticale
Base de Données Vectorielle • Plus de CPU : Recherche plus rapide • Plus de RAM : Index plus grands en mémoire • GPU : Recherche de similarité accélérée
Serveurs API • Plus de CPU : Gérer plus de requêtes concurrentes • Plus de RAM : Caches plus grands
Limites : • Maximum d'une seule machine • Coûteux à haut niveau • Pas de redondance
Mise à l'Échelle Horizontale
Équilibrage de Charge `python Plusieurs serveurs API derrière l'équilibreur de charge
Configuration NGINX upstream rag_api { least_conn; Router vers le serveur le moins occupé server api1.example.com:8000; server api2.example.com:8000; server api3.example.com:8000; }
server { listen 80;
location / { proxy_pass http://rag_api; proxy_set_header Host $host; } } `
Sharding de la Base de Données Vectorielle `python Shard par ID de document def get_shard(doc_id, num_shards=4): return hash(doc_id) % num_shards
Interroger tous les shards, fusionner les résultats async def distributed_search(query_vector, k=5): tasks = [ search_shard(shard_id, query_vector, k=k) for shard_id in range(num_shards) ]
all_results = await asyncio.gather(tasks)
Fusionner et reclasser merged = merge_results(all_results) return merged[:k] `
Réplicas de Lecture `python Écrire sur le primaire, lire depuis les réplicas class VectorDBCluster: def __init__(self, primary, replicas): self.primary = primary self.replicas = replicas self.replica_index = 0
def write(self, vector): Toutes les écritures vont au primaire self.primary.upsert(vector)
def search(self, query_vector, k=5): Lire depuis le replica (round-robin) replica = self.replicas[self.replica_index] self.replica_index = (self.replica_index + 1) % len(self.replicas)
return replica.search(query_vector, k=k) `
Auto-Scaling
`yaml Kubernetes HPA (Horizontal Pod Autoscaler) apiVersion: autoscaling/v2 kind: HorizontalPodAutoscaler metadata: name: rag-api-hpa spec: scaleTargetRef: apiVersion: apps/v1 kind: Deployment name: rag-api minReplicas: 2 maxReplicas: 10 metrics: • type: Resource resource: name: cpu target: type: Utilization averageUtilization: 70 • type: Resource resource: name: memory target: type: Utilization averageUtilization: 80 `
Mise en Cache
Cache de Requêtes
`python import redis import hashlib import json
class QueryCache: def __init__(self, redis_client, ttl=3600): self.redis = redis_client self.ttl = ttl
def get_cache_key(self, query, k): Hacher la requête pour la clé de cache query_hash = hashlib.md5(f"{query}:{k}".encode()).hexdigest() return f"rag:query:{query_hash}"
def get(self, query, k): key = self.get_cache_key(query, k) cached = self.redis.get(key)
if cached: return json.loads(cached) return None
def set(self, query, k, result): key = self.get_cache_key(query, k) self.redis.setex(key, self.ttl, json.dumps(result))
Utilisation cache = QueryCache(redis.Redis(host='localhost'))
def rag_query(query, k=5): Vérifier le cache cached = cache.get(query, k) if cached: return cached
Exécuter le pipeline RAG result = execute_rag_pipeline(query, k)
Mettre en cache le résultat cache.set(query, k, result)
return result `
Cache d'Embeddings
`python class EmbeddingCache: def __init__(self, redis_client): self.redis = redis_client
def get_embedding(self, text): key = f"emb:{hashlib.md5(text.encode()).hexdigest()}"
Vérifier le cache cached = self.redis.get(key) if cached: return np.frombuffer(cached, dtype=np.float32)
Générer l'embedding embedding = embed_model.encode(text)
Mettre en cache (pas de TTL - les embeddings ne changent pas) self.redis.set(key, embedding.tobytes())
return embedding `
Invalidation du Cache
`python def update_document(doc_id, new_content): Mettre à jour la base de données vectorielle embedding = embed(new_content) vector_db.upsert(doc_id, embedding, metadata={'content': new_content})
Invalider les entrées de cache associées invalidate_cache_for_document(doc_id)
def invalidate_cache_for_document(doc_id): Trouver toutes les requêtes en cache qui ont récupéré ce document pattern = f"rag:query:" for key in redis.scan_iter(match=pattern): cached_result = json.loads(redis.get(key))
Si ce document était dans les résultats, invalider if any(doc['id'] == doc_id for doc in cached_result.get('documents', [])): redis.delete(key) `
Gestion des Erreurs
Dégradation Gracieuse
`python class RobustRAG: def __init__(self, primary_llm, fallback_llm, vector_db): self.primary_llm = primary_llm self.fallback_llm = fallback_llm self.vector_db = vector_db
async def query(self, user_query, k=5): try: Essayer la récupération primaire contexts = await self.vector_db.search(user_query, k=k) except Exception as e: logger.error(f"Vector DB error: {e}")
Secours : retourner des contextes vides contexts = [] Ou : utiliser une recherche par mots-clés de secours contexts = await self.keyword_search_fallback(user_query)
try: Essayer le LLM primaire answer = await self.primary_llm.generate( query=user_query, contexts=contexts ) except Exception as e: logger.error(f"Primary LLM error: {e}")
try: Basculer vers le LLM secondaire answer = await self.fallback_llm.generate( query=user_query, contexts=contexts ) except Exception as e2: logger.error(f"Fallback LLM error: {e2}")
Secours ultime answer = "I'm experiencing technical difficulties. Please try again later."
return { 'answer': answer, 'contexts': contexts, 'fallback_used': False Suivre pour la surveillance } `
Logique de Retry
`python from tenacity import retry, stop_after_attempt, wait_exponential
@retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10), reraise=True ) async def resilient_llm_call(llm, prompt): """ Retry avec backoff exponentiel • Tentative 1 : immédiat • Tentative 2 : attendre 1s • Tentative 3 : attendre 2s """ return await llm.generate(prompt) `
Circuit Breaker
`python class CircuitBreaker: def __init__(self, failure_threshold=5, timeout=60): self.failure_count = 0 self.failure_threshold = failure_threshold self.timeout = timeout self.state = 'CLOSED' CLOSED, OPEN, HALF_OPEN self.last_failure_time = None
async def call(self, func, args, kwargs): if self.state == 'OPEN': if time.time() - self.last_failure_time > self.timeout: self.state = 'HALF_OPEN' else: raise Exception("Circuit breaker is OPEN")
try: result = await func(args, *kwargs)
Succès : réinitialiser if self.state == 'HALF_OPEN': self.state = 'CLOSED' self.failure_count = 0
return result
except Exception as e: self.failure_count += 1 self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold: self.state = 'OPEN'
raise e
Utilisation llm_breaker = CircuitBreaker(failure_threshold=5, timeout=60)
async def safe_llm_call(prompt): return await llm_breaker.call(llm.generate, prompt) `
Surveillance
Métriques
`python from prometheus_client import Counter, Histogram, Gauge
Métriques de requête query_counter = Counter('rag_queries_total', 'Total RAG queries') query_duration = Histogram('rag_query_duration_seconds', 'Query duration') error_counter = Counter('rag_errors_total', 'Total errors', ['type'])
Métriques de composants retrieval_latency = Histogram('rag_retrieval_latency_seconds', 'Retrieval latency') llm_latency = Histogram('rag_llm_latency_seconds', 'LLM latency')
Métriques de qualité avg_precision = Gauge('rag_precision_at_5', 'Average precision@5') thumbs_up_rate = Gauge('rag_thumbs_up_rate', 'Thumbs up rate')
Utilisation @query_duration.time() async def handle_query(query): query_counter.inc()
try: Récupération with retrieval_latency.time(): contexts = await retrieve(query)
Génération with llm_latency.time(): answer = await generate(query, contexts)
return answer
except Exception as e: error_counter.labels(type=type(e).__name__).inc() raise `
Journalisation
`python import logging import json
Journalisation structurée class JSONFormatter(logging.Formatter): def format(self, record): log_data = { 'timestamp': self.formatTime(record), 'level': record.levelname, 'message': record.getMessage(), 'module': record.module, }
if hasattr(record, 'query'): log_data['query'] = record.query if hasattr(record, 'latency'): log_data['latency'] = record.latency
return json.dumps(log_data)
Configurer le logger logger = logging.getLogger('rag') handler = logging.StreamHandler() handler.setFormatter(JSONFormatter()) logger.addHandler(handler) logger.setLevel(logging.INFO)
Utilisation logger.info('Query processed', extra={ 'query': user_query, 'latency': latency_ms, 'num_contexts': len(contexts) }) `
Alertes
`yaml Règles d'alerte Prometheus groups: • name: rag_alerts rules: • alert: HighErrorRate expr: rate(rag_errors_total[5m]) > 0.1 for: 5m annotations: summary: "High error rate in RAG system" description: "Error rate is {{ $value }} errors/sec" • alert: HighLatency expr: histogram_quantile(0.95, rate(rag_query_duration_seconds_bucket[5m])) > 5 for: 10m annotations: summary: "High query latency" description: "p95 latency is {{ $value }}s" • alert: LowPrecision expr: rag_precision_at_5 < 0.6 for: 30m annotations: summary: "Retrieval precision degraded" description: "Precision@5 is {{ $value }}" `
Pipeline de Données
Ingestion de Documents
`python class DocumentIngestionPipeline: def __init__(self, vector_db, embedding_service): self.vector_db = vector_db self.embedding_service = embedding_service
async def ingest_document(self, document): try: Extraire le texte text = extract_text(document) Découper chunks = chunk_document(text, chunk_size=512, overlap=50) Créer les embeddings (par lots) embeddings = await self.embedding_service.embed_batch(chunks) Télécharger vers la base de données vectorielle (par lots) await self.vector_db.upsert_batch( ids=[f"{document.id}_{i}" for i in range(len(chunks))], embeddings=embeddings, metadatas=[{ 'doc_id': document.id, 'chunk_index': i, 'content': chunk } for i, chunk in enumerate(chunks)] )
logger.info(f"Ingested document {document.id}: {len(chunks)} chunks")
except Exception as e: logger.error(f"Failed to ingest document {document.id}: {e}") raise
Job en arrière-plan async def batch_ingestion(): pipeline = DocumentIngestionPipeline(vector_db, embedding_service)
Obtenir les nouveaux documents new_docs = await get_new_documents()
Traiter en parallèle tasks = [pipeline.ingest_document(doc) for doc in new_docs] await asyncio.gather(tasks, return_exceptions=True) `
Mises à Jour Incrémentales
`python async def update_document(doc_id, new_content): Supprimer les anciens chunks await vector_db.delete(filter={'doc_id': doc_id})
Ingérer la nouvelle version await ingestion_pipeline.ingest_document({ 'id': doc_id, 'content': new_content })
Invalider les caches invalidate_cache_for_document(doc_id) `
Optimisation des Coûts
Embeddings
`python Embedding par lots pour un meilleur débit/coût async def cost_optimized_embedding(texts, batch_size=100): embeddings = []
for i in range(0, len(texts), batch_size): batch = texts[i:i+batch_size]
Appel API unique pour le lot batch_embeddings = await embedding_api.embed_batch(batch) embeddings.extend(batch_embeddings)
Limitation de débit await asyncio.sleep(0.1)
return embeddings `
Appels LLM
`python Réduire l'utilisation de tokens def optimize_context(query, chunks, max_tokens=2000): """ Ajuster autant de contexte pertinent que possible dans le budget de tokens """ selected_chunks = [] total_tokens = 0
for chunk in chunks: chunk_tokens = count_tokens(chunk)
if total_tokens + chunk_tokens <= max_tokens: selected_chunks.append(chunk) total_tokens += chunk_tokens else: break
return selected_chunks `
Stratégie de Cache • Mettre en cache les embeddings indéfiniment (déterministe) • Mettre en cache les requêtes fréquentes (TTL 1 heure) • Mettre en cache les patterns de requêtes populaires (TTL 24 heures) • Invalider lors des mises à jour de contenu
Sécurité
Authentification API
`python from fastapi import Depends, HTTPException, Security from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
security = HTTPBearer()
async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)): token = credentials.credentials
Vérifier le JWT try: payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) return payload except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") except jwt.InvalidTokenError: raise HTTPException(status_code=401, detail="Invalid token")
@app.post("/query") async def query(request: QueryRequest, user = Depends(verify_token)): Traiter la requête return await rag_system.query(request.query) `
Limitation de Débit
`python from slowapi import Limiter from slowapi.util import get_remote_address
limiter = Limiter(key_func=get_remote_address)
@app.post("/query") @limiter.limit("100/hour") async def query(request: QueryRequest): return await rag_system.query(request.query) `
Sanitisation des Entrées
`python def sanitize_query(query: str) -> str: Limiter la longueur max_length = 1000 query = query[:max_length]
Supprimer les caractères potentiellement dangereux query = re.sub(r'[^\w\s\?.,!-]', '', query)
Prévenir l'injection de prompt injection_patterns = [ r'ignore previous instructions', r'disregard above', r'system:', ]
for pattern in injection_patterns: if re.search(pattern, query, re.IGNORECASE): raise ValueError("Potentially harmful query detected")
return query ``
Checklist de Déploiement • [ ] Base de données vectorielle : Répliquée, sauvegardée • [ ] API : Auto-scaling, health checks • [ ] Mise en cache : Cluster Redis, politique d'éviction • [ ] Surveillance : Métriques, logs, alertes • [ ] Gestion des erreurs : Retries, secours, circuit breakers • [ ] Sécurité : Authentification, limitation de débit, validation des entrées • [ ] Documentation : Docs API, runbooks • [ ] Tests : Tests de charge, tests de chaos • [ ] Plan de rollback : Déploiement blue-green • [ ] Astreinte : Runbooks, procédures d'escalade
Prochaines Étapes
L'optimisation du traitement des requêtes et la gestion efficace des fenêtres de contexte sont critiques pour le coût et la qualité. Les guides suivants couvrent les techniques d'optimisation de requêtes et les stratégies de gestion de fenêtre de contexte.