VirtualDesktopService.cs•7.81 kB
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using Azure.ResourceManager.DesktopVirtualization;
using AzureMcp.Areas.VirtualDesktop.Models;
using AzureMcp.Core.Options;
using AzureMcp.Core.Services.Azure.Subscription;
namespace AzureMcp.Areas.VirtualDesktop.Services;
public class VirtualDesktopService(ISubscriptionService subscriptionService) : IVirtualDesktopService
{
    private readonly ISubscriptionService _subscriptionService = subscriptionService;
    public async Task<IReadOnlyList<HostPool>> ListHostpoolsAsync(string subscription, string? tenant = null, RetryPolicyOptions? retryPolicy = null)
    {
        var sub = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy);
        var hostpools = new List<HostPool>();
        await foreach (HostPoolResource resource in sub.GetHostPoolsAsync())
        {
            hostpools.Add(new HostPool(resource));
        }
        return hostpools;
    }
    public async Task<IReadOnlyList<HostPool>> ListHostpoolsByResourceGroupAsync(string subscription, string resourceGroup, string? tenant = null, RetryPolicyOptions? retryPolicy = null)
    {
        var sub = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy);
        var hostpools = new List<HostPool>();
        var resourceGroupResource = await sub.GetResourceGroupAsync(resourceGroup);
        await foreach (HostPoolResource resource in resourceGroupResource.Value.GetHostPools().GetAllAsync())
        {
            hostpools.Add(new HostPool(resource));
        }
        return hostpools;
    }
    public async Task<IReadOnlyList<SessionHost>> ListSessionHostsAsync(string subscription, string hostPoolName, string? tenant = null, RetryPolicyOptions? retryPolicy = null)
    {
        var sub = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy);
        var sessionHosts = new List<SessionHost>();
        await foreach (HostPoolResource resource in sub.GetHostPoolsAsync())
        {
            if (resource.Data.Name == hostPoolName)
            {
                var armClient = sub.GetCachedClient(client => client);
                var hostPool = armClient.GetHostPoolResource(resource.Id);
                await foreach (SessionHostResource sessionHost in hostPool.GetSessionHosts().GetAllAsync())
                {
                    sessionHosts.Add(new SessionHost(sessionHost));
                }
                break; // Found the host pool, no need to continue
            }
        }
        return sessionHosts;
    }
    public async Task<IReadOnlyList<UserSession>> ListUserSessionsAsync(string subscription, string hostPoolName, string sessionHostName, string? tenant = null, RetryPolicyOptions? retryPolicy = null)
    {
        var sub = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy);
        var userSessions = new List<UserSession>();
        await foreach (HostPoolResource resource in sub.GetHostPoolsAsync())
        {
            if (resource.Data.Name == hostPoolName)
            {
                var armClient = sub.GetCachedClient(client => client);
                var hostPool = armClient.GetHostPoolResource(resource.Id);
                await foreach (SessionHostResource sessionHost in hostPool.GetSessionHosts().GetAllAsync())
                {
                    if (sessionHost.Data.Name == sessionHostName || sessionHost.Data.Name == $"{hostPoolName}/{sessionHostName}")
                    {
                        await foreach (UserSessionResource userSession in sessionHost.GetUserSessions().GetAllAsync())
                        {
                            userSessions.Add(new UserSession(userSession));
                        }
                        break; // Found the session host, no need to continue
                    }
                }
                break; // Found the host pool, no need to continue
            }
        }
        return userSessions;
    }
    public async Task<IReadOnlyList<SessionHost>> ListSessionHostsByResourceIdAsync(string subscription, string hostPoolResourceId, string? tenant = null, RetryPolicyOptions? retryPolicy = null)
    {
        var sub = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy);
        var sessionHosts = new List<SessionHost>();
        var armClient = sub.GetCachedClient(client => client);
        var hostPool = armClient.GetHostPoolResource(Azure.Core.ResourceIdentifier.Parse(hostPoolResourceId));
        await foreach (SessionHostResource sessionHost in hostPool.GetSessionHosts().GetAllAsync())
        {
            sessionHosts.Add(new SessionHost(sessionHost));
        }
        return sessionHosts;
    }
    public async Task<IReadOnlyList<UserSession>> ListUserSessionsByResourceIdAsync(string subscription, string hostPoolResourceId, string sessionHostName, string? tenant = null, RetryPolicyOptions? retryPolicy = null)
    {
        var sub = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy);
        var userSessions = new List<UserSession>();
        var armClient = sub.GetCachedClient(client => client);
        var hostPool = armClient.GetHostPoolResource(Azure.Core.ResourceIdentifier.Parse(hostPoolResourceId));
        await foreach (SessionHostResource sessionHost in hostPool.GetSessionHosts().GetAllAsync())
        {
            if (sessionHost.Data.Name == sessionHostName || sessionHost.Data.Name.EndsWith($"/{sessionHostName}"))
            {
                await foreach (UserSessionResource userSession in sessionHost.GetUserSessions().GetAllAsync())
                {
                    userSessions.Add(new UserSession(userSession));
                }
                break; // Found the session host, no need to continue
            }
        }
        return userSessions;
    }
    public async Task<IReadOnlyList<SessionHost>> ListSessionHostsByResourceGroupAsync(string subscription, string resourceGroup, string hostPoolName, string? tenant = null, RetryPolicyOptions? retryPolicy = null)
    {
        var sub = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy);
        var sessionHosts = new List<SessionHost>();
        var resourceGroupResource = await sub.GetResourceGroupAsync(resourceGroup);
        var hostPool = await resourceGroupResource.Value.GetHostPoolAsync(hostPoolName);
        await foreach (SessionHostResource sessionHost in hostPool.Value.GetSessionHosts().GetAllAsync())
        {
            sessionHosts.Add(new SessionHost(sessionHost));
        }
        return sessionHosts;
    }
    public async Task<IReadOnlyList<UserSession>> ListUserSessionsByResourceGroupAsync(string subscription, string resourceGroup, string hostPoolName, string sessionHostName, string? tenant = null, RetryPolicyOptions? retryPolicy = null)
    {
        var sub = await _subscriptionService.GetSubscription(subscription, tenant, retryPolicy);
        var userSessions = new List<UserSession>();
        var resourceGroupResource = await sub.GetResourceGroupAsync(resourceGroup);
        var hostPool = await resourceGroupResource.Value.GetHostPoolAsync(hostPoolName);
        await foreach (SessionHostResource sessionHost in hostPool.Value.GetSessionHosts().GetAllAsync())
        {
            if (sessionHost.Data.Name == sessionHostName || sessionHost.Data.Name.EndsWith($"/{sessionHostName}"))
            {
                await foreach (UserSessionResource userSession in sessionHost.GetUserSessions().GetAllAsync())
                {
                    userSessions.Add(new UserSession(userSession));
                }
                break; // Found the session host, no need to continue
            }
        }
        return userSessions;
    }
}