from __future__ import annotations
import asyncio
import hashlib
import json
import os
import signal
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
class LSPError(RuntimeError):
"""Raised when an LSP operation fails."""
@dataclass
class DocumentState:
uri: str
language_id: str
version: int
text_hash: str
class LSPClient:
"""Async JSON-RPC client for Language Server Protocol (LSP) servers.
This client assumes stdio transport to the language server.
"""
def __init__(
self,
*,
name: str,
command: str,
args: list[str],
root_path: Path,
initialization_options: dict[str, Any] | None = None,
settings: dict[str, Any] | None = None,
env: dict[str, str] | None = None,
request_timeout_sec: float = 20.0,
):
self.name = name
self.command = command
self.args = args
self.root_path = root_path
self.initialization_options = initialization_options or {}
self.settings = settings or {}
self.env = env or {}
self.request_timeout_sec = request_timeout_sec
self._proc: asyncio.subprocess.Process | None = None
self._reader_task: asyncio.Task | None = None
self._stderr_task: asyncio.Task | None = None
self._next_id = 0
self._pending: dict[int, asyncio.Future] = {}
self._write_lock = asyncio.Lock()
self._initialized = False
self._documents: dict[str, DocumentState] = {} # uri -> state
self._diagnostics: dict[str, Any] = {} # uri -> publishDiagnostics payload
self._diagnostics_seq: dict[str, int] = {}
self._diagnostics_events: dict[str, asyncio.Event] = {}
# ---------------- lifecycle ----------------
async def start(self) -> None:
if self._proc is not None:
return
cmd = [self.command, *self.args]
env = os.environ.copy()
env.update(self.env)
self._proc = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=str(self.root_path),
env=env,
)
assert self._proc.stdout and self._proc.stdin and self._proc.stderr
self._reader_task = asyncio.create_task(self._read_loop(), name=f"{self.name}-read")
self._stderr_task = asyncio.create_task(self._stderr_loop(), name=f"{self.name}-stderr")
await self._initialize()
async def stop(self) -> None:
proc = self._proc
if proc is None:
return
try:
# Politely request shutdown if possible.
if self._initialized:
try:
await self.request("shutdown", {})
except Exception:
pass
try:
await self.notify("exit", {})
except Exception:
pass
# Terminate process.
if proc.returncode is None:
if os.name != "nt":
proc.send_signal(signal.SIGTERM)
else:
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=2.0)
except asyncio.TimeoutError:
proc.kill()
finally:
if self._reader_task:
self._reader_task.cancel()
if self._stderr_task:
self._stderr_task.cancel()
self._proc = None
self._initialized = False
self._pending.clear()
self._documents.clear()
self._diagnostics.clear()
# ---------------- protocol helpers ----------------
async def request(self, method: str, params: Any | None = None) -> Any:
"""Send an LSP request and wait for a response."""
await self.start()
self._next_id += 1
msg_id = self._next_id
loop = asyncio.get_running_loop()
fut: asyncio.Future = loop.create_future()
self._pending[msg_id] = fut
payload = {"jsonrpc": "2.0", "id": msg_id, "method": method, "params": params or {}}
await self._send(payload)
try:
return await asyncio.wait_for(fut, timeout=self.request_timeout_sec)
finally:
self._pending.pop(msg_id, None)
async def notify(self, method: str, params: Any | None = None) -> None:
"""Send an LSP notification (no response)."""
await self.start()
payload = {"jsonrpc": "2.0", "method": method, "params": params or {}}
await self._send(payload)
async def _send(self, payload: dict[str, Any]) -> None:
proc = self._proc
if proc is None or proc.stdin is None:
raise LSPError(f"{self.name}: process is not running")
body = json.dumps(payload, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
header = f"Content-Length: {len(body)}\r\n\r\n".encode("ascii")
async with self._write_lock:
proc.stdin.write(header + body)
await proc.stdin.drain()
# ---------------- initialization ----------------
async def _initialize(self) -> None:
if self._initialized:
return
root_uri = Path(self.root_path).as_uri()
init_params = {
"processId": os.getpid(),
"rootUri": root_uri,
"capabilities": {
"workspace": {
"workspaceFolders": True,
"configuration": True,
},
"textDocument": {
"synchronization": {
"dynamicRegistration": False,
"willSave": False,
"didSave": False,
"willSaveWaitUntil": False,
},
"definition": {"dynamicRegistration": False},
"references": {"dynamicRegistration": False},
"hover": {"dynamicRegistration": False},
"documentSymbol": {"dynamicRegistration": False},
"rename": {"dynamicRegistration": False},
"codeAction": {"dynamicRegistration": False},
},
},
"workspaceFolders": [{"uri": root_uri, "name": Path(self.root_path).name}],
"clientInfo": {"name": "codex-lsp-bridge", "version": "0.1.0"},
"initializationOptions": self.initialization_options or {},
}
await self.request("initialize", init_params)
await self.notify("initialized", {})
if self.settings:
await self.notify("workspace/didChangeConfiguration", {"settings": self.settings})
self._initialized = True
# ---------------- file sync ----------------
async def sync_document(self, *, uri: str, language_id: str, text: str) -> None:
"""Ensure the document is opened and up-to-date in the language server."""
# Hash to avoid redundant didChange payloads.
text_hash = hashlib.sha256(text.encode("utf-8", errors="replace")).hexdigest()
state = self._documents.get(uri)
if state is None:
self._documents[uri] = DocumentState(
uri=uri,
language_id=language_id,
version=1,
text_hash=text_hash,
)
await self.notify(
"textDocument/didOpen",
{
"textDocument": {
"uri": uri,
"languageId": language_id,
"version": 1,
"text": text,
}
},
)
return
if state.text_hash == text_hash:
return
new_version = state.version + 1
self._documents[uri] = DocumentState(
uri=uri,
language_id=language_id,
version=new_version,
text_hash=text_hash,
)
await self.notify(
"textDocument/didChange",
{
"textDocument": {"uri": uri, "version": new_version},
"contentChanges": [{"text": text}],
},
)
# ---------------- high-level LSP calls ----------------
async def definition(self, *, uri: str, line: int, character: int) -> Any:
return await self.request(
"textDocument/definition",
{
"textDocument": {"uri": uri},
"position": {"line": line, "character": character},
},
)
async def type_definition(self, *, uri: str, line: int, character: int) -> Any:
return await self.request(
"textDocument/typeDefinition",
{
"textDocument": {"uri": uri},
"position": {"line": line, "character": character},
},
)
async def references(
self, *, uri: str, line: int, character: int, include_declaration: bool = True
) -> Any:
return await self.request(
"textDocument/references",
{
"textDocument": {"uri": uri},
"position": {"line": line, "character": character},
"context": {"includeDeclaration": include_declaration},
},
)
async def hover(self, *, uri: str, line: int, character: int) -> Any:
return await self.request(
"textDocument/hover",
{
"textDocument": {"uri": uri},
"position": {"line": line, "character": character},
},
)
async def signature_help(self, *, uri: str, line: int, character: int) -> Any:
return await self.request(
"textDocument/signatureHelp",
{
"textDocument": {"uri": uri},
"position": {"line": line, "character": character},
},
)
async def document_symbols(self, *, uri: str) -> Any:
return await self.request("textDocument/documentSymbol", {"textDocument": {"uri": uri}})
async def workspace_symbols(self, *, query: str) -> Any:
return await self.request("workspace/symbol", {"query": query})
async def format_document(self, *, uri: str, options: dict[str, Any]) -> Any:
return await self.request(
"textDocument/formatting",
{
"textDocument": {"uri": uri},
"options": options,
},
)
async def code_action(
self,
*,
uri: str,
start_line: int,
start_character: int,
end_line: int,
end_character: int,
only: list[str] | None = None,
) -> Any:
context: dict[str, Any] = {"diagnostics": []}
if only:
context["only"] = only
return await self.request(
"textDocument/codeAction",
{
"textDocument": {"uri": uri},
"range": {
"start": {"line": start_line, "character": start_character},
"end": {"line": end_line, "character": end_character},
},
"context": context,
},
)
async def rename(self, *, uri: str, line: int, character: int, new_name: str) -> Any:
return await self.request(
"textDocument/rename",
{
"textDocument": {"uri": uri},
"position": {"line": line, "character": character},
"newName": new_name,
},
)
async def prepare_call_hierarchy(self, *, uri: str, line: int, character: int) -> Any:
return await self.request(
"textDocument/prepareCallHierarchy",
{
"textDocument": {"uri": uri},
"position": {"line": line, "character": character},
},
)
async def call_hierarchy_incoming(self, *, item: dict[str, Any]) -> Any:
return await self.request("callHierarchy/incomingCalls", {"item": item})
async def call_hierarchy_outgoing(self, *, item: dict[str, Any]) -> Any:
return await self.request("callHierarchy/outgoingCalls", {"item": item})
def latest_diagnostics(self, *, uri: str) -> Any:
return self._diagnostics.get(uri)
async def await_diagnostics(self, *, uri: str, timeout_sec: float = 5.0) -> tuple[Any, bool]:
start_seq = self._diagnostics_seq.get(uri, 0)
event = self._diagnostics_events.setdefault(uri, asyncio.Event())
event.clear()
try:
await asyncio.wait_for(event.wait(), timeout=timeout_sec)
except asyncio.TimeoutError:
return self._diagnostics.get(uri), False
updated = self._diagnostics_seq.get(uri, 0) > start_seq
return self._diagnostics.get(uri), updated
# ---------------- background loops ----------------
async def _stderr_loop(self) -> None:
proc = self._proc
if proc is None or proc.stderr is None:
return
while True:
line = await proc.stderr.readline()
if not line:
return
# Keep stderr noise minimal; you can wire this into structured logging.
# We intentionally do not print by default to avoid polluting Codex output.
# If you want logs, set CODEX_LSP_BRIDGE_LOG_STDERR=1
if os.getenv("CODEX_LSP_BRIDGE_LOG_STDERR") == "1":
try:
text = line.decode("utf-8", errors="replace").rstrip("\n")
except Exception:
text = str(line)
print(f"[{self.name} stderr] {text}")
async def _read_loop(self) -> None:
proc = self._proc
if proc is None or proc.stdout is None:
return
reader = proc.stdout
while True:
# 1) Read headers
try:
header_bytes = await reader.readuntil(b"\r\n\r\n")
except (asyncio.IncompleteReadError, asyncio.LimitOverrunError):
return
header_text = header_bytes.decode("ascii", errors="replace")
content_length = None
for line in header_text.split("\r\n"):
if line.lower().startswith("content-length:"):
try:
content_length = int(line.split(":", 1)[1].strip())
except ValueError:
content_length = None
if content_length is None:
# Malformed message; attempt to continue.
continue
# 2) Read JSON body
try:
body = await reader.readexactly(content_length)
except asyncio.IncompleteReadError:
return
try:
msg = json.loads(body.decode("utf-8"))
except json.JSONDecodeError:
continue
await self._handle_message(msg)
async def _handle_message(self, msg: dict[str, Any]) -> None:
# Response
if "id" in msg and ("result" in msg or "error" in msg):
msg_id = msg.get("id")
fut = self._pending.get(msg_id)
if not fut:
return
if "error" in msg and msg["error"] is not None:
fut.set_exception(LSPError(f"{self.name} error: {msg['error']}"))
else:
fut.set_result(msg.get("result"))
return
# Server request
if "id" in msg and "method" in msg:
req_id = msg.get("id")
method = msg.get("method")
params = msg.get("params")
try:
result = await self._handle_server_request(method, params)
await self._send({"jsonrpc": "2.0", "id": req_id, "result": result})
except Exception as e:
await self._send(
{
"jsonrpc": "2.0",
"id": req_id,
"error": {
"code": -32603,
"message": f"Internal error handling {method}: {e}",
},
}
)
return
# Notification
if "method" in msg and "id" not in msg:
method = msg.get("method")
params = msg.get("params")
self._handle_notification(method, params)
async def _handle_server_request(self, method: str, params: Any) -> Any:
# Many servers request configuration after initialization.
if method == "workspace/configuration":
items = []
if isinstance(params, dict):
items = params.get("items") or []
out = []
for item in items:
if isinstance(item, dict):
section = item.get("section")
if section and isinstance(self.settings, dict) and section in self.settings:
out.append(self.settings.get(section))
else:
out.append(self.settings)
else:
out.append(self.settings)
return out
# Safe default responses for common bookkeeping requests.
if method in {
"window/workDoneProgress/create",
"client/registerCapability",
"client/unregisterCapability",
}:
return None
# Unknown request: respond with null.
return None
def _handle_notification(self, method: str, params: Any) -> None:
if method == "textDocument/publishDiagnostics" and isinstance(params, dict):
uri = params.get("uri")
if isinstance(uri, str):
self._diagnostics[uri] = params
self._diagnostics_seq[uri] = self._diagnostics_seq.get(uri, 0) + 1
event = self._diagnostics_events.get(uri)
if event:
event.set()