Skip to content

Learning Module

The learning module implements learning policies that improve routing, agent, and tool decisions based on historical interaction outcomes. The module provides a LearningPolicy ABC taxonomy with specialized sub-ABCs for intelligence (model routing), agent behavior, and tool selection. It also includes reward functions for scoring inference results.

Abstract Base Classes

RouterPolicy

RouterPolicy

Bases: ABC

Model selection policy (used by the learning system).

Functions

select_model abstractmethod

select_model(context: 'RoutingContext') -> str

Select the best model key for the given routing context.

Source code in src/openjarvis/learning/_stubs.py
@abstractmethod
def select_model(self, context: "RoutingContext") -> str:
    """Select the best model key for the given routing context."""

QueryAnalyzer

QueryAnalyzer

Bases: ABC

Query analysis for routing contexts.

Functions

analyze abstractmethod

analyze(query: str, **kwargs: object) -> 'RoutingContext'

Analyze a query and return a RoutingContext.

Source code in src/openjarvis/learning/_stubs.py
@abstractmethod
def analyze(self, query: str, **kwargs: object) -> "RoutingContext":
    """Analyze a query and return a RoutingContext."""

RoutingContext

RoutingContext is defined in core/types.py.

RewardFunction

RewardFunction

Bases: ABC

Compute a scalar reward for a routing decision.

Functions

compute abstractmethod

compute(context: 'RoutingContext', model_key: str, response: str, **kwargs: object) -> float

Return reward in [0, 1].

Source code in src/openjarvis/learning/_stubs.py
@abstractmethod
def compute(
    self,
    context: "RoutingContext",
    model_key: str,
    response: str,
    **kwargs: object,
) -> float:
    """Return reward in [0, 1]."""

LearningPolicy Taxonomy

The learning system defines a hierarchy of learning policy ABCs:

  • LearningPolicy -- base ABC for all learning policies
  • IntelligenceLearningPolicy -- specialization for model routing decisions
  • AgentLearningPolicy -- specialization for agent behavior advice (ICL examples, tool-use strategies)

Policy Implementations

TraceDrivenPolicy

TraceDrivenPolicy

TraceDrivenPolicy(analyzer: Optional[TraceAnalyzer] = None, *, available_models: Optional[List[str]] = None, default_model: str = '', fallback_model: str = '')

Bases: RouterPolicy

Router policy that learns from historical traces.

Maintains a mapping of query_class → best_model derived from trace outcomes. Falls back to the provided default when no trace data is available for a query class.

The policy is updated by calling :meth:update_from_traces, which reads the TraceAnalyzer and recomputes the mapping.

Source code in src/openjarvis/learning/trace_policy.py
def __init__(
    self,
    analyzer: Optional[TraceAnalyzer] = None,
    *,
    available_models: Optional[List[str]] = None,
    default_model: str = "",
    fallback_model: str = "",
) -> None:
    self._analyzer = analyzer
    self._available = available_models or []
    self._default = default_model
    self._fallback = fallback_model
    # Learned mapping: query_class → model key
    self._policy_map: Dict[str, str] = {}
    # Track confidence: query_class → sample count
    self._confidence: Dict[str, int] = {}
    # Minimum samples before trusting learned policy
    self.min_samples: int = 5

Attributes

policy_map property

policy_map: Dict[str, str]

Current learned routing decisions (read-only copy).

Functions

select_model

select_model(context: RoutingContext) -> str

Select the best model based on learned policy or fallback.

Source code in src/openjarvis/learning/trace_policy.py
def select_model(self, context: RoutingContext) -> str:
    """Select the best model based on learned policy or fallback."""
    query_class = classify_query(context.query)

    # Use learned policy if we have enough confidence
    if (
        query_class in self._policy_map
        and self._confidence.get(query_class, 0) >= self.min_samples
    ):
        model = self._policy_map[query_class]
        if not self._available or model in self._available:
            return model

    # Fallback chain
    avail = self._available
    if self._default and (not avail or self._default in avail):
        return self._default
    if self._fallback and (not avail or self._fallback in avail):
        return self._fallback
    if self._available:
        return self._available[0]
    return self._default or ""

