コンテンツにスキップ

推薦システムA/Bテスト実装 - 統計的厳密性とガードレール設計

この記事は朝の記事のフォローアップです

基礎記事: Soraアプリ推薦アルゴリズム実装

対象読者: 推薦システムの実装経験がある中級〜上級エンジニア

ゴール

  • サンプルサイズ計算と統計検定の実装方法を習得
  • ガードレール指標によるリスク管理フレームワークを理解
  • バケット分割とトラフィック管理の実践パターンを確認

アーキテクチャ概要

推薦システムのA/Bテストは、単なるランダム振り分けでは不十分だ。統計的検出力の確保、複数指標の同時監視、ユーザー体験の保護を両立させる必要がある。

graph TD
    A[実験設計] --> B[サンプルサイズ計算]
    B --> C[バケット分割]
    C --> D[トラフィック配分]
    D --> E[ガードレール監視]
    E --> F{異常検知}
    F -->|正常| G[データ収集継続]
    F -->|異常| H[自動停止]
    G --> I[統計検定]
    I --> J[意思決定]

実装ステップ

ステップ1: サンプルサイズ計算

検出力分析の実装:

import numpy as np
from scipy import stats

def calculate_sample_size(
    baseline_rate: float,
    mde: float,  # Minimum Detectable Effect
    alpha: float = 0.05,
    power: float = 0.80
) -> int:
    """
    二項分布指標のサンプルサイズ計算

    Args:
        baseline_rate: ベースライン転換率(例: 0.05 = 5%)
        mde: 検出したい最小効果(例: 0.01 = 1pp)
        alpha: 第一種過誤率(有意水準)
        power: 検出力(1 - 第二種過誤率)

    Returns:
        各群に必要なサンプル数
    """
    # Z値の計算
    z_alpha = stats.norm.ppf(1 - alpha / 2)
    z_beta = stats.norm.ppf(power)

    # 期待値
    p1 = baseline_rate
    p2 = baseline_rate + mde
    p_pooled = (p1 + p2) / 2

    # サンプルサイズ公式
    numerator = (z_alpha * np.sqrt(2 * p_pooled * (1 - p_pooled)) +
                 z_beta * np.sqrt(p1 * (1 - p1) + p2 * (1 - p2))) ** 2
    denominator = (p2 - p1) ** 2

    n = int(np.ceil(numerator / denominator))
    return n

# 実行例
n_per_group = calculate_sample_size(
    baseline_rate=0.05,  # 現在の創作率5%
    mde=0.01,            # 1pp改善を検出したい
    alpha=0.05,
    power=0.80
)
print(f"各群に必要なユーザー数: {n_per_group:,}")
# 出力例: 各群に必要なユーザー数: 6,194

ステップ2: バケット管理システム

ハッシュベース安定バケット:

import hashlib
from typing import Literal

class BucketManager:
    """
    ユーザーIDベースの安定したバケット割り当て
    """

    def __init__(self, experiment_id: str, num_buckets: int = 100):
        self.experiment_id = experiment_id
        self.num_buckets = num_buckets

    def assign_bucket(self, user_id: str) -> int:
        """
        ユーザーをバケットに割り当て(決定論的)
        """
        # ユーザーIDと実験IDを結合してハッシュ化
        hash_input = f"{self.experiment_id}:{user_id}"
        hash_value = hashlib.md5(hash_input.encode()).hexdigest()

        # 0-99のバケット番号に変換
        bucket = int(hash_value, 16) % self.num_buckets
        return bucket

    def get_variant(
        self,
        user_id: str,
        control_pct: float = 50.0,
        treatment_pct: float = 50.0
    ) -> Literal['control', 'treatment', 'holdout']:
        """
        バケット番号から実験群を決定

        Args:
            control_pct: コントロール群の割合(0-100)
            treatment_pct: トリートメント群の割合(0-100)
            残りはholdout群(分析対象外)
        """
        bucket = self.assign_bucket(user_id)

        if bucket < control_pct:
            return 'control'
        elif bucket < control_pct + treatment_pct:
            return 'treatment'
        else:
            return 'holdout'

# 使用例
manager = BucketManager(experiment_id="inspiration_weight_v1")
variant = manager.get_variant(
    user_id="user_12345",
    control_pct=45.0,
    treatment_pct=45.0
)
print(f"ユーザー割り当て: {variant}")

ステップ3: ガードレール監視システム

リアルタイム異常検知:

from dataclasses import dataclass
from typing import Dict, List
import pandas as pd

@dataclass
class GuardrailThreshold:
    """ガードレール閾値定義"""
    metric_name: str
    min_value: float = None
    max_value: float = None
    relative_change: float = None  # ベースラインからの相対変化率

