import { Client as SSHClient } from 'ssh2'
import net from 'net'
import fs from 'fs/promises'
import path from 'path'
import { fileURLToPath } from 'url'
import type { SSHConfig, TunnelConfig } from './types.js'
const __filename = fileURLToPath(import.meta.url)
const __dirname = path.dirname(__filename)
// Lock file paths
const LOCK_DIR = '/tmp/pg_mcp_tunnel'
const TUNNEL_LOCK_FILE = path.join(LOCK_DIR, 'tunnel.lock')
const REFS_FILE = path.join(LOCK_DIR, 'refs.json')
interface TunnelState {
pid: number
port: number
refCount: number
startTime: number
}
export class TunnelManager {
private sshConfig: SSHConfig
private tunnelConfig: TunnelConfig
private sshConnection?: SSHClient
private tunnelServer?: net.Server
private instanceId: string
constructor(sshConfig: SSHConfig, tunnelConfig: TunnelConfig) {
this.sshConfig = sshConfig
this.tunnelConfig = tunnelConfig
this.instanceId = `${process.pid}-${Date.now()}`
}
async connect(): Promise<void> {
await this.ensureLockDir()
// Try to acquire lock for tunnel management
const lockAcquired = await this.acquireLock()
if (lockAcquired) {
// We're the first instance or taking over from a dead instance
await this.startTunnel()
} else {
// Another instance is managing the tunnel
await this.waitForTunnel()
}
// Register this instance
await this.addReference()
}
async disconnect(): Promise<void> {
const shouldShutdownTunnel = await this.removeReference()
if (shouldShutdownTunnel && this.tunnelServer) {
console.error('Last instance disconnecting, shutting down SSH tunnel')
this.tunnelServer.close()
if (this.sshConnection) {
this.sshConnection.end()
}
await this.releaseLock()
}
}
private async ensureLockDir(): Promise<void> {
try {
await fs.mkdir(LOCK_DIR, { recursive: true })
} catch (error) {
// Directory might already exist
}
}
private async acquireLock(): Promise<boolean> {
try {
// Check if lock exists and if the process is still running
const lockData = await this.readLockFile()
if (lockData && await this.isProcessRunning(lockData.pid)) {
return false
}
// Create or update lock file
await this.writeLockFile({
pid: process.pid,
port: this.tunnelConfig.srcPort,
refCount: 0,
startTime: Date.now()
})
return true
} catch (error) {
// If we can't read the lock, try to create it
try {
await this.writeLockFile({
pid: process.pid,
port: this.tunnelConfig.srcPort,
refCount: 0,
startTime: Date.now()
})
return true
} catch {
return false
}
}
}
private async releaseLock(): Promise<void> {
try {
await fs.unlink(TUNNEL_LOCK_FILE)
await fs.unlink(REFS_FILE)
} catch (error) {
// Files might not exist
}
}
private async readLockFile(): Promise<TunnelState | null> {
try {
const data = await fs.readFile(TUNNEL_LOCK_FILE, 'utf-8')
return JSON.parse(data)
} catch {
return null
}
}
private async writeLockFile(state: TunnelState): Promise<void> {
await fs.writeFile(TUNNEL_LOCK_FILE, JSON.stringify(state, null, 2))
}
private async isProcessRunning(pid: number): Promise<boolean> {
try {
// Send signal 0 to check if process exists
process.kill(pid, 0)
return true
} catch {
return false
}
}
private async isTunnelPortOpen(port: number, maxAttempts = 30): Promise<boolean> {
for (let i = 0; i < maxAttempts; i++) {
try {
const isOpen = await new Promise<boolean>((resolve) => {
const socket = new net.Socket()
socket.once('connect', () => {
socket.end()
resolve(true)
})
socket.once('error', () => {
resolve(false)
})
socket.connect(port, '127.0.0.1')
})
if (isOpen) return true
} catch {
// Continue trying
}
// Wait 1 second before trying again
await new Promise(resolve => setTimeout(resolve, 1000))
}
return false
}
private async startTunnel(): Promise<void> {
return new Promise((resolve, reject) => {
const ssh = new SSHClient()
ssh.on('ready', () => {
console.error('SSH connection established (tunnel manager)')
this.tunnelServer = net.createServer((socket) => {
ssh.forwardOut(
socket.remoteAddress ?? '',
socket.remotePort ?? 0,
this.tunnelConfig.dstHost,
this.tunnelConfig.dstPort,
(err, stream) => {
if (err) {
socket.end()
return console.error('Forward error:', err)
}
socket.pipe(stream).pipe(socket)
}
)
})
this.tunnelServer.listen(this.tunnelConfig.srcPort, '127.0.0.1', () => {
console.error('SSH tunnel established on port', this.tunnelConfig.srcPort)
this.sshConnection = ssh
resolve()
})
this.tunnelServer.on('error', (err) => {
console.error('Tunnel server error:', err)
reject(err)
})
})
ssh.on('error', (err) => {
console.error('SSH error:', err)
reject(err)
})
ssh.connect(this.sshConfig)
})
}
private async waitForTunnel(): Promise<void> {
console.error('Waiting for existing SSH tunnel...')
const lockData = await this.readLockFile()
if (!lockData) {
throw new Error('No tunnel lock file found')
}
const isOpen = await this.isTunnelPortOpen(lockData.port)
if (!isOpen) {
throw new Error('Tunnel port is not open after waiting')
}
console.error('Connected to existing SSH tunnel on port', lockData.port)
}
private async addReference(): Promise<void> {
const refs = await this.readReferences()
refs[this.instanceId] = {
pid: process.pid,
startTime: Date.now()
}
await this.writeReferences(refs)
// Update ref count in lock file
const lockData = await this.readLockFile()
if (lockData) {
lockData.refCount = Object.keys(refs).length
await this.writeLockFile(lockData)
}
}
private async removeReference(): Promise<boolean> {
const refs = await this.readReferences()
delete refs[this.instanceId]
// Clean up dead references
for (const [id, ref] of Object.entries(refs)) {
if (!await this.isProcessRunning(ref.pid)) {
delete refs[id]
}
}
await this.writeReferences(refs)
// Update ref count in lock file
const lockData = await this.readLockFile()
if (lockData) {
lockData.refCount = Object.keys(refs).length
await this.writeLockFile(lockData)
}
// Return true if this was the last reference
return Object.keys(refs).length === 0
}
private async readReferences(): Promise<Record<string, { pid: number; startTime: number }>> {
try {
const data = await fs.readFile(REFS_FILE, 'utf-8')
return JSON.parse(data)
} catch {
return {}
}
}
private async writeReferences(refs: Record<string, { pid: number; startTime: number }>): Promise<void> {
await fs.writeFile(REFS_FILE, JSON.stringify(refs, null, 2))
}
}