update_from_traces

update_from_traces(*, since: Optional[float] = None, until: Optional[float] = None) -> Dict[str, Any]

Recompute the policy map from trace history.

Returns a summary of what changed for logging/debugging.

Source code in src/openjarvis/learning/trace_policy.py
def update_from_traces(
    self,
    *,
    since: Optional[float] = None,
    until: Optional[float] = None,
) -> Dict[str, Any]:
    """Recompute the policy map from trace history.

    Returns a summary of what changed for logging/debugging.
    """
    if self._analyzer is None:
        return {"error": "no analyzer configured"}

    traces = self._analyzer._store.list_traces(
        since=since, until=until, limit=10_000
    )
    if not traces:
        return {"updated": False, "reason": "no traces"}

    # Group traces by query class
    groups: Dict[str, list] = {}
    for t in traces:
        qclass = classify_query(t.query)
        groups.setdefault(qclass, []).append(t)

    old_map = dict(self._policy_map)
    changes: Dict[str, Dict[str, str]] = {}

    for qclass, class_traces in groups.items():
        # Score each model for this query class
        model_scores: Dict[str, _ModelScore] = {}
        for t in class_traces:
            if not t.model:
                continue
            if t.model not in model_scores:
                model_scores[t.model] = _ModelScore()
            score = model_scores[t.model]
            score.count += 1
            score.total_latency += t.total_latency_seconds
            if t.outcome == "success":
                score.successes += 1
            if t.feedback is not None:
                score.feedback_sum += t.feedback
                score.feedback_count += 1

        if not model_scores:
            continue

        # Pick the best model: weighted score of success_rate and feedback
        best_model = max(
            model_scores.items(),
            key=lambda kv: kv[1].composite_score(),
        )[0]

        self._policy_map[qclass] = best_model
        self._confidence[qclass] = sum(s.count for s in model_scores.values())

        if old_map.get(qclass) != best_model:
            changes[qclass] = {
                "old": old_map.get(qclass, ""),
                "new": best_model,
            }

    return {
        "updated": True,
        "query_classes": len(groups),
        "total_traces": len(traces),
        "changes": changes,
    }

observe

observe(query: str, model: str, outcome: Optional[str], feedback: Optional[float]) -> None

Record a single observation for online (incremental) updates.

This is a lighter-weight alternative to :meth:update_from_traces for use cases where you want to update the policy after every interaction rather than in batch.

Source code in src/openjarvis/learning/trace_policy.py
def observe(
    self,
    query: str,
    model: str,
    outcome: Optional[str],
    feedback: Optional[float],
) -> None:
    """Record a single observation for online (incremental) updates.

    This is a lighter-weight alternative to :meth:`update_from_traces`
    for use cases where you want to update the policy after every
    interaction rather than in batch.
    """
    qclass = classify_query(query)
    current_count = self._confidence.get(qclass, 0)

    # Simple exponential moving average for online update
    if qclass not in self._policy_map:
        self._policy_map[qclass] = model
        self._confidence[qclass] = 1
        return

    self._confidence[qclass] = current_count + 1

    # Only switch models if the new model shows clearly better outcomes
    if outcome == "success" and feedback is not None and feedback > 0.7:
        # Weight new evidence against existing policy
        if current_count < self.min_samples:
            self._policy_map[qclass] = model

classify_query

classify_query

classify_query(query: str) -> str

Classify a query into a broad category for routing.

Source code in src/openjarvis/learning/trace_policy.py
def classify_query(query: str) -> str:
    """Classify a query into a broad category for routing."""
    if _CODE_RE.search(query):
        return "code"
    if _MATH_RE.search(query):
        return "math"
    if len(query) < 50:
        return "short"
    if len(query) > 500:
        return "long"
    return "general"

SFTRouterPolicy

SFTRouterPolicy

SFTRouterPolicy(*, min_samples: int = 5)

Bases: IntelligenceLearningPolicy

Trace-driven router that learns query_class → model mappings.