class GuardrailMonitor:
    """
    ガードレール指標の監視と自動停止
    """

    def __init__(self, thresholds: List[GuardrailThreshold]):
        self.thresholds = {t.metric_name: t for t in thresholds}
        self.violations = []

    def check_metrics(
        self,
        treatment_metrics: Dict[str, float],
        control_metrics: Dict[str, float]
    ) -> bool:
        """
        ガードレール違反をチェック

        Returns:
            True: 実験継続OK, False: 停止必要
        """
        self.violations = []

        for metric_name, threshold in self.thresholds.items():
            treatment_value = treatment_metrics.get(metric_name)
            control_value = control_metrics.get(metric_name)

            if treatment_value is None:
                continue

            # 絶対値チェック
            if threshold.min_value and treatment_value < threshold.min_value:
                self.violations.append(
                    f"{metric_name}: {treatment_value:.4f} < min {threshold.min_value}"
                )

            if threshold.max_value and treatment_value > threshold.max_value:
                self.violations.append(
                    f"{metric_name}: {treatment_value:.4f} > max {threshold.max_value}"
                )

            # 相対変化チェック
            if threshold.relative_change and control_value:
                pct_change = (treatment_value - control_value) / control_value
                if abs(pct_change) > threshold.relative_change:
                    self.violations.append(
                        f"{metric_name}: {pct_change:.2%} change > threshold {threshold.relative_change:.2%}"
                    )

        return len(self.violations) == 0

# 設定例
guardrails = [
    GuardrailThreshold(
        metric_name='avg_session_duration',
        min_value=300,  # 最低5分
        relative_change=0.30  # ±30%以内
    ),
    GuardrailThreshold(
        metric_name='bounce_rate',
        max_value=0.70,  # 最大70%
        relative_change=0.20
    ),
    GuardrailThreshold(
        metric_name='crash_rate',
        max_value=0.01  # 最大1%
    )
]

monitor = GuardrailMonitor(guardrails)

ベンチマーク: 統計検定比較

検定手法計算速度小サンプル精度多重比較対応推奨用途
t検定⭐⭐⭐⭐⭐単一指標・正規分布
Mann-Whitney U⭐⭐⭐⭐⭐非正規分布
Bootstrap⭐⭐⭐複雑な指標
Sequential⭐⭐⭐⭐⭐早期停止必要

失敗パターンと回避策

症状原因回避策
p値が0.049→0.051を往復サンプルサイズ不足事前計算の厳守・早期判断禁止
ガードレール誤検知頻発閾値が厳しすぎる過去データで閾値を校正
バケット偏り発生ハッシュ関数の偏りMD5/SHA256の使用・分布検証
新規ユーザーが常にtx群実験IDが固定実験ごとにIDを変更

統計検定の実装

Sequential Testing(逐次検定):

from typing import Tuple

class SequentialTest:
    """
    Always Valid Inference (AVI)による逐次検定
    """

    def __init__(self, alpha: float = 0.05):
        self.alpha = alpha
        # Robbins-Siegmund境界の簡易実装
        self.boundary_constant = np.sqrt(-2 * np.log(alpha))

    def test(
        self,
        control_conversions: int,
        control_total: int,
        treatment_conversions: int,
        treatment_total: int
    ) -> Tuple[bool, float, str]:
        """
        逐次A/Bテスト実行

        Returns:
            (有意差あり, p値推定, 決定: 'continue'/'stop_treatment_wins'/'stop_no_effect')
        """
        # 比率の差の推定
        p_control = control_conversions / control_total
        p_treatment = treatment_conversions / treatment_total
        diff = p_treatment - p_control

        # 標準誤差
        se = np.sqrt(
            p_control * (1 - p_control) / control_total +
            p_treatment * (1 - p_treatment) / treatment_total
        )

        # Z統計量
        z_score = diff / se if se > 0 else 0

        # Sequential境界(サンプルサイズ依存)
        n_total = control_total + treatment_total
        boundary = self.boundary_constant / np.sqrt(n_total)

        # 決定
        if abs(z_score) > boundary:
            if z_score > 0:
                return True, self._estimate_p_value(z_score), 'stop_treatment_wins'
            else:
                return True, self._estimate_p_value(z_score), 'stop_control_wins'
        else:
            return False, None, 'continue'

    def _estimate_p_value(self, z_score: float) -> float:
        """Z統計量からp値を推定"""
        return 2 * (1 - stats.norm.cdf(abs(z_score)))

自動化・拡張パターン

  1. Multi-Armed Bandit統合: Thompson Samplingによる動的トラフィック配分
  2. 階層ベイズモデル: ユーザーセグメント別効果の同時推定
  3. 因果推論フレームワーク: 交絡因子の調整とATEの推定
  4. メタ分析パイプライン: 過去実験からの事前分布構築
  5. 自動レポート生成: 統計検定結果の可視化とSlack通知

次のステップ


実装時の注意点:

  • Sequential Testingは早期停止を可能にするが、従来のp値解釈は適用不可
  • ガードレール閾値は過去データの95パーセンタイルから設定を推奨
  • バケット分割はユーザー単位が基本、セッション単位は交絡リスク大

参考文献:

  • Johari et al. (2017): "Peeking at A/B Tests"
  • Netflix Tech Blog: "Building Confidence in A/B Testing"
  • Optimizely Stats Engine: Sequential Testing whitepaper