コンテンツにスキップ

RAGパイプライン実装ガイド: SageMakerでの構築パターンと最適化

この記事はRAGパイプラインの実装に焦点を絞った実践ガイドです

SageMakerでのスケーラブルなRAG実装をマスターしましょう

ゴール

  • チャンキング戦略の定量比較実装
  • エンベディングモデルの性能測定
  • リトリーバル精度の自動評価

アーキテクチャ概要

RAGパイプラインの基本構成は以下の3段階で構築します。

graph LR
    A[Document Input] --> B[Chunking]
    B --> C[Embedding]
    C --> D[Vector Store]
    D --> E[Retrieval]
    E --> F[LLM Generation]

RAG方式選定マトリクス

RAGシステムの方式選定は、ユースケースの要件によって大きく異なります。以下のマトリクスを参考に、最適なアプローチを選択してください(AWS Prescriptive Guidance 2026を参照)。

方式Recall@kFaithfulnessレイテンシコスト推奨ユースケース
Vector検索のみ中 (0.72)低 (<150ms)FAQ、単純なQ&A
Hybrid検索 (Vector + BM25)高 (0.85)中〜高中 (150-300ms)社内文書検索、技術ドキュメント
Hybrid + Re-ranking高 (0.91)中〜高 (300-500ms)中〜高法務・コンプライアンス文書
Hybrid + Re-ranking + Semantic Cache高 (0.91)低 (キャッシュヒット時<50ms)高頻度アクセスのカスタマーサポート
Streaming RAG中〜高 (0.83)体感低 (TTFT<200ms)リアルタイムチャット、対話型UI

方式選定の判断基準

  • 精度優先: Hybrid + Re-ranking(Recall@k > 0.90を目指す場合)
  • レイテンシ優先: Vector検索のみ、またはSemantic Cache併用
  • コスト優先: Vector検索のみ(エンベディング計算のみで完結)
  • バランス型: Hybrid検索(多くのプロダクション環境で推奨)

実装ステップ

ステップ1: チャンキングは目的関数で選ぶ

チャンキング戦略は単一の手法を選ぶのではなく、目的関数に応じて最適な手法を選定します。以下のトレードオフを理解した上で選択してください。

チャンキング手法Recall@kFaithfulnessレイテンシコスト適用場面
固定サイズ (512トークン)0.720.68プロトタイプ、大量文書の初期処理
適応型 (200-800トークン)0.840.79汎用的な文書検索
セマンティック分割0.890.88精度重視の専門文書
親子チャンク (Parent-Child)0.870.91中〜高中〜高コンテキスト保持が重要な長文書
文単位 + オーバーラップ0.800.82法務文書、正確な引用が必要な場面

チャンキング選定の注意点

Recall@kが高くてもFaithfulnessが低い場合、LLMが検索結果を正確に反映しない応答を生成するリスクがあります。両方の指標をバランスよく評価してください。

from typing import List, Dict
import tiktoken

def adaptive_chunking(text: str,
                      min_size: int = 200,
                      max_size: int = 800) -> List[str]:
    encoder = tiktoken.get_encoding("cl100k_base")
    tokens = encoder.encode(text)

    chunks = []
    current = []
    for token in tokens:
        current.append(token)
        if len(current) >= min_size:
            if len(current) >= max_size:
                chunks.append(encoder.decode(current))
                current = []
    return chunks

ステップ2: エンベディングモデルの選定

import boto3
from sagemaker.huggingface import HuggingFaceModel

def deploy_embedding_model(model_id: str = "BAAI/bge-small-en-v1.5"):
    role = "arn:aws:iam::xxx:role/SageMakerRole"

    huggingface_model = HuggingFaceModel(
        model_data=f"s3://models/{model_id}.tar.gz",
        role=role,
        transformers_version="4.37",
        pytorch_version="2.1",
        py_version="py310"
    )

    predictor = huggingface_model.deploy(
        initial_instance_count=1,
        instance_type="ml.g5.2xlarge"
    )
    return predictor

ステップ3: リトリーバル最適化

def hybrid_retrieval(query: str,
                    k: int = 5,
                    alpha: float = 0.7) -> List[Dict]:
    # Semantic search
    semantic_results = vector_store.similarity_search(
        query, k=k*2
    )

    # Keyword search
    keyword_results = bm25_search(query, k=k*2)

    # Hybrid scoring
    combined = {}
    for doc in semantic_results:
        combined[doc.id] = alpha * doc.score
    for doc in keyword_results:
        if doc.id in combined:
            combined[doc.id] += (1-alpha) * doc.score
        else:
            combined[doc.id] = (1-alpha) * doc.score

    return sorted(combined.items(),
                 key=lambda x: x[1],
                 reverse=True)[:k]