Reads historical traces, groups by query class (code, math, short, long, general), scores each model via a composite metric (60% outcome + 40% feedback), and produces a routing table that maps query classes to their best-performing model.

Source code in src/openjarvis/learning/sft_policy.py
def __init__(self, *, min_samples: int = 5) -> None:
    self._min_samples = min_samples
    self._policy_map: Dict[str, str] = {}

Functions

update

update(trace_store: Any, **kwargs: object) -> Dict[str, Any]

Analyze trace outcomes and update the policy map.

Source code in src/openjarvis/learning/sft_policy.py
def update(self, trace_store: Any, **kwargs: object) -> Dict[str, Any]:
    """Analyze trace outcomes and update the policy map."""
    try:
        traces = trace_store.list_traces()
    except Exception:
        return {"updated": False, "reason": "Could not access trace store"}

    # Group traces by query class and model
    class_model_scores: Dict[str, Dict[str, List[float]]] = {}
    for trace in traces:
        query_class = self._classify_query(trace.query)
        model = trace.model or "unknown"
        outcome_score = 1.0 if trace.outcome == "success" else 0.0
        feedback = trace.feedback if trace.feedback is not None else 0.5

        composite = 0.6 * outcome_score + 0.4 * feedback

        if query_class not in class_model_scores:
            class_model_scores[query_class] = {}
        if model not in class_model_scores[query_class]:
            class_model_scores[query_class][model] = []
        class_model_scores[query_class][model].append(composite)

    # Update policy map: best model per query class
    changes = {}
    for qclass, model_scores in class_model_scores.items():
        best_model = None
        best_score = -1.0
        for model, scores in model_scores.items():
            if len(scores) >= self._min_samples:
                avg = sum(scores) / len(scores)
                if avg > best_score:
                    best_score = avg
                    best_model = model
        if best_model and best_model != self._policy_map.get(qclass):
            self._policy_map[qclass] = best_model
            changes[qclass] = best_model

    return {
        "updated": bool(changes),
        "changes": changes,
        "policy_map": dict(self._policy_map),
    }

AgentAdvisorPolicy

AgentAdvisorPolicy

AgentAdvisorPolicy(*, advisor_engine: Any = None, advisor_model: str = '', max_traces: int = 50)

Bases: AgentLearningPolicy

Higher-level LM analyzes traces, suggests agent structure changes.

Does NOT auto-apply changes — returns recommendations that can be reviewed or applied via config.

Source code in src/openjarvis/learning/agent_advisor.py
def __init__(
    self,
    *,
    advisor_engine: Any = None,
    advisor_model: str = "",
    max_traces: int = 50,
) -> None:
    self._advisor_engine = advisor_engine
    self._advisor_model = advisor_model
    self._max_traces = max_traces

Functions

update

update(trace_store: Any, **kwargs: object) -> Dict[str, Any]

Analyze traces and return agent improvement recommendations.

Source code in src/openjarvis/learning/agent_advisor.py
def update(self, trace_store: Any, **kwargs: object) -> Dict[str, Any]:
    """Analyze traces and return agent improvement recommendations."""
    try:
        traces = trace_store.list_traces()
    except Exception:
        return {"recommendations": [], "confidence": 0.0}

    # Collect failing or slow traces
    problem_traces = []
    for trace in traces[-self._max_traces :]:
        is_failing = trace.outcome != "success"
        is_slow = (trace.total_latency_seconds or 0) > 5.0
        if is_failing or is_slow:
            problem_traces.append(trace)

    if not problem_traces:
        return {
            "recommendations": [],
            "confidence": 1.0,
            "message": "No problematic traces found",
        }

    # Analyze patterns without LM (structural analysis)
    recommendations = self._analyze_patterns(problem_traces)

    # If advisor engine available, get LM-guided recommendations
    if self._advisor_engine and self._advisor_model:
        try:
            lm_recs = self._get_lm_recommendations(problem_traces)
            recommendations.extend(lm_recs)
        except Exception:
            pass

    confidence = 1.0 - (len(problem_traces) / max(len(traces), 1))
    return {
        "recommendations": recommendations,
        "confidence": round(confidence, 2),
        "analyzed_traces": len(traces),
        "problem_traces": len(problem_traces),
    }

