OpenAI MCP Server
by arthurcolle
Group Relative Policy Optimization (GRPO) for multi-agent learning in Claude Code.
This module provides a multi-agent GRPO implementation that learns from interactions.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
from dataclasses import dataclass
from collections import deque
import random
import time
class Experience:
"""A single step of experience for reinforcement learning."""
state: Any
action: Any
reward: float
next_state: Any
done: bool
info: Optional[Dict[str, Any]] = None
class ExperienceBuffer:
"""Buffer to store and sample experiences for training."""
def __init__(self, capacity: int = 100000):
Initialize the experience buffer.
capacity: Maximum number of experiences to store
self.buffer = deque(maxlen=capacity)
def add(self, experience: Experience) -> None:
"""Add an experience to the buffer."""
def sample(self, batch_size: int) -> List[Experience]:
"""Sample a batch of experiences from the buffer."""
return random.sample(self.buffer, min(batch_size, len(self.buffer)))
def __len__(self) -> int:
"""Get the current size of the buffer."""
return len(self.buffer)
class PolicyNetwork(nn.Module):
"""Neural network to represent a policy."""
def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int):
Initialize the policy network.
input_dim: Dimension of the input state
hidden_dims: List of hidden layer dimensions
output_dim: Dimension of the action space
super(PolicyNetwork, self).__init__()
# Create the input layer
layers = [nn.Linear(input_dim, hidden_dims[0]), nn.ReLU()]
# Create hidden layers
for i in range(len(hidden_dims) - 1):
layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
# Create output layer
layers.append(nn.Linear(hidden_dims[-1], output_dim)) = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the network."""
class ValueNetwork(nn.Module):
"""Neural network to represent a value function."""
def __init__(self, input_dim: int, hidden_dims: List[int]):
Initialize the value network.
input_dim: Dimension of the input state
hidden_dims: List of hidden layer dimensions
super(ValueNetwork, self).__init__()
# Create the input layer
layers = [nn.Linear(input_dim, hidden_dims[0]), nn.ReLU()]
# Create hidden layers
for i in range(len(hidden_dims) - 1):
layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
# Create output layer (scalar value)
layers.append(nn.Linear(hidden_dims[-1], 1)) = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the network."""
class GRPO:
Group Relative Policy Optimization implementation for multi-agent learning.
GRPO extends PPO by considering relative performance within a group of agents.
def __init__(
state_dim: int,
action_dim: int,
hidden_dims: List[int] = [64, 64],
lr_policy: float = 3e-4,
lr_value: float = 1e-3,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_ratio: float = 0.2,
target_kl: float = 0.01,
value_coef: float = 0.5,
entropy_coef: float = 0.01,
max_grad_norm: float = 0.5,
use_gae: bool = True,
normalize_advantages: bool = True,
relative_advantage_weight: float = 0.5,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
Initialize the GRPO agent.
state_dim: Dimension of the state space
action_dim: Dimension of the action space
hidden_dims: Dimensions of hidden layers in networks
lr_policy: Learning rate for policy network
lr_value: Learning rate for value network
gamma: Discount factor
gae_lambda: Lambda for GAE
clip_ratio: PPO clipping parameter
target_kl: Target KL divergence for early stopping
value_coef: Value loss coefficient
entropy_coef: Entropy bonus coefficient
max_grad_norm: Maximum gradient norm for clipping
use_gae: Whether to use GAE
normalize_advantages: Whether to normalize advantages
relative_advantage_weight: Weight for relative advantage component
device: Device to run the model on
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_ratio = clip_ratio
self.target_kl = target_kl
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.use_gae = use_gae
self.normalize_advantages = normalize_advantages
self.relative_advantage_weight = relative_advantage_weight
self.device = device
# Initialize networks
self.policy = PolicyNetwork(state_dim, hidden_dims, action_dim).to(device)
self.value = ValueNetwork(state_dim, hidden_dims).to(device)
# Initialize optimizers
self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr_policy)
self.value_optimizer = optim.Adam(self.value.parameters(), lr=lr_value)
# Initialize experience buffer
self.buffer = ExperienceBuffer()
# Group-level buffers for relative advantage computation
self.group_rewards = []
self.agent_id = None # Will be set when joining a group
def set_agent_id(self, agent_id: str) -> None:
"""Set the agent's ID within the group."""
self.agent_id = agent_id
def get_action(self, state: np.ndarray, deterministic: bool = False) -> Tuple[int, float]:
Get an action from the policy for the given state.
state: The current state
deterministic: Whether to return the most likely action
Tuple of (action, log probability)
# Convert state to tensor
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Get action distributions
with torch.no_grad():
logits = self.policy(state_tensor)
distribution = Categorical(logits=logits)
if deterministic:
action = torch.argmax(logits, dim=1).item()
action = distribution.sample().item()
log_prob = distribution.log_prob(torch.tensor(action)).item()
return action, log_prob
def get_value(self, state: np.ndarray) -> float:
Get the estimated value of a state.
state: The state to evaluate
The estimated value
# Convert state to tensor
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Get value estimate
with torch.no_grad():
value = self.value(state_tensor).item()
return value
def learn(
batch_size: int = 64,
epochs: int = 10,
group_rewards: Optional[Dict[str, List[float]]] = None
) -> Dict[str, float]:
Update policy and value networks based on collected experience.
batch_size: Size of batches to use for updates
epochs: Number of epochs to train for
group_rewards: Rewards collected by all agents in the group
Dictionary of training metrics
if len(self.buffer) < batch_size:
return {"policy_loss": 0, "value_loss": 0, "kl": 0}
# Prepare data for training
states, actions, old_log_probs, returns, advantages = self._prepare_training_data(
# Training metrics
metrics = {
"policy_loss": 0,
"value_loss": 0,
"entropy": 0,
"kl": 0,
# Run training for multiple epochs
for epoch in range(epochs):
# Generate random indices for batching
indices = np.random.permutation(len(states))
# Process in batches
for start_idx in range(0, len(states), batch_size):
# Get batch indices
batch_indices = indices[start_idx:start_idx + batch_size]
# Extract batch data
batch_states = states[batch_indices]
batch_actions = actions[batch_indices]
batch_old_log_probs = old_log_probs[batch_indices]
batch_returns = returns[batch_indices]
batch_advantages = advantages[batch_indices]
# Update policy
policy_loss, entropy, kl = self._update_policy(
batch_states, batch_actions, batch_old_log_probs, batch_advantages)
# Early stopping based on KL divergence
if kl > 1.5 * self.target_kl:
# Update value function
value_loss = self._update_value(batch_states, batch_returns)
# Update metrics
metrics["policy_loss"] += policy_loss
metrics["value_loss"] += value_loss
metrics["entropy"] += entropy
metrics["kl"] += kl
# Check for early stopping after each epoch
if metrics["kl"] / (epoch + 1) > self.target_kl:
# Normalize metrics by number of updates
num_updates = epochs * ((len(states) + batch_size - 1) // batch_size)
for key in metrics:
metrics[key] /= num_updates
# Clear buffer after training
self.buffer = ExperienceBuffer()
return metrics
def _prepare_training_data(
self, group_rewards: Optional[Dict[str, List[float]]] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Prepare data for training from the experience buffer.
group_rewards: Rewards collected by all agents in the group
Tuple of (states, actions, old_log_probs, returns, advantages)
# Collect experiences from buffer
experiences = list(self.buffer.buffer)
# Extract components
states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
dones = torch.FloatTensor([float(exp.done) for exp in experiences]).to(self.device)
# Compute values for all states and next states
with torch.no_grad():
values = self.value(states).squeeze()
next_values = self.value(next_states).squeeze()
# Compute advantages and returns
if self.use_gae:
# Generalized Advantage Estimation
advantages = self._compute_gae(rewards, values, next_values, dones)
# Regular advantages
advantages = rewards + self.gamma * next_values * (1 - dones) - values
# Compute returns (for value function)
returns = advantages + values
# If group rewards are provided, compute relative advantages
if group_rewards is not None and self.agent_id in group_rewards:
relative_advantages = self._compute_relative_advantages(
advantages, group_rewards)
# Combine regular and relative advantages
advantages = (1 - self.relative_advantage_weight) * advantages + \
self.relative_advantage_weight * relative_advantages
# Normalize advantages if enabled
if self.normalize_advantages:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Get old log probabilities
old_log_probs = torch.FloatTensor(
[self._compute_log_prob(exp.state, exp.action) for exp in experiences]
return states, actions, old_log_probs, returns, advantages
def _compute_gae(
self, rewards: torch.Tensor, values: torch.Tensor,
next_values: torch.Tensor, dones: torch.Tensor
) -> torch.Tensor:
Compute advantages using Generalized Advantage Estimation.
rewards: Batch of rewards
values: Batch of state values
next_values: Batch of next state values
dones: Batch of done flags
Batch of advantage estimates
# Initialize advantages
advantages = torch.zeros_like(rewards)
# Initialize gae
gae = 0
# Compute advantages in reverse order
for t in reversed(range(len(rewards))):
# Compute TD error
delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
# Update gae
gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
# Store advantage
advantages[t] = gae
return advantages
def _compute_relative_advantages(
self, advantages: torch.Tensor, group_rewards: Dict[str, List[float]]
) -> torch.Tensor:
Compute relative advantages compared to other agents in the group.
advantages: This agent's advantages
group_rewards: Rewards collected by all agents in the group
Relative advantages
# Compute mean reward for each agent
agent_mean_rewards = {
agent_id: sum(rewards) / max(1, len(rewards))
for agent_id, rewards in group_rewards.items()
# Compute mean reward across all agents
group_mean_reward = sum(agent_mean_rewards.values()) / len(agent_mean_rewards)
# Compute relative performance factor
# Higher if this agent is doing better than the group average
if self.agent_id in agent_mean_rewards:
relative_factor = agent_mean_rewards[self.agent_id] / (group_mean_reward + 1e-8)
relative_factor = 1.0
# Apply the relative factor to the advantages
relative_advantages = advantages * relative_factor
return relative_advantages
def _compute_log_prob(self, state: np.ndarray, action: int) -> float:
Compute the log probability of an action given a state.
state: The state
action: The action
The log probability
# Convert state to tensor
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Get action distribution
with torch.no_grad():
logits = self.policy(state_tensor)
distribution = Categorical(logits=logits)
log_prob = distribution.log_prob(torch.tensor(action, device=self.device)).item()
return log_prob
def _update_policy(
states: torch.Tensor,
actions: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor
) -> Tuple[float, float, float]:
Update the policy network using PPO.
states: Batch of states
actions: Batch of actions
old_log_probs: Batch of old log probabilities
advantages: Batch of advantages
Tuple of (policy_loss, entropy, kl_divergence)
# Get action distributions
logits = self.policy(states)
distribution = Categorical(logits=logits)
# Get new log probabilities
new_log_probs = distribution.log_prob(actions)
# Compute probability ratio
ratio = torch.exp(new_log_probs - old_log_probs)
# Compute surrogate objectives
surrogate1 = ratio * advantages
surrogate2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages
# Compute policy loss (negative because we're maximizing)
policy_loss = -torch.min(surrogate1, surrogate2).mean()
# Compute entropy bonus
entropy = distribution.entropy().mean()
# Add entropy bonus to loss
loss = policy_loss - self.entropy_coef * entropy
# Compute approximate KL divergence for monitoring
with torch.no_grad():
kl = (old_log_probs - new_log_probs).mean().item()
# Update policy network
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
return policy_loss.item(), entropy.item(), kl
def _update_value(self, states: torch.Tensor, returns: torch.Tensor) -> float:
Update the value network.
states: Batch of states
returns: Batch of returns
Value loss
# Get value predictions
values = self.value(states).squeeze()
# Compute value loss
value_loss = F.mse_loss(values, returns)
# Update value network
nn.utils.clip_grad_norm_(self.value.parameters(), self.max_grad_norm)
return value_loss.item()
class MultiAgentGroupRL:
Multi-agent reinforcement learning system using GRPO for Claude Code.
This class manages multiple GRPO agents that learn in a coordinated way.
def __init__(
agent_configs: List[Dict[str, Any]],
feature_extractor: Callable[[Dict[str, Any]], np.ndarray],
reward_function: Callable[[Dict[str, Any], str, Any], float],
update_interval: int = 1000,
training_epochs: int = 10,
batch_size: int = 64,
save_dir: str = "./models",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
Initialize the multi-agent RL system.
agent_configs: List of configurations for each agent
feature_extractor: Function to extract state features
reward_function: Function to compute rewards
update_interval: How often to update agents (in steps)
training_epochs: Number of epochs to train for each update
batch_size: Batch size for training
save_dir: Directory to save models
device: Device to run on
self.feature_extractor = feature_extractor
self.reward_function = reward_function
self.update_interval = update_interval
self.training_epochs = training_epochs
self.batch_size = batch_size
self.save_dir = save_dir
self.device = device
# Initialize agents
self.agents = {}
for config in agent_configs:
agent_id = config["id"]
state_dim = config["state_dim"]
action_dim = config["action_dim"]
# Create GRPO agent
agent = GRPO(
hidden_dims=config.get("hidden_dims", [64, 64]),
**{k: v for k, v in config.items() if k not in ["id", "state_dim", "action_dim", "hidden_dims"]}
# Set agent ID
self.agents[agent_id] = agent
# Track steps for periodic updates
self.total_steps = 0
# Store rewards for relative advantage computation
self.agent_rewards = {agent_id: [] for agent_id in self.agents}
def select_action(
self, agent_id: str, observation: Dict[str, Any], deterministic: bool = False
) -> Tuple[Any, float]:
Select an action for the specified agent.
agent_id: ID of the agent
observation: Current observation
deterministic: Whether to select deterministically
Tuple of (action, log probability)
if agent_id not in self.agents:
raise ValueError(f"Unknown agent ID: {agent_id}")
# Extract features
state = self.feature_extractor(observation)
# Get action from agent
action, log_prob = self.agents[agent_id].get_action(state, deterministic)
return action, log_prob
def observe(
agent_id: str,
observation: Dict[str, Any],
action: Any,
reward: float,
next_observation: Dict[str, Any],
done: bool,
info: Optional[Dict[str, Any]] = None
) -> None:
Record an observation for the specified agent.
agent_id: ID of the agent
observation: Current observation
action: Action taken
reward: Reward received
next_observation: Next observation
done: Whether the episode is done
info: Additional information
if agent_id not in self.agents:
raise ValueError(f"Unknown agent ID: {agent_id}")
# Extract features
state = self.feature_extractor(observation)
next_state = self.feature_extractor(next_observation)
# Create experience
exp = Experience(
# Add experience to agent's buffer
# Store reward for relative advantage computation
# Increment step counter
self.total_steps += 1
# Perform updates if needed
if self.total_steps % self.update_interval == 0:
def update_all_agents(self) -> Dict[str, Dict[str, float]]:
Update all agents' policies.
Dictionary of training metrics for each agent
# Store metrics for each agent
metrics = {}
# Update each agent
for agent_id, agent in self.agents.items():
# Train the agent with group rewards
agent_metrics = agent.learn(
metrics[agent_id] = agent_metrics
# Reset reward tracking
self.agent_rewards = {agent_id: [] for agent_id in self.agents}
return metrics
def save_agents(self, suffix: str = "") -> None:
Save all agents' models.
suffix: Optional suffix for saved files
import os
# Create save directory if it doesn't exist
os.makedirs(self.save_dir, exist_ok=True)
# Save each agent
for agent_id, agent in self.agents.items():
# Create file path
file_path = os.path.join(self.save_dir, f"{agent_id}{suffix}.pt")
# Save model{
"policy_state_dict": agent.policy.state_dict(),
"value_state_dict": agent.value.state_dict(),
"policy_optimizer_state_dict": agent.policy_optimizer.state_dict(),
"value_optimizer_state_dict": agent.value_optimizer.state_dict(),
}, file_path)
def load_agents(self, suffix: str = "") -> None:
Load all agents' models.
suffix: Optional suffix for loaded files
import os
# Load each agent
for agent_id, agent in self.agents.items():
# Create file path
file_path = os.path.join(self.save_dir, f"{agent_id}{suffix}.pt")
# Check if file exists
if not os.path.exists(file_path):
print(f"Warning: Model file not found for agent {agent_id}")
# Load model
checkpoint = torch.load(file_path, map_location=self.device)
# Load state dicts
class ToolSelectionGRPO:
Specialized GRPO implementation for tool selection in Claude Code.
This class adapts the MultiAgentGroupRL for the specific context of tool selection.
def __init__(
tool_registry: Any, # Should be a reference to the tool registry
context_evaluator: Callable, # Function to evaluate quality of response given context
state_dim: int = 768, # Embedding dimension for query
num_agents: int = 3, # Number of agents in the group
update_interval: int = 100,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
Initialize the GRPO tool selector.
tool_registry: Registry containing available tools
context_evaluator: Function to evaluate response quality
state_dim: Dimension of state features
num_agents: Number of agents in the group
update_interval: How often to update agents
device: Device to run on
self.tool_registry = tool_registry
self.context_evaluator = context_evaluator
# Get all available tools
self.tool_names = tool_registry.get_all_tool_names()
self.action_dim = len(self.tool_names)
# Define agent configurations
agent_configs = [
"id": f"tool_agent_{i}",
"state_dim": state_dim,
"action_dim": self.action_dim,
"hidden_dims": [256, 128],
"relative_advantage_weight": 0.7 if i > 0 else 0.3, # Different weights
"entropy_coef": 0.02 if i == 0 else 0.01, # Different exploration rates
for i in range(num_agents)
# Initialize multi-agent RL system
self.rl_system = MultiAgentGroupRL(
# Track current episode
self.current_episode = {agent_id: {} for agent_id in self.rl_system.agents}
def select_tool(self, user_query: str, context: Dict[str, Any], visualizer=None) -> str:
Select the best tool to use for a given user query and context.
user_query: The user's query
context: The current conversation context
visualizer: Optional visualizer to display the selection process
The name of the best tool to use
# Create observation
observation = {
"query": user_query,
"context": context,
# If visualizer is provided, start it
if visualizer:
tool_name="GRPO Tool Selection",
parameters={"query": user_query[:100] + "..." if len(user_query) > 100 else user_query}
# Select agent to use (round-robin for now)
agent_id = f"tool_agent_{self.rl_system.total_steps % len(self.rl_system.agents)}"
# Update visualizer if provided
if visualizer:
visualizer.update_progress("tool_selection", 0.3)
# Get action from agent
action_idx, _ = self.rl_system.select_action(
deterministic=False # Use exploratory actions during learning
# Update visualizer if provided
if visualizer:
visualizer.update_progress("tool_selection", 0.6)
# Store initial information for the episode
self.current_episode[agent_id] = {
"observation": observation,
"action_idx": action_idx,
"initial_quality": self.context_evaluator(context),
# Map action index to tool name
tool_name = self.tool_names[action_idx]
# Complete visualization if provided
if visualizer:
# Create detailed metrics for visualization
agent_data = {}
for aid, agent in self.rl_system.agents.items():
# Get all tool probabilities for this agent
with torch.no_grad():
state = self.rl_system._extract_features(observation)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
logits = agent.policy(state_tensor)
probs = F.softmax(logits, dim=1).squeeze().cpu().numpy()
# Add to metrics
agent_data[aid] = {
"selected": aid == agent_id,
"tool_probabilities": {
self.tool_names[i]: float(prob)
for i, prob in enumerate(probs)
# Complete the visualization
"selected_tool": tool_name,
"selected_agent": agent_id,
"agent_data": agent_data
return tool_name
def observe_result(
self, agent_id: str, result: Any, context: Dict[str, Any], done: bool = True
) -> None:
Observe the result of using a tool.
agent_id: The ID of the agent that selected the tool
result: The result of using the tool
context: The updated context after using the tool
done: Whether the interaction is complete
if agent_id not in self.current_episode:
# Get episode information
episode = self.current_episode[agent_id]
observation = episode["observation"]
action_idx = episode["action_idx"]
initial_quality = episode["initial_quality"]
# Create next observation
next_observation = {
"query": observation["query"],
"context": context,
"result": result,
# Compute reward
reward = self._compute_reward(observation, action_idx, result, context, initial_quality)
# Record observation
# Clear episode if done
if done:
self.current_episode[agent_id] = {}
def _extract_features(self, observation: Dict[str, Any]) -> np.ndarray:
"""Extract features from an observation."""
# This would ideally use an embedding model
# For now, return a random vector as a placeholder
return np.random.randn(768)
def _compute_reward(
observation: Dict[str, Any],
action_idx: int,
result: Any,
context: Dict[str, Any],
initial_quality: float
) -> float:
"""Compute the reward for an action."""
# Compute the quality improvement
final_quality = self.context_evaluator(context)
quality_improvement = final_quality - initial_quality
# Base reward on quality improvement
reward = max(0, quality_improvement * 10) # Scale for better learning
return reward
def update(self) -> Dict[str, Dict[str, float]]:
Trigger an update of all agents.
Dictionary of training metrics
return self.rl_system.update_all_agents()
def save(self, suffix: str = "") -> None:
"""Save all agents."""
def load(self, suffix: str = "") -> None:
"""Load all agents."""