ベンチマーク結果

チャンキング戦略平均レイテンシRecall@5FaithfulnessMRRNDCG@5コスト/1000クエリ
固定サイズ(512)120ms0.720.680.650.61$0.45
適応型(200-800)135ms0.840.790.780.74$0.52
セマンティック分割180ms0.890.880.850.82$0.68
親子チャンク195ms0.870.910.830.80$0.71

評価指標の説明

  • Recall@k: 上位k件の検索結果に正解文書が含まれる割合
  • Faithfulness: LLMの応答が検索結果に忠実である度合い
  • MRR (Mean Reciprocal Rank): 最初の正解文書の順位の逆数の平均
  • NDCG@k: 検索結果の順位を考慮した関連度の評価指標

失敗パターンと回避策

症状原因回避策
リトリーバル精度低下チャンクサイズ過小min_size=200以上に設定
タイムアウト頻発インスタンスタイプ不足ml.g5.2xlarge以上を使用
コスト超過全文エンベディングインクリメンタル更新実装

自動化・拡張案

  • GitHub Actionsでのパイプライン自動評価
  • A/Bテストによるチャンキング戦略の継続的最適化
  • CloudWatchメトリクスによるリアルタイム監視
  • Step Functionsでのワークフロー管理

評価パイプライン

RAGシステムの品質を継続的に担保するためには、定量的な評価パイプラインの構築が不可欠です。

リトリーバル評価指標

from typing import List, Dict
import numpy as np


def calculate_recall_at_k(retrieved_ids: List[str],
                          relevant_ids: List[str],
                          k: int = 5) -> float:
    """Recall@k: 上位k件に正解文書が含まれる割合"""
    retrieved_top_k = set(retrieved_ids[:k])
    relevant_set = set(relevant_ids)
    if not relevant_set:
        return 0.0
    return len(retrieved_top_k & relevant_set) / len(relevant_set)


def calculate_mrr(retrieved_ids: List[str],
                  relevant_ids: List[str]) -> float:
    """MRR: 最初の正解文書の順位の逆数"""
    relevant_set = set(relevant_ids)
    for i, doc_id in enumerate(retrieved_ids):
        if doc_id in relevant_set:
            return 1.0 / (i + 1)
    return 0.0


def calculate_ndcg_at_k(retrieved_ids: List[str],
                        relevance_scores: Dict[str, float],
                        k: int = 5) -> float:
    """NDCG@k: 順位を考慮した関連度評価"""
    dcg = 0.0
    for i, doc_id in enumerate(retrieved_ids[:k]):
        rel = relevance_scores.get(doc_id, 0.0)
        dcg += rel / np.log2(i + 2)

    ideal_scores = sorted(relevance_scores.values(), reverse=True)[:k]
    idcg = sum(rel / np.log2(i + 2) for i, rel in enumerate(ideal_scores))

    return dcg / idcg if idcg > 0 else 0.0

Faithfulness評価

def evaluate_faithfulness(response: str,
                         retrieved_contexts: List[str],
                         llm_client) -> float:
    """LLM応答が検索結果に忠実かどうかを評価"""
    prompt = f"""以下の応答が、提供されたコンテキストの情報のみに基づいているか評価してください。

コンテキスト:
{chr(10).join(retrieved_contexts)}

応答:
{response}

0.0(完全にコンテキスト外)から1.0(完全にコンテキストに忠実)のスコアで評価し、
数値のみを返してください。"""

    score = llm_client.invoke(prompt)
    return float(score.strip())


def run_evaluation_suite(test_queries: List[Dict],
                         retriever,
                         generator,
                         llm_evaluator) -> Dict:
    """評価スイートの一括実行"""
    results = {
        "recall_at_5": [],
        "mrr": [],
        "ndcg_at_5": [],
        "faithfulness": [],
    }

    for query_data in test_queries:
        query = query_data["query"]
        relevant_ids = query_data["relevant_doc_ids"]

        retrieved = retriever.search(query, k=5)
        retrieved_ids = [doc.id for doc in retrieved]
        contexts = [doc.content for doc in retrieved]
        response = generator.generate(query, contexts)

        results["recall_at_5"].append(
            calculate_recall_at_k(retrieved_ids, relevant_ids, k=5)
        )
        results["mrr"].append(
            calculate_mrr(retrieved_ids, relevant_ids)
        )
        results["faithfulness"].append(
            evaluate_faithfulness(response, contexts, llm_evaluator)
        )

    return {k: np.mean(v) for k, v in results.items()}