ICLUpdaterPolicy

ICLUpdaterPolicy

ICLUpdaterPolicy(*, min_score: float = 0.7, max_examples: int = 20, min_skill_occurrences: int = 3)

Bases: AgentLearningPolicy

Updates in-context examples and discovers skills from traces.

Analyzes traces for successful tool call patterns, extracts in-context learning examples, and discovers reusable multi-tool sequences ("skills"). This updates agent logic (ICL examples and tool-use strategies), not tool implementations themselves.

Source code in src/openjarvis/learning/icl_updater.py
def __init__(
    self,
    *,
    min_score: float = 0.7,
    max_examples: int = 20,
    min_skill_occurrences: int = 3,
) -> None:
    self._min_score = min_score
    self._max_examples = max_examples
    self._min_skill_occurrences = min_skill_occurrences
    self._examples: List[Dict[str, Any]] = []
    self._skills: List[Dict[str, Any]] = []

Functions

update

update(trace_store: Any, **kwargs: object) -> Dict[str, Any]

Analyze traces and extract ICL examples + skills.

Source code in src/openjarvis/learning/icl_updater.py
def update(self, trace_store: Any, **kwargs: object) -> Dict[str, Any]:
    """Analyze traces and extract ICL examples + skills."""
    try:
        traces = trace_store.list_traces()
    except Exception:
        return {"examples": [], "skills": []}

    # Extract high-scoring traces with tool calls
    new_examples: List[Dict[str, Any]] = []
    tool_sequences: List[List[str]] = []

    for trace in traces:
        # Only consider successful traces
        if trace.outcome != "success":
            continue

        feedback = trace.feedback if trace.feedback is not None else 0.5
        if feedback < self._min_score:
            continue

        # Extract tool call steps
        tool_steps = [
            s
            for s in (trace.steps or [])
            if s.step_type.value == "tool_call"
        ]

        if tool_steps:
            # Build ICL example
            tool_names = [
                s.metadata.get("tool_name", "")
                for s in tool_steps
            ]
            example = {
                "query": trace.query,
                "tools_used": tool_names,
                "outcome": trace.outcome,
                "score": feedback,
            }
            new_examples.append(example)
            tool_sequences.append(tool_names)

    # Keep top examples by score
    new_examples.sort(key=lambda x: x["score"], reverse=True)
    self._examples = new_examples[: self._max_examples]

    # Discover skills: recurring multi-tool sequences
    self._skills = self._discover_skills(tool_sequences)

    return {
        "examples": list(self._examples),
        "skills": list(self._skills),
        "traces_analyzed": len(traces),
    }

GRPORouterPolicy

GRPORouterPolicy

GRPORouterPolicy(**kwargs: Any)

Bases: RouterPolicy

Placeholder for GRPO-trained router policy (Phase 5).

Raises NotImplementedError until training infrastructure is ready.

Source code in src/openjarvis/learning/grpo_policy.py
def __init__(self, **kwargs: Any) -> None:
    self._kwargs = kwargs

Reward Functions

HeuristicRewardFunction

HeuristicRewardFunction

HeuristicRewardFunction(*, weight_latency: float = 0.4, weight_cost: float = 0.3, weight_efficiency: float = 0.3, max_latency: float = 30.0, max_cost: float = 0.01)

Bases: RewardFunction

Computes a scalar reward based on latency, cost, and token efficiency.

Each component is normalised to [0, 1] and combined via a weighted sum.

Source code in src/openjarvis/learning/heuristic_reward.py
def __init__(
    self,
    *,
    weight_latency: float = 0.4,
    weight_cost: float = 0.3,
    weight_efficiency: float = 0.3,
    max_latency: float = 30.0,
    max_cost: float = 0.01,
) -> None:
    self.weight_latency = weight_latency
    self.weight_cost = weight_cost
    self.weight_efficiency = weight_efficiency
    self.max_latency = max_latency
    self.max_cost = max_cost