from typing import Any, Dict, List, Optional, Union, Set
import grpc
import json
import logging
import asyncio
import os
import time
from datetime import datetime
from mcp.server.fastmcp import FastMCP
# Initialize FastMCP server
mcp = FastMCP("sushi-control")
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("sushi-mcp")
# Global state to store the Sushi server IP (will be set by the user)
sushi_ip = "localhost"
sushi_port = "51051"
# Global state for real-time updates
subscribed_to_updates = False
parameter_subscriptions = set()
notification_callbacks = {}
# Presets directory
PRESETS_DIR = "presets"
os.makedirs(PRESETS_DIR, exist_ok=True)
# Load the protobuf modules (these would be generated from the proto file)
import sushi_rpc_pb2
import sushi_rpc_pb2_grpc
# Helper function to get a gRPC channel to Sushi
def get_channel():
"""Create and return a gRPC channel to the Sushi server."""
address = f"{sushi_ip}:{sushi_port}"
logger.info(f"Connecting to Sushi at {address}")
return grpc.insecure_channel(address)
# Load the configuration for validation and reference
try:
with open('default_config_sushi.json', 'r') as config_file:
config = json.load(config_file)
logger.info("Loaded Sushi configuration file")
except Exception as e:
logger.error(f"Failed to load configuration file: {e}")
config = {}
# MCP Tool to set the Sushi server IP address
@mcp.tool()
async def set_sushi_server(ip: str, port: str = "51051") -> str:
"""Set the IP address and port of the Sushi server.
Args:
ip: IP address of the Sushi server (e.g., "192.168.1.100")
port: Port of the Sushi gRPC server (default: "51051")
"""
global sushi_ip, sushi_port
sushi_ip = ip
sushi_port = port
# Test the connection
try:
with get_channel() as channel:
stub = sushi_rpc_pb2_grpc.SystemControllerStub(channel)
response = stub.GetSushiVersion(sushi_rpc_pb2.GenericVoidValue())
return f"Successfully connected to Sushi version {response.value} at {ip}:{port}"
except Exception as e:
logger.error(f"Failed to connect to Sushi server: {e}")
return f"Failed to connect to Sushi server at {ip}:{port}: {str(e)}"
# System Controller Tools
@mcp.tool()
async def get_sushi_info() -> Dict[str, Any]:
"""Get information about the Sushi system."""
with get_channel() as channel:
system_controller = sushi_rpc_pb2_grpc.SystemControllerStub(channel)
# Get version
version_response = system_controller.GetSushiVersion(sushi_rpc_pb2.GenericVoidValue())
# Get build info
build_info_response = system_controller.GetBuildInfo(sushi_rpc_pb2.GenericVoidValue())
# Get channel counts
input_channels = system_controller.GetInputAudioChannelCount(sushi_rpc_pb2.GenericVoidValue())
output_channels = system_controller.GetOutputAudioChannelCount(sushi_rpc_pb2.GenericVoidValue())
return {
"version": version_response.value,
"build_options": list(build_info_response.build_options),
"audio_buffer_size": build_info_response.audio_buffer_size,
"commit_hash": build_info_response.commit_hash,
"build_date": build_info_response.build_date,
"input_channels": input_channels.value,
"output_channels": output_channels.value
}
# Transport Controller Tools
@mcp.tool()
async def get_transport_info() -> Dict[str, Any]:
"""Get information about the Sushi transport state."""
with get_channel() as channel:
transport_controller = sushi_rpc_pb2_grpc.TransportControllerStub(channel)
# Get sample rate
samplerate = transport_controller.GetSamplerate(sushi_rpc_pb2.GenericVoidValue())
# Get playing mode
playing_mode = transport_controller.GetPlayingMode(sushi_rpc_pb2.GenericVoidValue())
# Get sync mode
sync_mode = transport_controller.GetSyncMode(sushi_rpc_pb2.GenericVoidValue())
# Get time signature
time_signature = transport_controller.GetTimeSignature(sushi_rpc_pb2.GenericVoidValue())
# Get tempo
tempo = transport_controller.GetTempo(sushi_rpc_pb2.GenericVoidValue())
# Map mode enums to strings
playing_modes = {
1: "STOPPED",
2: "PLAYING",
3: "RECORDING"
}
sync_modes = {
1: "INTERNAL",
2: "MIDI",
3: "LINK"
}
return {
"samplerate": samplerate.value,
"playing_mode": playing_modes.get(playing_mode.mode, f"UNKNOWN ({playing_mode.mode})"),
"sync_mode": sync_modes.get(sync_mode.mode, f"UNKNOWN ({sync_mode.mode})"),
"time_signature": f"{time_signature.numerator}/{time_signature.denominator}",
"tempo": tempo.value
}
@mcp.tool()
async def set_tempo(tempo: float) -> str:
"""Set the tempo of Sushi.
Args:
tempo: Tempo in BPM
"""
with get_channel() as channel:
transport_controller = sushi_rpc_pb2_grpc.TransportControllerStub(channel)
transport_controller.SetTempo(sushi_rpc_pb2.GenericFloatValue(value=tempo))
return f"Tempo set to {tempo} BPM"
@mcp.tool()
async def set_playing_mode(mode: str) -> str:
"""Set the playing mode of Sushi.
Args:
mode: Playing mode (STOPPED, PLAYING, RECORDING)
"""
mode_map = {
"STOPPED": sushi_rpc_pb2.PlayingMode.Mode.STOPPED,
"PLAYING": sushi_rpc_pb2.PlayingMode.Mode.PLAYING,
"RECORDING": sushi_rpc_pb2.PlayingMode.Mode.RECORDING
}
if mode not in mode_map:
return f"Invalid mode: {mode}. Valid modes are: {', '.join(mode_map.keys())}"
with get_channel() as channel:
transport_controller = sushi_rpc_pb2_grpc.TransportControllerStub(channel)
transport_controller.SetPlayingMode(
sushi_rpc_pb2.PlayingMode(mode=mode_map[mode])
)
return f"Playing mode set to {mode}"
# Track and Processor Management Tools
@mcp.tool()
async def get_all_tracks() -> List[Dict[str, Any]]:
"""Get information about all tracks in Sushi."""
with get_channel() as channel:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
response = audio_graph_controller.GetAllTracks(sushi_rpc_pb2.GenericVoidValue())
tracks = []
for track in response.tracks:
# Map track type enum to string
track_types = {
1: "REGULAR",
2: "PRE",
3: "POST"
}
tracks.append({
"id": track.id,
"name": track.name,
"label": track.label,
"channels": track.channels,
"buses": track.buses,
"type": track_types.get(track.type.type, f"UNKNOWN ({track.type.type})"),
"processor_count": len(track.processors)
})
return tracks
@mcp.tool()
async def get_track_processors(track_id: int) -> List[Dict[str, Any]]:
"""Get all processors on a track.
Args:
track_id: ID of the track
"""
with get_channel() as channel:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
response = audio_graph_controller.GetTrackProcessors(
sushi_rpc_pb2.TrackIdentifier(id=track_id)
)
processors = []
for processor in response.processors:
processor_info = audio_graph_controller.GetProcessorInfo(
sushi_rpc_pb2.ProcessorIdentifier(id=processor.id)
)
bypass_state = audio_graph_controller.GetProcessorBypassState(
sushi_rpc_pb2.ProcessorIdentifier(id=processor.id)
)
processors.append({
"id": processor.id,
"name": processor_info.name,
"label": processor_info.label,
"parameter_count": processor_info.parameter_count,
"program_count": processor_info.program_count,
"bypassed": bypass_state.value
})
return processors
@mcp.tool()
async def get_processor_parameters(processor_id: int) -> List[Dict[str, Any]]:
"""Get all parameters for a processor.
Args:
processor_id: ID of the processor
"""
with get_channel() as channel:
parameter_controller = sushi_rpc_pb2_grpc.ParameterControllerStub(channel)
response = parameter_controller.GetProcessorParameters(
sushi_rpc_pb2.ProcessorIdentifier(id=processor_id)
)
parameters = []
for param in response.parameters:
# Get the current value
value_response = parameter_controller.GetParameterValue(
sushi_rpc_pb2.ParameterIdentifier(
processor_id=processor_id,
parameter_id=param.id
)
)
# Get the string representation
string_value_response = parameter_controller.GetParameterValueAsString(
sushi_rpc_pb2.ParameterIdentifier(
processor_id=processor_id,
parameter_id=param.id
)
)
# Map parameter type enum to string
param_types = {
1: "BOOL",
2: "INT",
3: "FLOAT"
}
parameters.append({
"id": param.id,
"name": param.name,
"label": param.label,
"type": param_types.get(param.type.type, f"UNKNOWN ({param.type.type})"),
"unit": param.unit,
"automatable": param.automatable,
"min_value": param.min_domain_value,
"max_value": param.max_domain_value,
"current_value": value_response.value,
"string_value": string_value_response.value
})
return parameters
@mcp.tool()
async def set_parameter_value(processor_id: int, parameter_id: int, value: float) -> str:
"""Set a parameter value on a processor.
Args:
processor_id: ID of the processor
parameter_id: ID of the parameter
value: New value for the parameter
"""
with get_channel() as channel:
parameter_controller = sushi_rpc_pb2_grpc.ParameterControllerStub(channel)
# Validate value range
try:
# Get parameter info to check range
processor_params = await get_processor_parameters(processor_id)
param_info = next((p for p in processor_params if p["id"] == parameter_id), None)
if param_info:
if value < param_info["min_value"] or value > param_info["max_value"]:
return (f"Value {value} out of range for parameter {param_info['name']}. "
f"Valid range: {param_info['min_value']} to {param_info['max_value']}")
except Exception as e:
logger.warning(f"Could not validate parameter value range: {e}")
# Set the parameter value
parameter_controller.SetParameterValue(
sushi_rpc_pb2.ParameterValue(
parameter=sushi_rpc_pb2.ParameterIdentifier(
processor_id=processor_id,
parameter_id=parameter_id
),
value=value
)
)
# Read back the value to confirm
value_response = parameter_controller.GetParameterValueAsString(
sushi_rpc_pb2.ParameterIdentifier(
processor_id=processor_id,
parameter_id=parameter_id
)
)
return f"Parameter set to {value_response.value}"
@mcp.tool()
async def bypass_processor(processor_id: int, bypassed: bool = True) -> str:
"""Bypass or enable a processor.
Args:
processor_id: ID of the processor
bypassed: True to bypass, False to enable
"""
with get_channel() as channel:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
# Get processor info for the response message
processor_info = audio_graph_controller.GetProcessorInfo(
sushi_rpc_pb2.ProcessorIdentifier(id=processor_id)
)
# Set bypass state
audio_graph_controller.SetProcessorBypassState(
sushi_rpc_pb2.ProcessorBypassStateSetRequest(
processor=sushi_rpc_pb2.ProcessorIdentifier(id=processor_id),
value=bypassed
)
)
state = "bypassed" if bypassed else "enabled"
return f"Processor '{processor_info.name}' {state}"
@mcp.tool()
async def add_processor_to_track(
track_id: int,
processor_name: str,
processor_uid: str,
plugin_type: str = "internal"
) -> str:
"""Add a processor to a track.
Args:
track_id: ID of the track
processor_name: Name for the new processor
processor_uid: UID of the processor (e.g., "sushi.testing.gain")
plugin_type: Type of plugin (internal, vst2, vst3, lv2)
"""
# Map plugin type string to enum
plugin_type_map = {
"internal": sushi_rpc_pb2.PluginType.Type.INTERNAL,
"vst2": sushi_rpc_pb2.PluginType.Type.VST2X,
"vst3": sushi_rpc_pb2.PluginType.Type.VST3X,
"lv2": sushi_rpc_pb2.PluginType.Type.LV2
}
if plugin_type.lower() not in plugin_type_map:
return f"Invalid plugin type: {plugin_type}. Valid types are: {', '.join(plugin_type_map.keys())}"
with get_channel() as channel:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
# Create processor position (add to the back of the chain)
position = sushi_rpc_pb2.ProcessorPosition(add_to_back=True)
# Create the processor
audio_graph_controller.CreateProcessorOnTrack(
sushi_rpc_pb2.CreateProcessorRequest(
name=processor_name,
uid=processor_uid,
path="", # Path is optional for internal plugins
type=sushi_rpc_pb2.PluginType(type=plugin_type_map[plugin_type.lower()]),
track=sushi_rpc_pb2.TrackIdentifier(id=track_id),
position=position
)
)
return f"Processor '{processor_name}' added to track ID {track_id}"
@mcp.tool()
async def remove_processor_from_track(track_id: int, processor_id: int) -> str:
"""Remove a processor from a track.
Args:
track_id: ID of the track
processor_id: ID of the processor to remove
"""
with get_channel() as channel:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
# Get processor info for the response message
processor_info = audio_graph_controller.GetProcessorInfo(
sushi_rpc_pb2.ProcessorIdentifier(id=processor_id)
)
# Delete the processor
audio_graph_controller.DeleteProcessorFromTrack(
sushi_rpc_pb2.DeleteProcessorRequest(
processor=sushi_rpc_pb2.ProcessorIdentifier(id=processor_id),
track=sushi_rpc_pb2.TrackIdentifier(id=track_id)
)
)
return f"Processor '{processor_info.name}' removed from track ID {track_id}"
# Audio Routing Tools
@mcp.tool()
async def get_track_connections(track_id: int) -> Dict[str, List[Dict[str, Any]]]:
"""Get audio connections for a track.
Args:
track_id: ID of the track
"""
with get_channel() as channel:
audio_routing_controller = sushi_rpc_pb2_grpc.AudioRoutingControllerStub(channel)
# Get input connections
input_connections_response = audio_routing_controller.GetInputConnectionsForTrack(
sushi_rpc_pb2.TrackIdentifier(id=track_id)
)
# Get output connections
output_connections_response = audio_routing_controller.GetOutputConnectionsForTrack(
sushi_rpc_pb2.TrackIdentifier(id=track_id)
)
# Format the connections
inputs = []
for conn in input_connections_response.connections:
inputs.append({
"track_channel": conn.track_channel,
"engine_channel": conn.engine_channel
})
outputs = []
for conn in output_connections_response.connections:
outputs.append({
"track_channel": conn.track_channel,
"engine_channel": conn.engine_channel
})
return {
"inputs": inputs,
"outputs": outputs
}
@mcp.tool()
async def connect_input_to_track(track_id: int, track_channel: int, engine_channel: int) -> str:
"""Connect an input channel to a track.
Args:
track_id: ID of the track
track_channel: Channel on the track (0-based)
engine_channel: Engine input channel (0-based)
"""
with get_channel() as channel:
audio_routing_controller = sushi_rpc_pb2_grpc.AudioRoutingControllerStub(channel)
audio_routing_controller.ConnectInputChannelToTrack(
sushi_rpc_pb2.AudioConnection(
track=sushi_rpc_pb2.TrackIdentifier(id=track_id),
track_channel=track_channel,
engine_channel=engine_channel
)
)
return f"Connected engine input channel {engine_channel} to track ID {track_id}, channel {track_channel}"
@mcp.tool()
async def connect_track_to_output(track_id: int, track_channel: int, engine_channel: int) -> str:
"""Connect a track to an output channel.
Args:
track_id: ID of the track
track_channel: Channel on the track (0-based)
engine_channel: Engine output channel (0-based)
"""
with get_channel() as channel:
audio_routing_controller = sushi_rpc_pb2_grpc.AudioRoutingControllerStub(channel)
audio_routing_controller.ConnectOutputChannelFromTrack(
sushi_rpc_pb2.AudioConnection(
track=sushi_rpc_pb2.TrackIdentifier(id=track_id),
track_channel=track_channel,
engine_channel=engine_channel
)
)
return f"Connected track ID {track_id}, channel {track_channel} to engine output channel {engine_channel}"
# Available plugins helper
@mcp.tool()
async def get_available_plugins() -> Dict[str, List[str]]:
"""Get a list of available plugins from the config."""
if not config:
return {"error": "Configuration not loaded"}
internal_plugins = set()
# Extract all plugin UIDs from the config
for track in config.get("tracks", []):
for plugin in track.get("plugins", []):
if plugin.get("type") == "internal":
internal_plugins.add(plugin.get("uid"))
# Also check post track
post_track = config.get("post_track", {})
for plugin in post_track.get("plugins", []):
if plugin.get("type") == "internal":
internal_plugins.add(plugin.get("uid"))
return {
"internal_plugins": sorted(list(internal_plugins))
}
# -------------------- NEW MIDI CONTROLLER TOOLS --------------------
@mcp.tool()
async def get_midi_ports() -> Dict[str, int]:
"""Get the number of MIDI input and output ports."""
with get_channel() as channel:
midi_controller = sushi_rpc_pb2_grpc.MidiControllerStub(channel)
# Get port counts
input_ports = midi_controller.GetInputPorts(sushi_rpc_pb2.GenericVoidValue())
output_ports = midi_controller.GetOutputPorts(sushi_rpc_pb2.GenericVoidValue())
return {
"input_ports": input_ports.value,
"output_ports": output_ports.value
}
@mcp.tool()
async def get_midi_keyboard_connections() -> Dict[str, List[Dict[str, Any]]]:
"""Get all MIDI keyboard input and output connections."""
with get_channel() as channel:
midi_controller = sushi_rpc_pb2_grpc.MidiControllerStub(channel)
# Get all keyboard connections
kbd_inputs = midi_controller.GetAllKbdInputConnections(sushi_rpc_pb2.GenericVoidValue())
kbd_outputs = midi_controller.GetAllKbdOutputConnections(sushi_rpc_pb2.GenericVoidValue())
inputs = []
for conn in kbd_inputs.connections:
# Get track info
track_name = f"Track ID {conn.track.id}"
try:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
track_info = audio_graph_controller.GetTrackInfo(conn.track)
track_name = track_info.name
except Exception:
pass
inputs.append({
"track_id": conn.track.id,
"track_name": track_name,
"channel": conn.channel.channel,
"port": conn.port,
"raw_midi": conn.raw_midi
})
outputs = []
for conn in kbd_outputs.connections:
# Get track info
track_name = f"Track ID {conn.track.id}"
try:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
track_info = audio_graph_controller.GetTrackInfo(conn.track)
track_name = track_info.name
except Exception:
pass
outputs.append({
"track_id": conn.track.id,
"track_name": track_name,
"channel": conn.channel.channel,
"port": conn.port,
"raw_midi": conn.raw_midi
})
return {
"inputs": inputs,
"outputs": outputs
}
@mcp.tool()
async def get_midi_cc_connections() -> List[Dict[str, Any]]:
"""Get all MIDI CC connections."""
with get_channel() as channel:
midi_controller = sushi_rpc_pb2_grpc.MidiControllerStub(channel)
# Get all CC connections
cc_connections = midi_controller.GetAllCCInputConnections(sushi_rpc_pb2.GenericVoidValue())
connections = []
for conn in cc_connections.connections:
# Get processor and parameter info
processor_id = conn.parameter.processor_id
parameter_id = conn.parameter.parameter_id
processor_name = f"Processor ID {processor_id}"
parameter_name = f"Parameter ID {parameter_id}"
try:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
processor_info = audio_graph_controller.GetProcessorInfo(
sushi_rpc_pb2.ProcessorIdentifier(id=processor_id)
)
processor_name = processor_info.name
parameter_controller = sushi_rpc_pb2_grpc.ParameterControllerStub(channel)
parameter_info = parameter_controller.GetParameterInfo(
sushi_rpc_pb2.ParameterIdentifier(
processor_id=processor_id,
parameter_id=parameter_id
)
)
parameter_name = parameter_info.name
except Exception:
pass
connections.append({
"processor_id": processor_id,
"processor_name": processor_name,
"parameter_id": parameter_id,
"parameter_name": parameter_name,
"channel": conn.channel.channel,
"port": conn.port,
"cc_number": conn.cc_number,
"min_range": conn.min_range,
"max_range": conn.max_range,
"relative_mode": conn.relative_mode
})
return connections
@mcp.tool()
async def connect_midi_keyboard_to_track(
track_id: int,
midi_channel: int,
port: int = 0,
raw_midi: bool = False
) -> str:
"""Connect MIDI keyboard input to a track.
Args:
track_id: ID of the track
midi_channel: MIDI channel (1-16, or 17 for omni)
port: MIDI port number
raw_midi: Whether to use raw MIDI (bypasses Sushi's MIDI mapping)
"""
if midi_channel < 1 or midi_channel > 17:
return f"Invalid MIDI channel: {midi_channel}. Must be 1-16, or 17 for omni."
with get_channel() as channel:
midi_controller = sushi_rpc_pb2_grpc.MidiControllerStub(channel)
midi_controller.ConnectKbdInputToTrack(
sushi_rpc_pb2.MidiKbdConnection(
track=sushi_rpc_pb2.TrackIdentifier(id=track_id),
channel=sushi_rpc_pb2.MidiChannel(channel=midi_channel),
port=port,
raw_midi=raw_midi
)
)
return f"Connected MIDI keyboard on port {port}, channel {midi_channel} to track ID {track_id}"
@mcp.tool()
async def connect_track_to_midi_keyboard(
track_id: int,
midi_channel: int,
port: int = 0,
raw_midi: bool = False
) -> str:
"""Connect a track to MIDI keyboard output.
Args:
track_id: ID of the track
midi_channel: MIDI channel (1-16)
port: MIDI port number
raw_midi: Whether to use raw MIDI (bypasses Sushi's MIDI mapping)
"""
if midi_channel < 1 or midi_channel > 16:
return f"Invalid MIDI channel: {midi_channel}. Must be 1-16."
with get_channel() as channel:
midi_controller = sushi_rpc_pb2_grpc.MidiControllerStub(channel)
midi_controller.ConnectKbdOutputFromTrack(
sushi_rpc_pb2.MidiKbdConnection(
track=sushi_rpc_pb2.TrackIdentifier(id=track_id),
channel=sushi_rpc_pb2.MidiChannel(channel=midi_channel),
port=port,
raw_midi=raw_midi
)
)
return f"Connected track ID {track_id} to MIDI keyboard output on port {port}, channel {midi_channel}"
@mcp.tool()
async def connect_midi_cc_to_parameter(
processor_id: int,
parameter_id: int,
cc_number: int,
midi_channel: int = 1,
port: int = 0,
min_range: float = 0.0,
max_range: float = 1.0,
relative_mode: bool = False
) -> str:
"""Connect a MIDI CC to a parameter.
Args:
processor_id: ID of the processor
parameter_id: ID of the parameter
cc_number: MIDI CC number (0-127)
midi_channel: MIDI channel (1-16, or 17 for omni)
port: MIDI port number
min_range: Minimum parameter value to map to CC 0
max_range: Maximum parameter value to map to CC 127
relative_mode: Whether to use relative mode (CC values adjust current value)
"""
if midi_channel < 1 or midi_channel > 17:
return f"Invalid MIDI channel: {midi_channel}. Must be 1-16, or 17 for omni."
if cc_number < 0 or cc_number > 127:
return f"Invalid CC number: {cc_number}. Must be 0-127."
with get_channel() as channel:
midi_controller = sushi_rpc_pb2_grpc.MidiControllerStub(channel)
midi_controller.ConnectCCToParameter(
sushi_rpc_pb2.MidiCCConnection(
parameter=sushi_rpc_pb2.ParameterIdentifier(
processor_id=processor_id,
parameter_id=parameter_id
),
channel=sushi_rpc_pb2.MidiChannel(channel=midi_channel),
port=port,
cc_number=cc_number,
min_range=min_range,
max_range=max_range,
relative_mode=relative_mode
)
)
# Get processor and parameter info for the response message
processor_name = f"Processor ID {processor_id}"
parameter_name = f"Parameter ID {parameter_id}"
try:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
processor_info = audio_graph_controller.GetProcessorInfo(
sushi_rpc_pb2.ProcessorIdentifier(id=processor_id)
)
processor_name = processor_info.name
parameter_controller = sushi_rpc_pb2_grpc.ParameterControllerStub(channel)
parameter_info = parameter_controller.GetParameterInfo(
sushi_rpc_pb2.ParameterIdentifier(
processor_id=processor_id,
parameter_id=parameter_id
)
)
parameter_name = parameter_info.name
except Exception:
pass
return f"Connected MIDI CC {cc_number} on port {port}, channel {midi_channel} to parameter '{parameter_name}' of processor '{processor_name}'"
@mcp.tool()
async def disconnect_midi_cc(
processor_id: int,
parameter_id: int,
cc_number: int,
midi_channel: int = 1,
port: int = 0
) -> str:
"""Disconnect a MIDI CC from a parameter.
Args:
processor_id: ID of the processor
parameter_id: ID of the parameter
cc_number: MIDI CC number (0-127)
midi_channel: MIDI channel (1-16, or 17 for omni)
port: MIDI port number
"""
with get_channel() as channel:
midi_controller = sushi_rpc_pb2_grpc.MidiControllerStub(channel)
midi_controller.DisconnectCC(
sushi_rpc_pb2.MidiCCConnection(
parameter=sushi_rpc_pb2.ParameterIdentifier(
processor_id=processor_id,
parameter_id=parameter_id
),
channel=sushi_rpc_pb2.MidiChannel(channel=midi_channel),
port=port,
cc_number=cc_number
)
)
return f"Disconnected MIDI CC {cc_number} on port {port}, channel {midi_channel} from parameter ID {parameter_id} of processor ID {processor_id}"
@mcp.tool()
async def disconnect_all_midi_cc_from_processor(processor_id: int) -> str:
"""Disconnect all MIDI CC connections from a processor.
Args:
processor_id: ID of the processor
"""
with get_channel() as channel:
midi_controller = sushi_rpc_pb2_grpc.MidiControllerStub(channel)
midi_controller.DisconnectAllCCFromProcessor(
sushi_rpc_pb2.ProcessorIdentifier(id=processor_id)
)
return f"Disconnected all MIDI CC connections from processor ID {processor_id}"
@mcp.tool()
async def send_note_on(
track_id: int,
note: int,
velocity: float,
channel: int = 1
) -> str:
"""Send a MIDI note-on message to a track.
Args:
track_id: ID of the track
note: MIDI note number (0-127)
velocity: Note velocity (0.0-1.0)
channel: MIDI channel (1-16)
"""
if channel < 1 or channel > 16:
return f"Invalid MIDI channel: {channel}. Must be 1-16."
if note < 0 or note > 127:
return f"Invalid note number: {note}. Must be 0-127."
if velocity < 0.0 or velocity > 1.0:
return f"Invalid velocity: {velocity}. Must be 0.0-1.0."
midi_channel = channel # Store the MIDI channel in a different variable
with get_channel() as grpc_channel: # Rename this to avoid conflict
keyboard_controller = sushi_rpc_pb2_grpc.KeyboardControllerStub(grpc_channel)
keyboard_controller.SendNoteOn(
sushi_rpc_pb2.NoteOnRequest(
track=sushi_rpc_pb2.TrackIdentifier(id=track_id),
channel=midi_channel, # Use the stored MIDI channel value
note=note,
velocity=velocity
)
)
return f"Sent note-on message for note {note} with velocity {velocity} to track ID {track_id}, channel {midi_channel}"
@mcp.tool()
async def send_note_off(
track_id: int,
note: int,
velocity: float,
channel: int = 1
) -> str:
"""Send a MIDI note-off message to a track.
Args:
track_id: ID of the track
note: MIDI note number (0-127)
velocity: Release velocity (0.0-1.0)
channel: MIDI channel (1-16)
"""
if channel < 1 or channel > 16:
return f"Invalid MIDI channel: {channel}. Must be 1-16."
if note < 0 or note > 127:
return f"Invalid note number: {note}. Must be 0-127."
if velocity < 0.0 or velocity > 1.0:
return f"Invalid velocity: {velocity}. Must be 0.0-1.0."
midi_channel = channel # Store the MIDI channel value
with get_channel() as grpc_channel: # Rename this to avoid variable collision
keyboard_controller = sushi_rpc_pb2_grpc.KeyboardControllerStub(grpc_channel)
keyboard_controller.SendNoteOff(
sushi_rpc_pb2.NoteOffRequest(
track=sushi_rpc_pb2.TrackIdentifier(id=track_id),
channel=midi_channel, # Use the stored MIDI channel value
note=note,
velocity=velocity
)
)
return f"Sent note-off message for note {note} with velocity {velocity} to track ID {track_id}, channel {midi_channel}"
# -------------------- REAL-TIME UPDATES --------------------
# Start background task for parameter update notifications
async def parameter_update_listener():
"""Background task to listen for parameter updates."""
global subscribed_to_updates, parameter_subscriptions, notification_callbacks
if not subscribed_to_updates:
logger.info("Parameter update listener not started (not subscribed)")
return
logger.info("Starting parameter update listener")
try:
with get_channel() as channel:
notification_controller = sushi_rpc_pb2_grpc.NotificationControllerStub(channel)
# Create a blocklist of parameters we're not interested in
blocklist = sushi_rpc_pb2.ParameterNotificationBlocklist(parameters=[])
# Subscribe to parameter updates
parameter_updates = notification_controller.SubscribeToParameterUpdates(blocklist)
# Process updates
for update in parameter_updates:
processor_id = update.parameter.processor_id
parameter_id = update.parameter.parameter_id
# Check if we have a callback for this parameter
param_key = f"{processor_id}:{parameter_id}"
if param_key in parameter_subscriptions:
# Store the update for later retrieval
notification_callbacks[param_key] = {
"processor_id": processor_id,
"parameter_id": parameter_id,
"normalized_value": update.normalized_value,
"domain_value": update.domain_value,
"formatted_value": update.formatted_value,
"timestamp": time.time()
}
logger.info(f"Parameter update: {param_key} = {update.domain_value} ({update.formatted_value})")
except Exception as e:
logger.error(f"Error in parameter update listener: {e}")
subscribed_to_updates = False
@mcp.tool()
async def subscribe_to_parameter_updates(enable: bool = True) -> str:
"""Enable or disable real-time parameter update subscriptions.
Args:
enable: True to enable, False to disable
"""
global subscribed_to_updates
if enable and not subscribed_to_updates:
subscribed_to_updates = True
# Start the background task
asyncio.create_task(parameter_update_listener())
return "Subscribed to parameter updates"
elif not enable and subscribed_to_updates:
subscribed_to_updates = False
return "Unsubscribed from parameter updates"
elif enable and subscribed_to_updates:
return "Already subscribed to parameter updates"
else:
return "Already unsubscribed from parameter updates"
@mcp.tool()
async def watch_parameter(processor_id: int, parameter_id: int) -> str:
"""Add a parameter to the watch list for real-time updates.
Args:
processor_id: ID of the processor
parameter_id: ID of the parameter
"""
global parameter_subscriptions
# Get parameter info
with get_channel() as channel:
parameter_controller = sushi_rpc_pb2_grpc.ParameterControllerStub(channel)
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
try:
# Get processor info
processor_info = audio_graph_controller.GetProcessorInfo(
sushi_rpc_pb2.ProcessorIdentifier(id=processor_id)
)
# Get parameter info
parameter_info = parameter_controller.GetParameterInfo(
sushi_rpc_pb2.ParameterIdentifier(
processor_id=processor_id,
parameter_id=parameter_id
)
)
# Add to watch list
param_key = f"{processor_id}:{parameter_id}"
parameter_subscriptions.add(param_key)
return f"Watching parameter '{parameter_info.name}' of processor '{processor_info.name}'"
except Exception as e:
return f"Failed to add parameter to watch list: {e}"
@mcp.tool()
async def unwatch_parameter(processor_id: int, parameter_id: int) -> str:
"""Remove a parameter from the watch list for real-time updates.
Args:
processor_id: ID of the processor
parameter_id: ID of the parameter
"""
global parameter_subscriptions
param_key = f"{processor_id}:{parameter_id}"
if param_key in parameter_subscriptions:
parameter_subscriptions.remove(param_key)
return f"Removed parameter {parameter_id} of processor {processor_id} from watch list"
else:
return f"Parameter {parameter_id} of processor {processor_id} is not in the watch list"
@mcp.tool()
async def get_parameter_updates() -> Dict[str, List[Dict[str, Any]]]:
"""Get all parameter updates since the last call."""
global notification_callbacks
updates = []
for param_key, update in notification_callbacks.items():
updates.append(update)
# Clear the callbacks
notification_callbacks = {}
return {"updates": updates}
# -------------------- PRESET MANAGEMENT --------------------
@mcp.tool()
async def save_processor_preset(processor_id: int, name: str) -> str:
"""Save the current state of a processor as a preset.
Args:
processor_id: ID of the processor
name: Name for the preset
"""
# Get processor info and state
with get_channel() as channel:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
try:
# Get processor info
processor_info = audio_graph_controller.GetProcessorInfo(
sushi_rpc_pb2.ProcessorIdentifier(id=processor_id)
)
# Get processor state
processor_state = audio_graph_controller.GetProcessorState(
sushi_rpc_pb2.ProcessorIdentifier(id=processor_id)
)
# Create preset data
preset_data = {
"processor_info": {
"id": processor_id,
"name": processor_info.name,
"label": processor_info.label
},
"parameters": [],
"bypassed": processor_state.bypassed.value if processor_state.bypassed.has_value else None,
"timestamp": datetime.now().isoformat()
}
# Get all parameter values
parameter_controller = sushi_rpc_pb2_grpc.ParameterControllerStub(channel)
parameters = parameter_controller.GetProcessorParameters(
sushi_rpc_pb2.ProcessorIdentifier(id=processor_id)
)
for param in parameters.parameters:
# Get parameter value
value_response = parameter_controller.GetParameterValue(
sushi_rpc_pb2.ParameterIdentifier(
processor_id=processor_id,
parameter_id=param.id
)
)
preset_data["parameters"].append({
"id": param.id,
"name": param.name,
"value": value_response.value
})
# Save to file
sanitized_name = name.replace(" ", "_").replace("/", "_").replace("\\", "_")
processor_name = processor_info.name.replace(" ", "_")
filename = f"{processor_name}_{sanitized_name}.json"
filepath = os.path.join(PRESETS_DIR, filename)
with open(filepath, "w") as f:
json.dump(preset_data, f, indent=2)
return f"Saved preset '{name}' for processor '{processor_info.name}'"
except Exception as e:
logger.error(f"Failed to save preset: {e}")
return f"Failed to save preset: {e}"
@mcp.tool()
async def load_processor_preset(preset_file: str, processor_id: int = None) -> str:
"""Load a preset onto a processor.
Args:
preset_file: Filename of the preset to load
processor_id: ID of the target processor (if None, use the processor ID from the preset)
"""
filepath = os.path.join(PRESETS_DIR, preset_file)
try:
# Load preset data
with open(filepath, "r") as f:
preset_data = json.load(f)
# Use processor ID from preset if not specified
target_processor_id = processor_id if processor_id is not None else preset_data["processor_info"]["id"]
# Apply preset to processor
with get_channel() as channel:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
parameter_controller = sushi_rpc_pb2_grpc.ParameterControllerStub(channel)
# Get target processor info
processor_info = audio_graph_controller.GetProcessorInfo(
sushi_rpc_pb2.ProcessorIdentifier(id=target_processor_id)
)
# Apply bypass state
if preset_data.get("bypassed") is not None:
audio_graph_controller.SetProcessorBypassState(
sushi_rpc_pb2.ProcessorBypassStateSetRequest(
processor=sushi_rpc_pb2.ProcessorIdentifier(id=target_processor_id),
value=preset_data["bypassed"]
)
)
# Apply parameter values
for param_data in preset_data["parameters"]:
try:
# Find the corresponding parameter in the target processor
param_id = None
parameter_info_list = parameter_controller.GetProcessorParameters(
sushi_rpc_pb2.ProcessorIdentifier(id=target_processor_id)
)
for param_info in parameter_info_list.parameters:
if param_info.name == param_data["name"]:
param_id = param_info.id
break
if param_id is not None:
# Set parameter value
parameter_controller.SetParameterValue(
sushi_rpc_pb2.ParameterValue(
parameter=sushi_rpc_pb2.ParameterIdentifier(
processor_id=target_processor_id,
parameter_id=param_id
),
value=param_data["value"]
)
)
except Exception as e:
logger.warning(f"Failed to set parameter {param_data['name']}: {e}")
return f"Loaded preset '{preset_file}' onto processor '{processor_info.name}'"
except Exception as e:
logger.error(f"Failed to load preset: {e}")
return f"Failed to load preset: {e}"
@mcp.tool()
async def list_presets() -> List[Dict[str, Any]]:
"""List all saved presets."""
presets = []
for filename in os.listdir(PRESETS_DIR):
if filename.endswith(".json"):
try:
filepath = os.path.join(PRESETS_DIR, filename)
with open(filepath, "r") as f:
preset_data = json.load(f)
presets.append({
"filename": filename,
"processor_name": preset_data["processor_info"]["name"],
"processor_id": preset_data["processor_info"]["id"],
"parameter_count": len(preset_data["parameters"]),
"timestamp": preset_data.get("timestamp", "Unknown")
})
except Exception as e:
logger.warning(f"Failed to read preset {filename}: {e}")
return presets
@mcp.tool()
async def delete_preset(preset_file: str) -> str:
"""Delete a saved preset.
Args:
preset_file: Filename of the preset to delete
"""
filepath = os.path.join(PRESETS_DIR, preset_file)
if os.path.exists(filepath):
os.remove(filepath)
return f"Deleted preset '{preset_file}'"
else:
return f"Preset '{preset_file}' not found"
@mcp.tool()
async def snapshot_all_tracks(name: str) -> str:
"""Save a snapshot of all tracks and their processors.
Args:
name: Name for the snapshot
"""
try:
# Get all tracks and their processors
with get_channel() as channel:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
parameter_controller = sushi_rpc_pb2_grpc.ParameterControllerStub(channel)
tracks_response = audio_graph_controller.GetAllTracks(sushi_rpc_pb2.GenericVoidValue())
snapshot_data = {
"name": name,
"timestamp": datetime.now().isoformat(),
"tracks": []
}
for track in tracks_response.tracks:
track_data = {
"id": track.id,
"name": track.name,
"processors": []
}
# Get processors for this track
processors_response = audio_graph_controller.GetTrackProcessors(
sushi_rpc_pb2.TrackIdentifier(id=track.id)
)
for processor in processors_response.processors:
# Get processor info
processor_info = audio_graph_controller.GetProcessorInfo(
sushi_rpc_pb2.ProcessorIdentifier(id=processor.id)
)
# Get processor state
processor_state = audio_graph_controller.GetProcessorState(
sushi_rpc_pb2.ProcessorIdentifier(id=processor.id)
)
processor_data = {
"id": processor.id,
"name": processor_info.name,
"bypassed": processor_state.bypassed.value if processor_state.bypassed.has_value else False,
"parameters": []
}
# Get parameters
parameters_response = parameter_controller.GetProcessorParameters(
sushi_rpc_pb2.ProcessorIdentifier(id=processor.id)
)
for param in parameters_response.parameters:
# Get parameter value
value_response = parameter_controller.GetParameterValue(
sushi_rpc_pb2.ParameterIdentifier(
processor_id=processor.id,
parameter_id=param.id
)
)
processor_data["parameters"].append({
"id": param.id,
"name": param.name,
"value": value_response.value
})
track_data["processors"].append(processor_data)
snapshot_data["tracks"].append(track_data)
# Save to file
sanitized_name = name.replace(" ", "_").replace("/", "_").replace("\\", "_")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"snapshot_{sanitized_name}_{timestamp}.json"
filepath = os.path.join(PRESETS_DIR, filename)
with open(filepath, "w") as f:
json.dump(snapshot_data, f, indent=2)
return f"Saved snapshot '{name}' to {filename}"
except Exception as e:
logger.error(f"Failed to save snapshot: {e}")
return f"Failed to save snapshot: {e}"
@mcp.tool()
async def load_snapshot(snapshot_file: str) -> str:
"""Load a snapshot of all tracks and their processors.
Args:
snapshot_file: Filename of the snapshot to load
"""
filepath = os.path.join(PRESETS_DIR, snapshot_file)
try:
# Load snapshot data
with open(filepath, "r") as f:
snapshot_data = json.load(f)
# Apply snapshot to tracks and processors
with get_channel() as channel:
audio_graph_controller = sushi_rpc_pb2_grpc.AudioGraphControllerStub(channel)
parameter_controller = sushi_rpc_pb2_grpc.ParameterControllerStub(channel)
for track_data in snapshot_data["tracks"]:
track_id = track_data["id"]
# Check if track exists
try:
audio_graph_controller.GetTrackInfo(
sushi_rpc_pb2.TrackIdentifier(id=track_id)
)
except Exception:
logger.warning(f"Track ID {track_id} not found, skipping")
continue
for processor_data in track_data["processors"]:
processor_id = processor_data["id"]
# Check if processor exists
try:
audio_graph_controller.GetProcessorInfo(
sushi_rpc_pb2.ProcessorIdentifier(id=processor_id)
)
except Exception:
logger.warning(f"Processor ID {processor_id} not found, skipping")
continue
# Set bypass state
audio_graph_controller.SetProcessorBypassState(
sushi_rpc_pb2.ProcessorBypassStateSetRequest(
processor=sushi_rpc_pb2.ProcessorIdentifier(id=processor_id),
value=processor_data["bypassed"]
)
)
# Set parameter values
for param_data in processor_data["parameters"]:
try:
parameter_controller.SetParameterValue(
sushi_rpc_pb2.ParameterValue(
parameter=sushi_rpc_pb2.ParameterIdentifier(
processor_id=processor_id,
parameter_id=param_data["id"]
),
value=param_data["value"]
)
)
except Exception as e:
logger.warning(f"Failed to set parameter {param_data['id']}: {e}")
return f"Loaded snapshot '{snapshot_data['name']}'"
except Exception as e:
logger.error(f"Failed to load snapshot: {e}")
return f"Failed to load snapshot: {e}"
# Launch the MCP server
if __name__ == "__main__":
# Initialize and run the server
mcp.run(transport='stdio')