ガードレール実装

プロダクション環境のRAGシステムには、セキュリティとコンプライアンスのためのガードレールが必須です。

PII(個人情報)フィルタリング

import re
from typing import Tuple


def filter_pii(text: str) -> Tuple[str, List[Dict]]:
    """入出力テキストからPIIを検出・マスキング"""
    pii_patterns = {
        "email": r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
        "phone_jp": r"0\d{1,4}-\d{1,4}-\d{4}",
        "my_number": r"\d{4}\s?\d{4}\s?\d{4}",
        "credit_card": r"\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}",
    }

    detected = []
    masked_text = text
    for pii_type, pattern in pii_patterns.items():
        matches = re.finditer(pattern, masked_text)
        for match in matches:
            detected.append({
                "type": pii_type,
                "position": match.span(),
                "masked": True,
            })
            masked_text = masked_text.replace(
                match.group(), f"[{pii_type.upper()}_REDACTED]"
            )

    return masked_text, detected

プロンプトインジェクション防御

def detect_prompt_injection(user_input: str) -> Dict:
    """プロンプトインジェクション攻撃の検出"""
    injection_patterns = [
        r"ignore\s+(previous|above|all)\s+(instructions?|prompts?)",
        r"system\s*prompt",
        r"you\s+are\s+now",
        r"pretend\s+(to\s+be|you\s+are)",
        r"jailbreak",
        r"DAN\s+mode",
    ]

    risk_score = 0.0
    matched_patterns = []

    for pattern in injection_patterns:
        if re.search(pattern, user_input, re.IGNORECASE):
            risk_score += 0.3
            matched_patterns.append(pattern)

    # 長い入力や特殊文字の多用も警戒
    if len(user_input) > 2000:
        risk_score += 0.1
    if user_input.count("```") > 4:
        risk_score += 0.1

    return {
        "risk_score": min(risk_score, 1.0),
        "is_blocked": risk_score >= 0.5,
        "matched_patterns": matched_patterns,
    }

データ境界の強制

def enforce_data_boundary(query: str,
                          retrieved_docs: List[Dict],
                          user_permissions: Dict) -> List[Dict]:
    """ユーザー権限に基づくデータアクセス境界の強制"""
    allowed_docs = []

    for doc in retrieved_docs:
        doc_classification = doc.get("classification", "public")
        doc_department = doc.get("department", "general")

        # 分類レベルのチェック
        if doc_classification == "confidential":
            if "confidential" not in user_permissions.get("access_levels", []):
                continue

        # 部門アクセスのチェック
        if doc_department != "general":
            if doc_department not in user_permissions.get("departments", []):
                continue

        allowed_docs.append(doc)

    if not allowed_docs:
        allowed_docs = [{"content": "アクセス可能な関連文書が見つかりませんでした。",
                         "source": "system"}]

    return allowed_docs


def build_safe_prompt(query: str,
                      contexts: List[str],
                      system_boundary: str = "") -> str:
    """安全なプロンプトの構築"""
    boundary_instruction = system_boundary or (
        "提供されたコンテキストの情報のみに基づいて回答してください。"
        "コンテキストに含まれない情報については「情報が見つかりませんでした」と回答してください。"
        "推測や外部知識による補完は行わないでください。"
    )

    return f"""[SYSTEM] {boundary_instruction}

[CONTEXT]
{chr(10).join(contexts)}

[USER QUERY]
{query}"""

ガードレールの重要性

ガードレールなしのRAGシステムをプロダクションに展開すると、PII漏洩、プロンプトインジェクション攻撃、権限外データへのアクセスなどのリスクがあります。必ずデプロイ前にこれらの防御層を実装してください。

次のステップ

このRAGパイプライン実装を基盤として、さらなる高度化と本格的なプロダクション運用を進めてください。

  • 評価の継続的実行: CI/CDパイプラインに評価スイートを組み込み、Recall@kやFaithfulnessの回帰を検出
  • ガードレールの強化: AWS Bedrock Guardrailsとの統合による多層防御
  • A/Bテスト: チャンキング戦略やリトリーバル方式の本番環境での比較検証
  • 監視ダッシュボード: CloudWatchメトリクスで評価指標をリアルタイム可視化