OpenAI MCP Server
by arthurcolle
"""
Group Relative Policy Optimization (GRPO) for multi-agent learning in Claude Code.
This module provides a multi-agent GRPO implementation that learns from interactions.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
from dataclasses import dataclass
from collections import deque
import random
import time
@dataclass
class Experience:
"""A single step of experience for reinforcement learning."""
state: Any
action: Any
reward: float
next_state: Any
done: bool
info: Optional[Dict[str, Any]] = None
class ExperienceBuffer:
"""Buffer to store and sample experiences for training."""
def __init__(self, capacity: int = 100000):
"""
Initialize the experience buffer.
Args:
capacity: Maximum number of experiences to store
"""
self.buffer = deque(maxlen=capacity)
def add(self, experience: Experience) -> None:
"""Add an experience to the buffer."""
self.buffer.append(experience)
def sample(self, batch_size: int) -> List[Experience]:
"""Sample a batch of experiences from the buffer."""
return random.sample(self.buffer, min(batch_size, len(self.buffer)))
def __len__(self) -> int:
"""Get the current size of the buffer."""
return len(self.buffer)
class PolicyNetwork(nn.Module):
"""Neural network to represent a policy."""
def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int):
"""
Initialize the policy network.
Args:
input_dim: Dimension of the input state
hidden_dims: List of hidden layer dimensions
output_dim: Dimension of the action space
"""
super(PolicyNetwork, self).__init__()
# Create the input layer
layers = [nn.Linear(input_dim, hidden_dims[0]), nn.ReLU()]
# Create hidden layers
for i in range(len(hidden_dims) - 1):
layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
layers.append(nn.ReLU())
# Create output layer
layers.append(nn.Linear(hidden_dims[-1], output_dim))
self.network = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the network."""
return self.network(x)
class ValueNetwork(nn.Module):
"""Neural network to represent a value function."""
def __init__(self, input_dim: int, hidden_dims: List[int]):
"""
Initialize the value network.
Args:
input_dim: Dimension of the input state
hidden_dims: List of hidden layer dimensions
"""
super(ValueNetwork, self).__init__()
# Create the input layer
layers = [nn.Linear(input_dim, hidden_dims[0]), nn.ReLU()]
# Create hidden layers
for i in range(len(hidden_dims) - 1):
layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
layers.append(nn.ReLU())
# Create output layer (scalar value)
layers.append(nn.Linear(hidden_dims[-1], 1))
self.network = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the network."""
return self.network(x)
class GRPO:
"""
Group Relative Policy Optimization implementation for multi-agent learning.
GRPO extends PPO by considering relative performance within a group of agents.
"""
def __init__(
self,
state_dim: int,
action_dim: int,
hidden_dims: List[int] = [64, 64],
lr_policy: float = 3e-4,
lr_value: float = 1e-3,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_ratio: float = 0.2,
target_kl: float = 0.01,
value_coef: float = 0.5,
entropy_coef: float = 0.01,
max_grad_norm: float = 0.5,
use_gae: bool = True,
normalize_advantages: bool = True,
relative_advantage_weight: float = 0.5,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Initialize the GRPO agent.
Args:
state_dim: Dimension of the state space
action_dim: Dimension of the action space
hidden_dims: Dimensions of hidden layers in networks
lr_policy: Learning rate for policy network
lr_value: Learning rate for value network
gamma: Discount factor
gae_lambda: Lambda for GAE
clip_ratio: PPO clipping parameter
target_kl: Target KL divergence for early stopping
value_coef: Value loss coefficient
entropy_coef: Entropy bonus coefficient
max_grad_norm: Maximum gradient norm for clipping
use_gae: Whether to use GAE
normalize_advantages: Whether to normalize advantages
relative_advantage_weight: Weight for relative advantage component
device: Device to run the model on
"""
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_ratio = clip_ratio
self.target_kl = target_kl
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.use_gae = use_gae
self.normalize_advantages = normalize_advantages
self.relative_advantage_weight = relative_advantage_weight
self.device = device
# Initialize networks
self.policy = PolicyNetwork(state_dim, hidden_dims, action_dim).to(device)
self.value = ValueNetwork(state_dim, hidden_dims).to(device)
# Initialize optimizers
self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr_policy)
self.value_optimizer = optim.Adam(self.value.parameters(), lr=lr_value)
# Initialize experience buffer
self.buffer = ExperienceBuffer()
# Group-level buffers for relative advantage computation
self.group_rewards = []
self.agent_id = None # Will be set when joining a group
def set_agent_id(self, agent_id: str) -> None:
"""Set the agent's ID within the group."""
self.agent_id = agent_id
def get_action(self, state: np.ndarray, deterministic: bool = False) -> Tuple[int, float]:
"""
Get an action from the policy for the given state.
Args:
state: The current state
deterministic: Whether to return the most likely action
Returns:
Tuple of (action, log probability)
"""
# Convert state to tensor
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Get action distributions
with torch.no_grad():
logits = self.policy(state_tensor)
distribution = Categorical(logits=logits)
if deterministic:
action = torch.argmax(logits, dim=1).item()
else:
action = distribution.sample().item()
log_prob = distribution.log_prob(torch.tensor(action)).item()
return action, log_prob
def get_value(self, state: np.ndarray) -> float:
"""
Get the estimated value of a state.
Args:
state: The state to evaluate
Returns:
The estimated value
"""
# Convert state to tensor
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Get value estimate
with torch.no_grad():
value = self.value(state_tensor).item()
return value
def learn(
self,
batch_size: int = 64,
epochs: int = 10,
group_rewards: Optional[Dict[str, List[float]]] = None
) -> Dict[str, float]:
"""
Update policy and value networks based on collected experience.
Args:
batch_size: Size of batches to use for updates
epochs: Number of epochs to train for
group_rewards: Rewards collected by all agents in the group
Returns:
Dictionary of training metrics
"""
if len(self.buffer) < batch_size:
return {"policy_loss": 0, "value_loss": 0, "kl": 0}
# Prepare data for training
states, actions, old_log_probs, returns, advantages = self._prepare_training_data(
group_rewards)
# Training metrics
metrics = {
"policy_loss": 0,
"value_loss": 0,
"entropy": 0,
"kl": 0,
}
# Run training for multiple epochs
for epoch in range(epochs):
# Generate random indices for batching
indices = np.random.permutation(len(states))
# Process in batches
for start_idx in range(0, len(states), batch_size):
# Get batch indices
batch_indices = indices[start_idx:start_idx + batch_size]
# Extract batch data
batch_states = states[batch_indices]
batch_actions = actions[batch_indices]
batch_old_log_probs = old_log_probs[batch_indices]
batch_returns = returns[batch_indices]
batch_advantages = advantages[batch_indices]
# Update policy
policy_loss, entropy, kl = self._update_policy(
batch_states, batch_actions, batch_old_log_probs, batch_advantages)
# Early stopping based on KL divergence
if kl > 1.5 * self.target_kl:
break
# Update value function
value_loss = self._update_value(batch_states, batch_returns)
# Update metrics
metrics["policy_loss"] += policy_loss
metrics["value_loss"] += value_loss
metrics["entropy"] += entropy
metrics["kl"] += kl
# Check for early stopping after each epoch
if metrics["kl"] / (epoch + 1) > self.target_kl:
break
# Normalize metrics by number of updates
num_updates = epochs * ((len(states) + batch_size - 1) // batch_size)
for key in metrics:
metrics[key] /= num_updates
# Clear buffer after training
self.buffer = ExperienceBuffer()
return metrics
def _prepare_training_data(
self, group_rewards: Optional[Dict[str, List[float]]] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepare data for training from the experience buffer.
Args:
group_rewards: Rewards collected by all agents in the group
Returns:
Tuple of (states, actions, old_log_probs, returns, advantages)
"""
# Collect experiences from buffer
experiences = list(self.buffer.buffer)
# Extract components
states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
dones = torch.FloatTensor([float(exp.done) for exp in experiences]).to(self.device)
# Compute values for all states and next states
with torch.no_grad():
values = self.value(states).squeeze()
next_values = self.value(next_states).squeeze()
# Compute advantages and returns
if self.use_gae:
# Generalized Advantage Estimation
advantages = self._compute_gae(rewards, values, next_values, dones)
else:
# Regular advantages
advantages = rewards + self.gamma * next_values * (1 - dones) - values
# Compute returns (for value function)
returns = advantages + values
# If group rewards are provided, compute relative advantages
if group_rewards is not None and self.agent_id in group_rewards:
relative_advantages = self._compute_relative_advantages(
advantages, group_rewards)
# Combine regular and relative advantages
advantages = (1 - self.relative_advantage_weight) * advantages + \
self.relative_advantage_weight * relative_advantages
# Normalize advantages if enabled
if self.normalize_advantages:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Get old log probabilities
old_log_probs = torch.FloatTensor(
[self._compute_log_prob(exp.state, exp.action) for exp in experiences]
).to(self.device)
return states, actions, old_log_probs, returns, advantages
def _compute_gae(
self, rewards: torch.Tensor, values: torch.Tensor,
next_values: torch.Tensor, dones: torch.Tensor
) -> torch.Tensor:
"""
Compute advantages using Generalized Advantage Estimation.
Args:
rewards: Batch of rewards
values: Batch of state values
next_values: Batch of next state values
dones: Batch of done flags
Returns:
Batch of advantage estimates
"""
# Initialize advantages
advantages = torch.zeros_like(rewards)
# Initialize gae
gae = 0
# Compute advantages in reverse order
for t in reversed(range(len(rewards))):
# Compute TD error
delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
# Update gae
gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
# Store advantage
advantages[t] = gae
return advantages
def _compute_relative_advantages(
self, advantages: torch.Tensor, group_rewards: Dict[str, List[float]]
) -> torch.Tensor:
"""
Compute relative advantages compared to other agents in the group.
Args:
advantages: This agent's advantages
group_rewards: Rewards collected by all agents in the group
Returns:
Relative advantages
"""
# Compute mean reward for each agent
agent_mean_rewards = {
agent_id: sum(rewards) / max(1, len(rewards))
for agent_id, rewards in group_rewards.items()
}
# Compute mean reward across all agents
group_mean_reward = sum(agent_mean_rewards.values()) / len(agent_mean_rewards)
# Compute relative performance factor
# Higher if this agent is doing better than the group average
if self.agent_id in agent_mean_rewards:
relative_factor = agent_mean_rewards[self.agent_id] / (group_mean_reward + 1e-8)
else:
relative_factor = 1.0
# Apply the relative factor to the advantages
relative_advantages = advantages * relative_factor
return relative_advantages
def _compute_log_prob(self, state: np.ndarray, action: int) -> float:
"""
Compute the log probability of an action given a state.
Args:
state: The state
action: The action
Returns:
The log probability
"""
# Convert state to tensor
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Get action distribution
with torch.no_grad():
logits = self.policy(state_tensor)
distribution = Categorical(logits=logits)
log_prob = distribution.log_prob(torch.tensor(action, device=self.device)).item()
return log_prob
def _update_policy(
self,
states: torch.Tensor,
actions: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor
) -> Tuple[float, float, float]:
"""
Update the policy network using PPO.
Args:
states: Batch of states
actions: Batch of actions
old_log_probs: Batch of old log probabilities
advantages: Batch of advantages
Returns:
Tuple of (policy_loss, entropy, kl_divergence)
"""
# Get action distributions
logits = self.policy(states)
distribution = Categorical(logits=logits)
# Get new log probabilities
new_log_probs = distribution.log_prob(actions)
# Compute probability ratio
ratio = torch.exp(new_log_probs - old_log_probs)
# Compute surrogate objectives
surrogate1 = ratio * advantages
surrogate2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages
# Compute policy loss (negative because we're maximizing)
policy_loss = -torch.min(surrogate1, surrogate2).mean()
# Compute entropy bonus
entropy = distribution.entropy().mean()
# Add entropy bonus to loss
loss = policy_loss - self.entropy_coef * entropy
# Compute approximate KL divergence for monitoring
with torch.no_grad():
kl = (old_log_probs - new_log_probs).mean().item()
# Update policy network
self.policy_optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy_optimizer.step()
return policy_loss.item(), entropy.item(), kl
def _update_value(self, states: torch.Tensor, returns: torch.Tensor) -> float:
"""
Update the value network.
Args:
states: Batch of states
returns: Batch of returns
Returns:
Value loss
"""
# Get value predictions
values = self.value(states).squeeze()
# Compute value loss
value_loss = F.mse_loss(values, returns)
# Update value network
self.value_optimizer.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.value.parameters(), self.max_grad_norm)
self.value_optimizer.step()
return value_loss.item()
class MultiAgentGroupRL:
"""
Multi-agent reinforcement learning system using GRPO for Claude Code.
This class manages multiple GRPO agents that learn in a coordinated way.
"""
def __init__(
self,
agent_configs: List[Dict[str, Any]],
feature_extractor: Callable[[Dict[str, Any]], np.ndarray],
reward_function: Callable[[Dict[str, Any], str, Any], float],
update_interval: int = 1000,
training_epochs: int = 10,
batch_size: int = 64,
save_dir: str = "./models",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Initialize the multi-agent RL system.
Args:
agent_configs: List of configurations for each agent
feature_extractor: Function to extract state features
reward_function: Function to compute rewards
update_interval: How often to update agents (in steps)
training_epochs: Number of epochs to train for each update
batch_size: Batch size for training
save_dir: Directory to save models
device: Device to run on
"""
self.feature_extractor = feature_extractor
self.reward_function = reward_function
self.update_interval = update_interval
self.training_epochs = training_epochs
self.batch_size = batch_size
self.save_dir = save_dir
self.device = device
# Initialize agents
self.agents = {}
for config in agent_configs:
agent_id = config["id"]
state_dim = config["state_dim"]
action_dim = config["action_dim"]
# Create GRPO agent
agent = GRPO(
state_dim=state_dim,
action_dim=action_dim,
hidden_dims=config.get("hidden_dims", [64, 64]),
device=device,
**{k: v for k, v in config.items() if k not in ["id", "state_dim", "action_dim", "hidden_dims"]}
)
# Set agent ID
agent.set_agent_id(agent_id)
self.agents[agent_id] = agent
# Track steps for periodic updates
self.total_steps = 0
# Store rewards for relative advantage computation
self.agent_rewards = {agent_id: [] for agent_id in self.agents}
def select_action(
self, agent_id: str, observation: Dict[str, Any], deterministic: bool = False
) -> Tuple[Any, float]:
"""
Select an action for the specified agent.
Args:
agent_id: ID of the agent
observation: Current observation
deterministic: Whether to select deterministically
Returns:
Tuple of (action, log probability)
"""
if agent_id not in self.agents:
raise ValueError(f"Unknown agent ID: {agent_id}")
# Extract features
state = self.feature_extractor(observation)
# Get action from agent
action, log_prob = self.agents[agent_id].get_action(state, deterministic)
return action, log_prob
def observe(
self,
agent_id: str,
observation: Dict[str, Any],
action: Any,
reward: float,
next_observation: Dict[str, Any],
done: bool,
info: Optional[Dict[str, Any]] = None
) -> None:
"""
Record an observation for the specified agent.
Args:
agent_id: ID of the agent
observation: Current observation
action: Action taken
reward: Reward received
next_observation: Next observation
done: Whether the episode is done
info: Additional information
"""
if agent_id not in self.agents:
raise ValueError(f"Unknown agent ID: {agent_id}")
# Extract features
state = self.feature_extractor(observation)
next_state = self.feature_extractor(next_observation)
# Create experience
exp = Experience(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
info=info
)
# Add experience to agent's buffer
self.agents[agent_id].buffer.add(exp)
# Store reward for relative advantage computation
self.agent_rewards[agent_id].append(reward)
# Increment step counter
self.total_steps += 1
# Perform updates if needed
if self.total_steps % self.update_interval == 0:
self.update_all_agents()
def update_all_agents(self) -> Dict[str, Dict[str, float]]:
"""
Update all agents' policies.
Returns:
Dictionary of training metrics for each agent
"""
# Store metrics for each agent
metrics = {}
# Update each agent
for agent_id, agent in self.agents.items():
# Train the agent with group rewards
agent_metrics = agent.learn(
batch_size=self.batch_size,
epochs=self.training_epochs,
group_rewards=self.agent_rewards
)
metrics[agent_id] = agent_metrics
# Reset reward tracking
self.agent_rewards = {agent_id: [] for agent_id in self.agents}
return metrics
def save_agents(self, suffix: str = "") -> None:
"""
Save all agents' models.
Args:
suffix: Optional suffix for saved files
"""
import os
# Create save directory if it doesn't exist
os.makedirs(self.save_dir, exist_ok=True)
# Save each agent
for agent_id, agent in self.agents.items():
# Create file path
file_path = os.path.join(self.save_dir, f"{agent_id}{suffix}.pt")
# Save model
torch.save({
"policy_state_dict": agent.policy.state_dict(),
"value_state_dict": agent.value.state_dict(),
"policy_optimizer_state_dict": agent.policy_optimizer.state_dict(),
"value_optimizer_state_dict": agent.value_optimizer.state_dict(),
}, file_path)
def load_agents(self, suffix: str = "") -> None:
"""
Load all agents' models.
Args:
suffix: Optional suffix for loaded files
"""
import os
# Load each agent
for agent_id, agent in self.agents.items():
# Create file path
file_path = os.path.join(self.save_dir, f"{agent_id}{suffix}.pt")
# Check if file exists
if not os.path.exists(file_path):
print(f"Warning: Model file not found for agent {agent_id}")
continue
# Load model
checkpoint = torch.load(file_path, map_location=self.device)
# Load state dicts
agent.policy.load_state_dict(checkpoint["policy_state_dict"])
agent.value.load_state_dict(checkpoint["value_state_dict"])
agent.policy_optimizer.load_state_dict(checkpoint["policy_optimizer_state_dict"])
agent.value_optimizer.load_state_dict(checkpoint["value_optimizer_state_dict"])
class ToolSelectionGRPO:
"""
Specialized GRPO implementation for tool selection in Claude Code.
This class adapts the MultiAgentGroupRL for the specific context of tool selection.
"""
def __init__(
self,
tool_registry: Any, # Should be a reference to the tool registry
context_evaluator: Callable, # Function to evaluate quality of response given context
state_dim: int = 768, # Embedding dimension for query
num_agents: int = 3, # Number of agents in the group
update_interval: int = 100,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Initialize the GRPO tool selector.
Args:
tool_registry: Registry containing available tools
context_evaluator: Function to evaluate response quality
state_dim: Dimension of state features
num_agents: Number of agents in the group
update_interval: How often to update agents
device: Device to run on
"""
self.tool_registry = tool_registry
self.context_evaluator = context_evaluator
# Get all available tools
self.tool_names = tool_registry.get_all_tool_names()
self.action_dim = len(self.tool_names)
# Define agent configurations
agent_configs = [
{
"id": f"tool_agent_{i}",
"state_dim": state_dim,
"action_dim": self.action_dim,
"hidden_dims": [256, 128],
"relative_advantage_weight": 0.7 if i > 0 else 0.3, # Different weights
"entropy_coef": 0.02 if i == 0 else 0.01, # Different exploration rates
}
for i in range(num_agents)
]
# Initialize multi-agent RL system
self.rl_system = MultiAgentGroupRL(
agent_configs=agent_configs,
feature_extractor=self._extract_features,
reward_function=self._compute_reward,
update_interval=update_interval,
device=device,
)
# Track current episode
self.current_episode = {agent_id: {} for agent_id in self.rl_system.agents}
def select_tool(self, user_query: str, context: Dict[str, Any], visualizer=None) -> str:
"""
Select the best tool to use for a given user query and context.
Args:
user_query: The user's query
context: The current conversation context
visualizer: Optional visualizer to display the selection process
Returns:
The name of the best tool to use
"""
# Create observation
observation = {
"query": user_query,
"context": context,
}
# If visualizer is provided, start it
if visualizer:
visualizer.start()
visualizer.add_execution(
execution_id="tool_selection",
tool_name="GRPO Tool Selection",
parameters={"query": user_query[:100] + "..." if len(user_query) > 100 else user_query}
)
# Select agent to use (round-robin for now)
agent_id = f"tool_agent_{self.rl_system.total_steps % len(self.rl_system.agents)}"
# Update visualizer if provided
if visualizer:
visualizer.update_progress("tool_selection", 0.3)
# Get action from agent
action_idx, _ = self.rl_system.select_action(
agent_id=agent_id,
observation=observation,
deterministic=False # Use exploratory actions during learning
)
# Update visualizer if provided
if visualizer:
visualizer.update_progress("tool_selection", 0.6)
# Store initial information for the episode
self.current_episode[agent_id] = {
"observation": observation,
"action_idx": action_idx,
"initial_quality": self.context_evaluator(context),
}
# Map action index to tool name
tool_name = self.tool_names[action_idx]
# Complete visualization if provided
if visualizer:
# Create detailed metrics for visualization
agent_data = {}
for aid, agent in self.rl_system.agents.items():
# Get all tool probabilities for this agent
with torch.no_grad():
state = self.rl_system._extract_features(observation)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
logits = agent.policy(state_tensor)
probs = F.softmax(logits, dim=1).squeeze().cpu().numpy()
# Add to metrics
agent_data[aid] = {
"selected": aid == agent_id,
"tool_probabilities": {
self.tool_names[i]: float(prob)
for i, prob in enumerate(probs)
}
}
# Complete the visualization
visualizer.complete_execution(
execution_id="tool_selection",
result={
"selected_tool": tool_name,
"selected_agent": agent_id,
"agent_data": agent_data
},
status="success"
)
visualizer.stop()
return tool_name
def observe_result(
self, agent_id: str, result: Any, context: Dict[str, Any], done: bool = True
) -> None:
"""
Observe the result of using a tool.
Args:
agent_id: The ID of the agent that selected the tool
result: The result of using the tool
context: The updated context after using the tool
done: Whether the interaction is complete
"""
if agent_id not in self.current_episode:
return
# Get episode information
episode = self.current_episode[agent_id]
observation = episode["observation"]
action_idx = episode["action_idx"]
initial_quality = episode["initial_quality"]
# Create next observation
next_observation = {
"query": observation["query"],
"context": context,
"result": result,
}
# Compute reward
reward = self._compute_reward(observation, action_idx, result, context, initial_quality)
# Record observation
self.rl_system.observe(
agent_id=agent_id,
observation=observation,
action=action_idx,
reward=reward,
next_observation=next_observation,
done=done,
)
# Clear episode if done
if done:
self.current_episode[agent_id] = {}
def _extract_features(self, observation: Dict[str, Any]) -> np.ndarray:
"""Extract features from an observation."""
# This would ideally use an embedding model
# For now, return a random vector as a placeholder
return np.random.randn(768)
def _compute_reward(
self,
observation: Dict[str, Any],
action_idx: int,
result: Any,
context: Dict[str, Any],
initial_quality: float
) -> float:
"""Compute the reward for an action."""
# Compute the quality improvement
final_quality = self.context_evaluator(context)
quality_improvement = final_quality - initial_quality
# Base reward on quality improvement
reward = max(0, quality_improvement * 10) # Scale for better learning
return reward
def update(self) -> Dict[str, Dict[str, float]]:
"""
Trigger an update of all agents.
Returns:
Dictionary of training metrics
"""
return self.rl_system.update_all_agents()
def save(self, suffix: str = "") -> None:
"""Save all agents."""
self.rl_system.save_agents(suffix)
def load(self, suffix: str = "") -> None:
"""Load all agents."""
self.rl_system.load_agents(suffix)