import { WebSocket } from 'ws';
import { BaseProtocol } from '../core/BaseProtocol.js';
import {
ConsoleSession,
SessionOptions,
ConsoleType,
ConsoleOutput,
} from '../types/index.js';
import {
ProtocolCapabilities,
SessionState as BaseSessionState,
ErrorContext,
} from '../core/IProtocol.js';
// Azure SDK imports - made optional to handle missing dependencies
let DefaultAzureCredential: any,
ClientSecretCredential: any,
ManagedIdentityCredential: any,
ChainedTokenCredential: any,
AzureCliCredential: any,
InteractiveBrowserCredential: any,
ClientCertificateCredential: any;
try {
const identityModule = require('@azure/identity');
DefaultAzureCredential = identityModule.DefaultAzureCredential;
ClientSecretCredential = identityModule.ClientSecretCredential;
ManagedIdentityCredential = identityModule.ManagedIdentityCredential;
ChainedTokenCredential = identityModule.ChainedTokenCredential;
AzureCliCredential = identityModule.AzureCliCredential;
InteractiveBrowserCredential = identityModule.InteractiveBrowserCredential;
ClientCertificateCredential = identityModule.ClientCertificateCredential;
} catch (error) {
console.warn(
'@azure/identity not available, Azure identity functionality will be disabled'
);
}
let ComputeManagementClient: any,
NetworkManagementClient: any,
SecretClient: any;
try {
const computeModule = require('@azure/arm-compute');
ComputeManagementClient = computeModule.ComputeManagementClient;
} catch (error) {
console.warn(
'@azure/arm-compute not available, Compute management functionality will be disabled'
);
}
try {
const networkModule = require('@azure/arm-network');
NetworkManagementClient = networkModule.NetworkManagementClient;
} catch (error) {
console.warn(
'@azure/arm-network not available, Network management functionality will be disabled'
);
}
try {
const keyVaultModule = require('@azure/keyvault-secrets');
SecretClient = keyVaultModule.SecretClient;
} catch (error) {
console.warn(
'@azure/keyvault-secrets not available, Key Vault functionality will be disabled'
);
}
import {
AzureConnectionOptions,
AzureCloudShellSession,
AzureBastionSession,
AzureArcSession,
AzureTokenInfo,
AzureResourceInfo,
} from '../types/index.js';
// Azure API Response Interfaces
interface AzureResourceResponse {
id: string;
type: string;
location: string;
name: string;
properties: any;
}
interface AzureCloudShellCreateResponse {
properties: {
uri: string;
};
}
interface AzureCloudShellConnectResponse {
properties: {
socketUri: string;
};
}
interface AzureBastionConnectionResponse {
value: Array<{
bsl: string;
}>;
}
interface AzureArcConnectionResponse {
properties: {
connectionDetails: {
socketUri: string;
};
};
}
interface AzureHybridConnectionResponse {
hybridConnectionString: string;
}
// Type interfaces for Azure SDK types when not available
interface AzureCredentialLike {
getToken(scope: string | string[]): Promise<{
token: string;
expiresOnTimestamp: number;
}>;
}
interface ComputeManagementClientLike {
virtualMachines: {
get(resourceGroupName: string, vmName: string): Promise<any>;
};
}
interface NetworkManagementClientLike {
networkInterfaces: {
get(resourceGroupName: string, networkInterfaceName: string): Promise<any>;
};
}
interface SecretClientLike {
getSecret(secretName: string): Promise<{ value?: string }>;
}
// Azure-specific session state interface
interface AzureSessionState extends BaseSessionState {
tokenExpiry?: Date;
tokenValid?: boolean;
sessionType?: 'cloud-shell' | 'bastion' | 'arc';
webSocketConnected?: boolean;
}
export class AzureProtocol extends BaseProtocol {
public readonly type: ConsoleType = 'azure-shell';
public readonly capabilities: ProtocolCapabilities;
private azureSessions: Map<
string,
AzureCloudShellSession | AzureBastionSession | AzureArcSession
> = new Map();
private webSockets: Map<string, WebSocket> = new Map();
private credentials: Map<string, AzureCredentialLike> = new Map();
private computeClients: Map<string, ComputeManagementClientLike> = new Map();
private networkClients: Map<string, NetworkManagementClientLike> = new Map();
private secretClients: Map<string, SecretClientLike> = new Map();
private reconnectTimers: Map<string, NodeJS.Timeout> = new Map();
private tokenRefreshTimers: Map<string, NodeJS.Timeout> = new Map();
constructor() {
super('azure');
this.capabilities = {
supportsStreaming: true,
supportsFileTransfer: false,
supportsX11Forwarding: false,
supportsPortForwarding: true,
supportsAuthentication: true,
supportsEncryption: true,
supportsCompression: false,
supportsMultiplexing: false,
supportsKeepAlive: true,
supportsReconnection: true,
supportsBinaryData: false,
supportsCustomEnvironment: false,
supportsWorkingDirectory: false,
supportsSignals: false,
supportsResizing: true,
supportsPTY: true,
maxConcurrentSessions: 50,
defaultTimeout: 30000,
supportedEncodings: ['utf-8'],
supportedAuthMethods: ['oauth', 'service-principal', 'managed-identity'],
platformSupport: {
windows: true,
linux: true,
macos: true,
freebsd: true,
},
};
}
async initialize(): Promise<void> {
if (this.isInitialized) return;
try {
// Verify Azure SDK availability
if (!DefaultAzureCredential) {
this.logger.warn(
'Azure SDK not fully available, some features may be disabled'
);
}
this.isInitialized = true;
this.logger.info(
'Azure protocol initialized with session management fixes'
);
} catch (error: any) {
this.logger.error('Failed to initialize Azure protocol', error);
throw error;
}
}
async createSession(options: SessionOptions): Promise<ConsoleSession> {
if (!this.isInitialized) {
await this.initialize();
}
const sessionId = `azure-${Date.now()}-${Math.random().toString(36).substring(2, 11)}`;
const azureOptions = options as AzureConnectionOptions;
// Determine Azure service type
if (azureOptions.cloudShellType) {
const azureSession = await this.createCloudShellSession(
sessionId,
azureOptions
);
return await this.createAzureSessionWithDetection(
sessionId,
options,
azureSession
);
} else if (azureOptions.bastionResourceId) {
const azureSession = await this.createBastionSession(
sessionId,
azureOptions
);
return await this.createAzureSessionWithDetection(
sessionId,
options,
azureSession
);
} else if (azureOptions.arcResourceId) {
const azureSession = await this.createArcSession(sessionId, azureOptions);
return await this.createAzureSessionWithDetection(
sessionId,
options,
azureSession
);
} else {
throw new Error('Azure connection type not specified');
}
}
getSessionState(sessionId: string): Promise<BaseSessionState> {
const azureSession = this.azureSessions.get(sessionId);
const session = this.sessions.get(sessionId);
if (!azureSession || !session) {
return Promise.resolve({
sessionId,
status: 'failed',
isOneShot: false,
isPersistent: true,
createdAt: new Date(),
lastActivity: new Date(),
metadata: { error: 'Session not found' },
});
}
const webSocket = this.webSockets.get(sessionId);
const connected = webSocket?.readyState === WebSocket.OPEN;
const state: BaseSessionState = {
sessionId,
status: connected ? 'running' : 'stopped',
isOneShot: false,
isPersistent: true,
createdAt: session.createdAt,
lastActivity: session.lastActivity || new Date(),
metadata: {
tokenExpiry: azureSession.tokenExpiry,
tokenValid: azureSession.tokenExpiry.getTime() > Date.now(),
sessionType:
'webSocketUrl' in azureSession
? 'cloud-shell'
: 'bastionResourceId' in azureSession
? 'bastion'
: 'arc',
webSocketConnected: connected,
},
};
return Promise.resolve(state);
}
getActiveSessions(): ConsoleSession[] {
return Array.from(this.sessions.values());
}
async attemptReconnection(context: ErrorContext): Promise<boolean> {
try {
const sessionId = context.sessionId;
if (!sessionId) {
return false;
}
const azureSession = this.azureSessions.get(sessionId);
if (!azureSession) {
return false;
}
// Close existing WebSocket if present
const existingWs = this.webSockets.get(sessionId);
if (existingWs) {
existingWs.close();
this.webSockets.delete(sessionId);
}
// Re-establish WebSocket connection
await this.connectWebSocket(sessionId, azureSession);
return true;
} catch (error) {
this.logger.error(
`Failed to reconnect Azure session ${context.sessionId}:`,
error
);
return false;
}
}
async dispose(): Promise<void> {
await this.cleanup();
}
async executeCommand(
sessionId: string,
command: string,
args?: string[]
): Promise<void> {
const fullCommand =
args && args.length > 0 ? `${command} ${args.join(' ')}` : command;
await this.sendInput(sessionId, fullCommand + '\n');
}
async doCreateSession(
sessionId: string,
options: SessionOptions
): Promise<ConsoleSession> {
return await this.createSession(options);
}
async sendInput(sessionId: string, input: string): Promise<void> {
const webSocket = this.webSockets.get(sessionId);
if (!webSocket) {
throw new Error(
`No WebSocket connection found for session: ${sessionId}`
);
}
if (webSocket.readyState === WebSocket.OPEN) {
const message = {
type: 'input',
data: input,
};
webSocket.send(JSON.stringify(message));
this.logger.debug(
`Sent input to session ${sessionId}: ${input.substring(0, 100)}`
);
} else {
throw new Error(
`WebSocket connection is not open for session: ${sessionId}`
);
}
}
/**
* Helper method to create session with type detection and BaseProtocol integration
*/
private async createAzureSessionWithDetection(
sessionId: string,
options: SessionOptions,
azureSession: AzureCloudShellSession | AzureBastionSession | AzureArcSession
): Promise<ConsoleSession> {
this.azureSessions.set(sessionId, azureSession);
// Create BaseProtocol session
const session: ConsoleSession = {
id: sessionId,
command: 'azure-session',
args: [],
cwd: '',
env: {},
createdAt: new Date(),
lastActivity: new Date(),
status: 'initializing',
type: this.type,
executionState: 'idle',
activeCommands: new Map(),
};
this.sessions.set(sessionId, session);
// Set up event handlers for Azure-specific events
this.setupAzureEventHandlers(sessionId);
return session;
}
/**
* Set up Azure-specific event handlers
*/
private setupAzureEventHandlers(sessionId: string): void {
// Set up handlers for token refresh, WebSocket events, etc.
// This replaces the EventEmitter pattern with direct callback management
}
/**
* Create a new Azure Cloud Shell session
*/
async createCloudShellSession(
sessionId: string,
options: AzureConnectionOptions
): Promise<AzureCloudShellSession> {
try {
this.logger.info(`Creating Azure Cloud Shell session: ${sessionId}`);
// Get authentication credentials
const credential = await this.getCredential(options);
const token = await this.getAccessToken(
credential,
'https://management.azure.com/'
);
// Create Cloud Shell session via Azure API
const cloudShellUrl = await this.createCloudShellInstance(options, token);
const webSocketUrl = await this.getCloudShellWebSocketUrl(
cloudShellUrl,
token
);
const session: AzureCloudShellSession = {
sessionId,
webSocketUrl,
accessToken: token.accessToken,
refreshToken: token.refreshToken,
tokenExpiry: token.expiresOn,
shellType: options.cloudShellType || 'bash',
subscription: options.subscriptionId || '',
resourceGroup:
options.resourceGroupName || 'cloud-shell-storage-eastus',
location: options.region || 'eastus',
storageAccount: options.storageAccountName
? {
name: options.storageAccountName,
resourceGroup:
options.resourceGroupName || 'cloud-shell-storage-eastus',
fileShare:
options.fileShareName || 'cs-' + sessionId.substring(0, 8),
}
: undefined,
metadata: {},
};
this.azureSessions.set(sessionId, session);
this.credentials.set(sessionId, credential);
// Establish WebSocket connection
await this.connectWebSocket(sessionId, session);
// Schedule token refresh
this.scheduleTokenRefresh(sessionId, session);
this.logger.info(
`Azure Cloud Shell session created successfully: ${sessionId}`
);
return session;
} catch (error) {
this.logger.error(
`Failed to create Azure Cloud Shell session: ${sessionId}`,
error
);
throw error;
}
}
/**
* Create a new Azure Bastion session
*/
async createBastionSession(
sessionId: string,
options: AzureConnectionOptions
): Promise<AzureBastionSession> {
try {
this.logger.info(`Creating Azure Bastion session: ${sessionId}`);
const credential = await this.getCredential(options);
const token = await this.getAccessToken(
credential,
'https://management.azure.com/'
);
// Get Bastion and VM resource information
const bastionInfo = await this.getBastionResourceInfo(options, token);
const vmInfo = await this.getVmResourceInfo(options, token);
// Create Bastion connection
const connectionUrl = await this.createBastionConnection(
bastionInfo,
vmInfo,
options,
token
);
const session: AzureBastionSession = {
sessionId,
bastionResourceId: options.bastionResourceId || '',
targetVmResourceId: options.targetVmResourceId || '',
targetVmName: options.targetVmName || '',
protocol: options.protocol || 'ssh',
connectionUrl,
accessToken: token.accessToken,
tokenExpiry: token.expiresOn,
metadata: {
bastionName: bastionInfo.name,
vmName: vmInfo.name,
location: bastionInfo.location,
},
};
this.azureSessions.set(sessionId, session);
this.credentials.set(sessionId, credential);
// For SSH connections through Bastion, establish tunnel
if (options.protocol === 'ssh') {
await this.establishBastionTunnel(sessionId, session, options);
}
this.scheduleTokenRefresh(sessionId, session);
this.logger.info(
`Azure Bastion session created successfully: ${sessionId}`
);
return session;
} catch (error) {
this.logger.error(
`Failed to create Azure Bastion session: ${sessionId}`,
error
);
throw error;
}
}
/**
* Create a new Azure Arc session
*/
async createArcSession(
sessionId: string,
options: AzureConnectionOptions
): Promise<AzureArcSession> {
try {
this.logger.info(`Creating Azure Arc session: ${sessionId}`);
const credential = await this.getCredential(options);
const token = await this.getAccessToken(
credential,
'https://management.azure.com/'
);
// Get Arc resource information
const arcInfo = await this.getArcResourceInfo(options, token);
// Create hybrid connection
const connectionEndpoint = await this.createArcConnection(
arcInfo,
options,
token
);
const hybridConnectionString = await this.getHybridConnectionString(
arcInfo,
token
);
const session: AzureArcSession = {
sessionId,
arcResourceId: options.arcResourceId || '',
connectionEndpoint,
accessToken: token.accessToken,
tokenExpiry: token.expiresOn,
hybridConnectionString,
targetMachine: {
name: arcInfo.name,
osType: arcInfo.properties?.osName?.includes('Windows')
? 'Windows'
: 'Linux',
version: arcInfo.properties?.osVersion,
},
metadata: {
location: arcInfo.location,
resourceGroup: arcInfo.resourceGroup,
},
};
this.azureSessions.set(sessionId, session);
this.credentials.set(sessionId, credential);
// Establish Arc connection
await this.connectArcSession(sessionId, session);
this.scheduleTokenRefresh(sessionId, session);
this.logger.info(`Azure Arc session created successfully: ${sessionId}`);
return session;
} catch (error) {
this.logger.error(
`Failed to create Azure Arc session: ${sessionId}`,
error
);
throw error;
}
}
/**
* Resize terminal for a session
*/
async resizeTerminal(
sessionId: string,
rows: number,
cols: number
): Promise<void> {
const webSocket = this.webSockets.get(sessionId);
if (!webSocket) {
throw new Error(
`No WebSocket connection found for session: ${sessionId}`
);
}
if (webSocket.readyState === WebSocket.OPEN) {
const message = {
type: 'resize',
rows,
cols,
};
webSocket.send(JSON.stringify(message));
this.logger.debug(
`Resized terminal for session ${sessionId}: ${rows}x${cols}`
);
}
}
/**
* Close a session (overrides BaseProtocol)
*/
async closeSession(sessionId: string): Promise<void> {
try {
this.logger.info(`Closing Azure session: ${sessionId}`);
// Close WebSocket connection
const webSocket = this.webSockets.get(sessionId);
if (webSocket) {
webSocket.close();
this.webSockets.delete(sessionId);
}
// Clear timers
const reconnectTimer = this.reconnectTimers.get(sessionId);
if (reconnectTimer) {
clearTimeout(reconnectTimer);
this.reconnectTimers.delete(sessionId);
}
const refreshTimer = this.tokenRefreshTimers.get(sessionId);
if (refreshTimer) {
clearTimeout(refreshTimer);
this.tokenRefreshTimers.delete(sessionId);
}
// Clean up Azure session data
this.azureSessions.delete(sessionId);
this.credentials.delete(sessionId);
// Remove from base session map
this.sessions.delete(sessionId);
this.logger.info(`Azure session closed: ${sessionId}`);
} catch (error) {
this.logger.error(`Failed to close Azure session: ${sessionId}`, error);
throw error;
}
}
/**
* Get Azure session information
*/
getAzureSession(
sessionId: string
):
| AzureCloudShellSession
| AzureBastionSession
| AzureArcSession
| undefined {
return this.azureSessions.get(sessionId);
}
/**
* List all active Azure sessions
*/
listAzureSessions(): string[] {
return Array.from(this.azureSessions.keys());
}
/**
* Check if session is connected
*/
isConnected(sessionId: string): boolean {
const webSocket = this.webSockets.get(sessionId);
return webSocket?.readyState === WebSocket.OPEN;
}
/**
* Private methods
*/
private async getCredential(
options: AzureConnectionOptions
): Promise<AzureCredentialLike> {
if (!DefaultAzureCredential) {
throw new Error('@azure/identity package is required but not available');
}
if (options.managedIdentity) {
if (!ManagedIdentityCredential) {
throw new Error('ManagedIdentityCredential not available');
}
return new ManagedIdentityCredential(options.clientId);
}
if (options.clientId && options.clientSecret && options.tenantId) {
if (!ClientSecretCredential) {
throw new Error('ClientSecretCredential not available');
}
return new ClientSecretCredential(
options.tenantId,
options.clientId,
options.clientSecret
);
}
if (options.clientCertificatePath && options.tenantId && options.clientId) {
if (!ClientCertificateCredential) {
throw new Error('ClientCertificateCredential not available');
}
return new ClientCertificateCredential(
options.tenantId,
options.clientId,
options.clientCertificatePath
);
}
// Use chained credential for maximum compatibility
if (
!ChainedTokenCredential ||
!AzureCliCredential ||
!ManagedIdentityCredential
) {
// Fallback to DefaultAzureCredential if chained components not available
return new DefaultAzureCredential();
}
return new ChainedTokenCredential(
new AzureCliCredential(),
new ManagedIdentityCredential(),
new DefaultAzureCredential()
);
}
private async getAccessToken(
credential: AzureCredentialLike,
scope: string
): Promise<AzureTokenInfo> {
const tokenResponse = await credential.getToken(scope);
return {
accessToken: tokenResponse.token,
tokenType: 'Bearer',
expiresIn: Math.floor(
(tokenResponse.expiresOnTimestamp - Date.now()) / 1000
),
expiresOn: new Date(tokenResponse.expiresOnTimestamp),
scope: [scope],
tenantId: (credential as any).tenantId || '',
resource: scope,
authority: 'https://login.microsoftonline.com/',
};
}
private async createCloudShellInstance(
options: AzureConnectionOptions,
token: AzureTokenInfo
): Promise<string> {
const subscriptionId = options.subscriptionId;
const shellType = options.cloudShellType || 'bash';
const location = options.region || 'eastus';
// Azure Cloud Shell API endpoint
const cloudShellApiUrl = `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.Portal/consoles/default`;
const response = await fetch(cloudShellApiUrl, {
method: 'PUT',
headers: {
Authorization: `Bearer ${token.accessToken}`,
'Content-Type': 'application/json',
'x-ms-console-preferred-location': location,
},
body: JSON.stringify({
properties: {
osType: shellType === 'powershell' ? 'Windows' : 'Linux',
consoleDefinition: {
type: shellType,
},
},
}),
});
if (!response.ok) {
throw new Error(
`Failed to create Cloud Shell instance: ${response.statusText}`
);
}
const data = (await response.json()) as AzureCloudShellCreateResponse;
return data.properties.uri;
}
private async getCloudShellWebSocketUrl(
consoleUri: string,
token: AzureTokenInfo
): Promise<string> {
const response = await fetch(`${consoleUri}/connect`, {
method: 'POST',
headers: {
Authorization: `Bearer ${token.accessToken}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
properties: {
connectParams: {
osType: 'Linux',
shellType: 'bash',
},
},
}),
});
if (!response.ok) {
throw new Error(`Failed to get WebSocket URL: ${response.statusText}`);
}
const data = (await response.json()) as AzureCloudShellConnectResponse;
return data.properties.socketUri;
}
private async getBastionResourceInfo(
options: AzureConnectionOptions,
token: AzureTokenInfo
): Promise<AzureResourceInfo> {
const bastionResourceId = options.bastionResourceId;
if (!bastionResourceId) {
throw new Error('Bastion resource ID is required');
}
const response = await fetch(
`https://management.azure.com${bastionResourceId}?api-version=2021-02-01`,
{
headers: {
Authorization: `Bearer ${token.accessToken}`,
},
}
);
if (!response.ok) {
throw new Error(
`Failed to get Bastion resource info: ${response.statusText}`
);
}
const data = (await response.json()) as AzureResourceResponse;
return {
resourceId: data.id,
resourceType: data.type,
resourceGroup: data.id.split('/')[4],
subscriptionId: data.id.split('/')[2],
location: data.location,
name: data.name,
properties: data.properties,
};
}
private async getVmResourceInfo(
options: AzureConnectionOptions,
token: AzureTokenInfo
): Promise<AzureResourceInfo> {
const vmResourceId = options.targetVmResourceId;
if (!vmResourceId) {
throw new Error('Target VM resource ID is required');
}
const response = await fetch(
`https://management.azure.com${vmResourceId}?api-version=2021-03-01`,
{
headers: {
Authorization: `Bearer ${token.accessToken}`,
},
}
);
if (!response.ok) {
throw new Error(`Failed to get VM resource info: ${response.statusText}`);
}
const data = (await response.json()) as AzureResourceResponse;
return {
resourceId: data.id,
resourceType: data.type,
resourceGroup: data.id.split('/')[4],
subscriptionId: data.id.split('/')[2],
location: data.location,
name: data.name,
properties: data.properties,
};
}
private async createBastionConnection(
bastionInfo: AzureResourceInfo,
vmInfo: AzureResourceInfo,
options: AzureConnectionOptions,
token: AzureTokenInfo
): Promise<string> {
const protocol = options.protocol || 'ssh';
const bastionEndpoint = `https://management.azure.com${bastionInfo.resourceId}/createShareableLinks`;
const response = await fetch(bastionEndpoint, {
method: 'POST',
headers: {
Authorization: `Bearer ${token.accessToken}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
vms: [
{
vm: {
id: vmInfo.resourceId,
},
},
],
}),
});
if (!response.ok) {
throw new Error(
`Failed to create Bastion connection: ${response.statusText}`
);
}
const data = (await response.json()) as AzureBastionConnectionResponse;
return data.value[0]?.bsl || '';
}
private async getArcResourceInfo(
options: AzureConnectionOptions,
token: AzureTokenInfo
): Promise<AzureResourceInfo> {
const arcResourceId = options.arcResourceId;
if (!arcResourceId) {
throw new Error('Arc resource ID is required');
}
const response = await fetch(
`https://management.azure.com${arcResourceId}?api-version=2020-08-02`,
{
headers: {
Authorization: `Bearer ${token.accessToken}`,
},
}
);
if (!response.ok) {
throw new Error(
`Failed to get Arc resource info: ${response.statusText}`
);
}
const data = (await response.json()) as AzureResourceResponse;
return {
resourceId: data.id,
resourceType: data.type,
resourceGroup: data.id.split('/')[4],
subscriptionId: data.id.split('/')[2],
location: data.location,
name: data.name,
properties: data.properties,
};
}
private async createArcConnection(
arcInfo: AzureResourceInfo,
options: AzureConnectionOptions,
token: AzureTokenInfo
): Promise<string> {
// Create Arc connection endpoint
const endpoint = `https://management.azure.com${arcInfo.resourceId}/providers/Microsoft.HybridConnectivity/endpoints/default`;
const response = await fetch(endpoint, {
method: 'PUT',
headers: {
Authorization: `Bearer ${token.accessToken}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
properties: {
type: 'default',
resourceId: arcInfo.resourceId,
},
}),
});
if (!response.ok) {
throw new Error(
`Failed to create Arc connection: ${response.statusText}`
);
}
const data = (await response.json()) as AzureArcConnectionResponse;
return data.properties.connectionDetails.socketUri;
}
private async getHybridConnectionString(
arcInfo: AzureResourceInfo,
token: AzureTokenInfo
): Promise<string> {
const endpoint = `https://management.azure.com${arcInfo.resourceId}/providers/Microsoft.HybridConnectivity/endpoints/default/listCredentials`;
const response = await fetch(endpoint, {
method: 'POST',
headers: {
Authorization: `Bearer ${token.accessToken}`,
'Content-Type': 'application/json',
},
});
if (!response.ok) {
throw new Error(
`Failed to get hybrid connection string: ${response.statusText}`
);
}
const data = (await response.json()) as AzureHybridConnectionResponse;
return data.hybridConnectionString;
}
private async connectWebSocket(
sessionId: string,
session: AzureCloudShellSession | AzureBastionSession | AzureArcSession
): Promise<void> {
return new Promise((resolve, reject) => {
try {
const webSocketUrl =
(session as AzureCloudShellSession).webSocketUrl ||
(session as AzureArcSession).connectionEndpoint;
const webSocket = new WebSocket(webSocketUrl, {
headers: {
Authorization: `Bearer ${session.accessToken}`,
},
});
webSocket.on('open', () => {
this.logger.info(`WebSocket connected for session: ${sessionId}`);
this.webSockets.set(sessionId, webSocket);
this.emit('connected', sessionId);
resolve();
});
webSocket.on('message', (data: Buffer | ArrayBuffer | Buffer[]) => {
try {
const message = Buffer.isBuffer(data)
? data.toString()
: data instanceof ArrayBuffer
? Buffer.from(data).toString()
: Buffer.concat(data as Buffer[]).toString();
const output: ConsoleOutput = {
sessionId,
type: 'stdout',
data: message,
timestamp: new Date(),
raw: message,
};
this.addToOutputBuffer(sessionId, output);
} catch (error) {
this.logger.error(
`Error processing WebSocket message for session ${sessionId}:`,
error
);
}
});
webSocket.on('error', (error) => {
this.logger.error(`WebSocket error for session ${sessionId}:`, error);
this.emit('error', { sessionId, error });
reject(error);
});
webSocket.on('close', (code, reason) => {
this.logger.info(
`WebSocket closed for session ${sessionId}: ${code} - ${reason}`
);
this.webSockets.delete(sessionId);
// Attempt reconnection if not intentionally closed
if (code !== 1000 && this.azureSessions.has(sessionId)) {
this.scheduleReconnect(sessionId);
}
});
} catch (error) {
reject(error);
}
});
}
private async establishBastionTunnel(
sessionId: string,
session: AzureBastionSession,
options: AzureConnectionOptions
): Promise<void> {
// For Bastion SSH connections, we need to establish a tunnel
// This is a simplified implementation - in practice, you'd use the Bastion API
// to create a proper SSH tunnel
const localPort = 2222 + Math.floor(Math.random() * 1000);
session.tunnelEndpoint = `localhost:${localPort}`;
session.portForwarding = {
localPort,
remotePort: 22,
remoteHost: session.targetVmName,
};
this.logger.info(
`Bastion SSH tunnel established for session ${sessionId} on port ${localPort}`
);
}
private async connectArcSession(
sessionId: string,
session: AzureArcSession
): Promise<void> {
// For Arc sessions, establish the hybrid connection
await this.connectWebSocket(sessionId, session);
}
private scheduleTokenRefresh(
sessionId: string,
session: AzureCloudShellSession | AzureBastionSession | AzureArcSession
): void {
const refreshTime =
session.tokenExpiry.getTime() - Date.now() - 5 * 60 * 1000; // 5 minutes before expiry
if (refreshTime > 0) {
const timer = setTimeout(async () => {
try {
await this.refreshToken(sessionId);
} catch (error) {
this.logger?.error(
`Failed to refresh token for session ${sessionId}:`,
error
);
this.emit('error', sessionId, error as Error);
}
}, refreshTime);
this.tokenRefreshTimers.set(sessionId, timer);
}
}
private async refreshToken(sessionId: string): Promise<void> {
const session = this.azureSessions.get(sessionId);
const credential = this.credentials.get(sessionId);
if (!session || !credential) {
throw new Error(`Session or credential not found: ${sessionId}`);
}
const newToken = await this.getAccessToken(
credential,
'https://management.azure.com/'
);
// Update session with new token
session.accessToken = newToken.accessToken;
session.tokenExpiry = newToken.expiresOn;
// Update session activity in BaseProtocol
const baseSession = this.sessions.get(sessionId);
if (baseSession) {
baseSession.lastActivity = new Date();
}
this.scheduleTokenRefresh(sessionId, session);
this.logger.info(`Token refreshed for session: ${sessionId}`);
}
private scheduleReconnect(sessionId: string, attempt: number = 1): void {
const maxAttempts = 5;
const baseDelay = 1000;
const delay = baseDelay * Math.pow(2, attempt - 1); // Exponential backoff
if (attempt > maxAttempts) {
this.logger.error(
`Max reconnection attempts reached for session: ${sessionId}`
);
this.emit('error', {
sessionId,
error: new Error('Max reconnection attempts reached'),
});
return;
}
this.logger.info(
`Attempting reconnection for session ${sessionId}, attempt ${attempt}`
);
const timer = setTimeout(async () => {
try {
const session = this.azureSessions.get(sessionId);
if (session) {
await this.connectWebSocket(sessionId, session);
this.logger.info(
`Reconnected session ${sessionId} on attempt ${attempt}`
);
}
} catch (error) {
this.logger.error(
`Reconnection attempt ${attempt} failed for session ${sessionId}:`,
error
);
this.scheduleReconnect(sessionId, attempt + 1);
}
}, delay);
this.reconnectTimers.set(sessionId, timer);
}
/**
* Get secrets from Azure Key Vault
*/
private async getKeyVaultSecret(
keyVaultUrl: string,
secretName: string,
credential: AzureCredentialLike
): Promise<string> {
if (!SecretClient) {
throw new Error(
'@azure/keyvault-secrets package is required but not available'
);
}
const secretClient = new SecretClient(keyVaultUrl, credential);
const secret = await secretClient.getSecret(secretName);
return secret.value || '';
}
/**
* Health check for Azure sessions
*/
async healthCheck(sessionId: string): Promise<boolean> {
const session = this.azureSessions.get(sessionId);
if (!session) {
return false;
}
const webSocket = this.webSockets.get(sessionId);
if (!webSocket || webSocket.readyState !== WebSocket.OPEN) {
return false;
}
// Check if token is still valid
const tokenExpiry = session.tokenExpiry.getTime();
const now = Date.now();
if (tokenExpiry <= now) {
try {
await this.refreshToken(sessionId);
} catch {
return false;
}
}
return true;
}
/**
* Get session metrics
*/
getSessionMetrics(sessionId: string): Record<string, any> {
const session = this.azureSessions.get(sessionId);
const webSocket = this.webSockets.get(sessionId);
if (!session) {
return {};
}
return {
sessionId,
connected: webSocket?.readyState === WebSocket.OPEN,
tokenExpiry: session.tokenExpiry,
tokenValid: session.tokenExpiry.getTime() > Date.now(),
sessionType:
'webSocketUrl' in session
? 'cloud-shell'
: 'bastionResourceId' in session
? 'bastion'
: 'arc',
metadata: session.metadata,
};
}
/**
* Cleanup all sessions
*/
async cleanup(): Promise<void> {
const sessionIds = Array.from(this.azureSessions.keys());
await Promise.all(sessionIds.map((id) => this.closeSession(id)));
this.azureSessions.clear();
this.webSockets.clear();
this.credentials.clear();
this.computeClients.clear();
this.networkClients.clear();
this.secretClients.clear();
this.reconnectTimers.clear();
this.tokenRefreshTimers.clear();
}
}