using System.Text;
using System.Text.RegularExpressions;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using SqlAugur.Configuration;
using static SqlAugur.Services.SchemaQueryHelper;
namespace SqlAugur.Services;
public sealed partial class DiagramService : IDiagramService
{
private readonly SqlAugurOptions _options;
private readonly ILogger<DiagramService> _logger;
public DiagramService(
IOptions<SqlAugurOptions> options,
ILogger<DiagramService> logger)
{
_options = options.Value;
_logger = logger;
}
internal sealed record ColumnInfo(
string Schema, string TableName, string ColumnName, string DataType,
int MaxLength, byte Precision, byte Scale, bool IsNullable,
bool IsPrimaryKey, bool IsIdentity);
internal sealed record ForeignKeyInfo(
string FkName, string FkSchema, string FkTable, string FkColumn,
string RefSchema, string RefTable, string RefColumn,
bool IsNullable, bool IsUnique);
public async Task<string> GenerateDiagramAsync(string serverName, string databaseName,
IReadOnlyList<string>? includeSchemas, IReadOnlyList<string>? excludeSchemas,
IReadOnlyList<string>? includeTables, IReadOnlyList<string>? excludeTables,
int maxTables, CancellationToken cancellationToken, bool compact = false)
{
var serverConfig = _options.ResolveServer(serverName);
_logger.LogInformation("Generating diagram on server {Server} database {Database} schemas {Schemas} excluding {ExcludeSchemas} tables {Tables} excluding {ExcludeTables}",
serverName, databaseName,
includeSchemas is { Count: > 0 } ? string.Join(",", includeSchemas) : "all",
excludeSchemas is { Count: > 0 } ? string.Join(",", excludeSchemas) : "none",
includeTables is { Count: > 0 } ? string.Join(",", includeTables) : "all",
excludeTables is { Count: > 0 } ? string.Join(",", excludeTables) : "none");
await using var connection = new SqlConnection(serverConfig.ConnectionString);
await connection.OpenAsync(cancellationToken);
await connection.ChangeDatabaseAsync(databaseName, cancellationToken);
var tables = await QueryTablesAsync(connection, includeSchemas, excludeSchemas, includeTables, excludeTables, maxTables, _options.CommandTimeoutSeconds, cancellationToken);
if (tables.Count == 0)
return GenerateEmptyDiagram(serverName, databaseName, includeSchemas);
var columns = await QueryColumnsAsync(connection, tables, cancellationToken);
var foreignKeys = await QueryForeignKeysAsync(connection, tables, cancellationToken);
return BuildPlantUml(serverName, databaseName, includeSchemas, maxTables, tables, columns, foreignKeys, compact);
}
private async Task<List<ColumnInfo>> QueryColumnsAsync(SqlConnection connection,
List<TableInfo> tables, CancellationToken cancellationToken)
{
var (cteSql, cteParams) = BuildTableFilterCte(tables);
var sql = cteSql + """
SELECT
c.TABLE_SCHEMA,
c.TABLE_NAME,
c.COLUMN_NAME,
c.DATA_TYPE,
COALESCE(c.CHARACTER_MAXIMUM_LENGTH, 0) AS MaxLength,
CAST(COALESCE(c.NUMERIC_PRECISION, 0) AS tinyint) AS [Precision],
CAST(COALESCE(c.NUMERIC_SCALE, 0) AS tinyint) AS Scale,
CASE WHEN c.IS_NULLABLE = 'YES' THEN 1 ELSE 0 END AS IsNullable,
CASE WHEN ixc.column_id IS NOT NULL THEN 1 ELSE 0 END AS IsPrimaryKey,
sc.is_identity AS IsIdentity
FROM INFORMATION_SCHEMA.COLUMNS c
INNER JOIN table_filter dt ON dt.SchemaName = c.TABLE_SCHEMA AND dt.TableName = c.TABLE_NAME
INNER JOIN sys.columns sc
ON sc.object_id = OBJECT_ID(QUOTENAME(c.TABLE_SCHEMA) + '.' + QUOTENAME(c.TABLE_NAME))
AND sc.name = c.COLUMN_NAME
LEFT JOIN (
SELECT ic.object_id, ic.column_id
FROM sys.index_columns ic
INNER JOIN sys.indexes ix ON ix.object_id = ic.object_id AND ix.index_id = ic.index_id
WHERE ix.is_primary_key = 1
) ixc ON ixc.object_id = OBJECT_ID(QUOTENAME(c.TABLE_SCHEMA) + '.' + QUOTENAME(c.TABLE_NAME))
AND ixc.column_id = sc.column_id
ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
""";
await using var cmd = new SqlCommand(sql, connection)
{
CommandTimeout = _options.CommandTimeoutSeconds
};
cmd.Parameters.AddRange(cteParams);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
var columns = new List<ColumnInfo>();
while (await reader.ReadAsync(cancellationToken))
{
columns.Add(new ColumnInfo(
Schema: reader.GetString(0),
TableName: reader.GetString(1),
ColumnName: reader.GetString(2),
DataType: reader.GetString(3),
MaxLength: reader.GetInt32(4),
Precision: reader.GetByte(5),
Scale: reader.GetByte(6),
IsNullable: reader.GetInt32(7) == 1,
IsPrimaryKey: reader.GetInt32(8) == 1,
IsIdentity: reader.GetBoolean(9)));
}
return columns;
}
private async Task<List<ForeignKeyInfo>> QueryForeignKeysAsync(SqlConnection connection,
List<TableInfo> tables, CancellationToken cancellationToken)
{
var (cteSql, cteParams) = BuildTableFilterCte(tables);
var sql = cteSql + """
SELECT
fk.name AS FkName,
SCHEMA_NAME(fk_tab.schema_id) AS FkSchema,
fk_tab.name AS FkTable,
fk_col.name AS FkColumn,
SCHEMA_NAME(ref_tab.schema_id) AS RefSchema,
ref_tab.name AS RefTable,
ref_col.name AS RefColumn,
sc.is_nullable AS IsNullable,
CASE WHEN EXISTS (
SELECT 1 FROM sys.index_columns uic
INNER JOIN sys.indexes uix ON uix.object_id = uic.object_id AND uix.index_id = uic.index_id
WHERE uix.is_unique = 1
AND uic.object_id = fkc.parent_object_id
AND uic.column_id = fkc.parent_column_id
AND NOT EXISTS (
SELECT 1 FROM sys.index_columns uic2
WHERE uic2.object_id = uic.object_id AND uic2.index_id = uic.index_id
AND uic2.column_id <> uic.column_id
)
) THEN 1 ELSE 0 END AS IsUnique
FROM sys.foreign_keys fk
INNER JOIN sys.foreign_key_columns fkc ON fkc.constraint_object_id = fk.object_id
INNER JOIN sys.tables fk_tab ON fk_tab.object_id = fkc.parent_object_id
INNER JOIN sys.columns fk_col ON fk_col.object_id = fkc.parent_object_id AND fk_col.column_id = fkc.parent_column_id
INNER JOIN sys.tables ref_tab ON ref_tab.object_id = fkc.referenced_object_id
INNER JOIN sys.columns ref_col ON ref_col.object_id = fkc.referenced_object_id AND ref_col.column_id = fkc.referenced_column_id
INNER JOIN sys.columns sc ON sc.object_id = fkc.parent_object_id AND sc.column_id = fkc.parent_column_id
INNER JOIN table_filter dt_fk
ON dt_fk.SchemaName = SCHEMA_NAME(fk_tab.schema_id) AND dt_fk.TableName = fk_tab.name
INNER JOIN table_filter dt_ref
ON dt_ref.SchemaName = SCHEMA_NAME(ref_tab.schema_id) AND dt_ref.TableName = ref_tab.name
ORDER BY fk.name, fkc.constraint_column_id
""";
await using var cmd = new SqlCommand(sql, connection)
{
CommandTimeout = _options.CommandTimeoutSeconds
};
cmd.Parameters.AddRange(cteParams);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
var foreignKeys = new List<ForeignKeyInfo>();
while (await reader.ReadAsync(cancellationToken))
{
foreignKeys.Add(new ForeignKeyInfo(
FkName: reader.GetString(0),
FkSchema: reader.GetString(1),
FkTable: reader.GetString(2),
FkColumn: reader.GetString(3),
RefSchema: reader.GetString(4),
RefTable: reader.GetString(5),
RefColumn: reader.GetString(6),
IsNullable: reader.GetBoolean(7),
IsUnique: reader.GetInt32(8) == 1));
}
return foreignKeys;
}
internal static string BuildPlantUml(string serverName, string databaseName,
IReadOnlyList<string>? includeSchemas, int maxTables, List<TableInfo> tables,
List<ColumnInfo> columns, List<ForeignKeyInfo> foreignKeys, bool compact = false)
{
var sb = new StringBuilder();
sb.AppendLine("@startuml");
sb.AppendLine($"' ER Diagram: {SanitizePlantUmlText(databaseName)} on {SanitizePlantUmlText(serverName)}");
sb.Append("' Schema: ");
sb.Append(SanitizePlantUmlText(includeSchemas is { Count: > 0 } ? string.Join(",", includeSchemas) : "all"));
sb.AppendLine($" | Tables: {tables.Count}");
sb.AppendLine();
sb.AppendLine("skinparam linetype ortho");
sb.AppendLine();
// Build a lookup of FK columns for marking
var fkColumns = new HashSet<(string Schema, string Table, string Column)>();
foreach (var fk in foreignKeys)
fkColumns.Add((fk.FkSchema, fk.FkTable, fk.FkColumn));
// Group columns by table
var columnsByTable = columns
.GroupBy(c => new TableInfo(c.Schema, c.TableName))
.ToDictionary(g => g.Key, g => g.ToList());
foreach (var table in tables)
{
var displayName = table.Schema == "dbo"
? SanitizePlantUmlText(table.Name)
: $"{SanitizePlantUmlText(table.Schema)}.{SanitizePlantUmlText(table.Name)}";
var alias = SanitizeAlias($"{table.Schema}_{table.Name}");
sb.AppendLine($"entity \"{displayName}\" as {alias} {{");
if (!columnsByTable.TryGetValue(table, out var tableCols))
{
sb.AppendLine("}");
sb.AppendLine();
continue;
}
var pkCols = tableCols.Where(c => c.IsPrimaryKey).ToList();
var nonPkCols = tableCols.Where(c => !c.IsPrimaryKey).ToList();
if (compact)
{
// PK columns — name + stereotypes only, no data type
foreach (var col in pkCols)
{
var stereotypes = "<<PK>>";
if (fkColumns.Contains((col.Schema, col.TableName, col.ColumnName)))
stereotypes += " <<FK>>";
if (col.IsIdentity)
stereotypes += " <<IDENTITY>>";
sb.AppendLine($" * {SanitizePlantUmlText(col.ColumnName)} {stereotypes}");
}
// FK columns only (non-PK columns that are foreign keys)
var fkNonPkCols = nonPkCols.Where(c =>
fkColumns.Contains((c.Schema, c.TableName, c.ColumnName))).ToList();
if (pkCols.Count > 0 && fkNonPkCols.Count > 0)
sb.AppendLine(" --");
foreach (var col in fkNonPkCols)
{
var prefix = col.IsNullable ? " " : " *";
sb.AppendLine($"{prefix} {SanitizePlantUmlText(col.ColumnName)} <<FK>>");
}
}
else
{
// PK columns above separator
foreach (var col in pkCols)
{
var stereotypes = "<<PK>>";
if (col.IsIdentity)
stereotypes += " <<IDENTITY>>";
sb.AppendLine($" * {SanitizePlantUmlText(col.ColumnName)} : {FormatDataType(col.DataType, col.MaxLength, col.Precision, col.Scale)} {stereotypes}");
}
if (pkCols.Count > 0 || nonPkCols.Count > 0)
sb.AppendLine(" --");
// Non-PK columns
foreach (var col in nonPkCols)
{
var prefix = col.IsNullable ? " " : " *";
var stereotype = fkColumns.Contains((col.Schema, col.TableName, col.ColumnName))
? " <<FK>>"
: "";
sb.Append(prefix);
sb.Append(' ');
sb.Append(SanitizePlantUmlText(col.ColumnName));
sb.Append(" : ");
sb.Append(FormatDataType(col.DataType, col.MaxLength, col.Precision, col.Scale));
sb.AppendLine(stereotype);
}
}
sb.AppendLine("}");
sb.AppendLine();
}
// Relationships, deduplicated by FK name
var emittedFks = new HashSet<string>();
foreach (var fk in foreignKeys)
{
if (!emittedFks.Add(fk.FkName))
continue; // Skip composite FK duplicates (already emitted)
var refAlias = SanitizeAlias($"{fk.RefSchema}_{fk.RefTable}");
var fkAlias = SanitizeAlias($"{fk.FkSchema}_{fk.FkTable}");
// Determine cardinality
string refSide = "||"; // referenced (parent) side is always mandatory one
string fkSide;
if (fk.IsUnique)
fkSide = fk.IsNullable ? "o|" : "||"; // one-to-one
else
fkSide = fk.IsNullable ? "o{" : "|{"; // one-to-many
sb.AppendLine($"{refAlias} {refSide}--{fkSide} {fkAlias} : \"{SanitizePlantUmlText(fk.FkName)}\"");
}
if (tables.Count >= maxTables)
{
sb.AppendLine();
sb.AppendLine($"' WARNING: Output truncated at {maxTables} tables. Increase maxTables to see more.");
}
sb.AppendLine();
sb.AppendLine("@enduml");
return sb.ToString();
}
public async Task<string> GenerateMermaidDiagramAsync(string serverName, string databaseName,
IReadOnlyList<string>? includeSchemas, IReadOnlyList<string>? excludeSchemas,
IReadOnlyList<string>? includeTables, IReadOnlyList<string>? excludeTables,
int maxTables, CancellationToken cancellationToken, bool compact = false)
{
var serverConfig = _options.ResolveServer(serverName);
_logger.LogInformation("Generating Mermaid diagram on server {Server} database {Database} schemas {Schemas} excluding {ExcludeSchemas} tables {Tables} excluding {ExcludeTables}",
serverName, databaseName,
includeSchemas is { Count: > 0 } ? string.Join(",", includeSchemas) : "all",
excludeSchemas is { Count: > 0 } ? string.Join(",", excludeSchemas) : "none",
includeTables is { Count: > 0 } ? string.Join(",", includeTables) : "all",
excludeTables is { Count: > 0 } ? string.Join(",", excludeTables) : "none");
await using var connection = new SqlConnection(serverConfig.ConnectionString);
await connection.OpenAsync(cancellationToken);
await connection.ChangeDatabaseAsync(databaseName, cancellationToken);
var tables = await QueryTablesAsync(connection, includeSchemas, excludeSchemas, includeTables, excludeTables, maxTables, _options.CommandTimeoutSeconds, cancellationToken);
if (tables.Count == 0)
return GenerateEmptyMermaidDiagram(serverName, databaseName, includeSchemas);
var columns = await QueryColumnsAsync(connection, tables, cancellationToken);
var foreignKeys = await QueryForeignKeysAsync(connection, tables, cancellationToken);
return BuildMermaid(serverName, databaseName, includeSchemas, maxTables, tables, columns, foreignKeys, compact);
}
internal static string BuildMermaid(string serverName, string databaseName,
IReadOnlyList<string>? includeSchemas, int maxTables, List<TableInfo> tables,
List<ColumnInfo> columns, List<ForeignKeyInfo> foreignKeys, bool compact = false)
{
var sb = new StringBuilder();
sb.AppendLine("---");
sb.AppendLine($"title: \"ER Diagram: {SanitizeMermaidText(databaseName)} on {SanitizeMermaidText(serverName)}\"");
sb.AppendLine("---");
sb.AppendLine("erDiagram");
// Build a lookup of FK columns for marking
var fkColumns = new HashSet<(string Schema, string Table, string Column)>();
foreach (var fk in foreignKeys)
fkColumns.Add((fk.FkSchema, fk.FkTable, fk.FkColumn));
// Group columns by table
var columnsByTable = columns
.GroupBy(c => new TableInfo(c.Schema, c.TableName))
.ToDictionary(g => g.Key, g => g.ToList());
foreach (var table in tables)
{
var entityName = table.Schema == "dbo"
? SanitizeMermaidEntity(table.Name)
: $"{SanitizeMermaidEntity(table.Schema)}__{SanitizeMermaidEntity(table.Name)}";
if (!columnsByTable.TryGetValue(table, out var tableCols))
{
sb.AppendLine($" {entityName} {{");
sb.AppendLine(" }");
continue;
}
sb.AppendLine($" {entityName} {{");
if (compact)
{
// PK columns (use _ placeholder type since Mermaid requires a type token)
var pkCols = tableCols.Where(c => c.IsPrimaryKey).ToList();
foreach (var col in pkCols)
{
var marker = "PK";
if (fkColumns.Contains((col.Schema, col.TableName, col.ColumnName)))
marker = "PK,FK";
sb.AppendLine($" _ {SanitizeMermaidEntity(col.ColumnName)} {marker}");
}
// FK columns only (non-PK)
var fkNonPkCols = tableCols.Where(c => !c.IsPrimaryKey &&
fkColumns.Contains((c.Schema, c.TableName, c.ColumnName))).ToList();
foreach (var col in fkNonPkCols)
{
sb.AppendLine($" _ {SanitizeMermaidEntity(col.ColumnName)} FK");
}
}
else
{
foreach (var col in tableCols)
{
var dataType = FormatDataType(col.DataType, col.MaxLength, col.Precision, col.Scale);
// Mermaid doesn't support parentheses in types — replace with underscores
dataType = dataType.Replace("(", "_").Replace(")", "").Replace(",", "_");
string marker;
if (col.IsPrimaryKey && fkColumns.Contains((col.Schema, col.TableName, col.ColumnName)))
marker = " PK,FK";
else if (col.IsPrimaryKey)
marker = " PK";
else if (fkColumns.Contains((col.Schema, col.TableName, col.ColumnName)))
marker = " FK";
else
marker = "";
sb.AppendLine($" {SanitizeMermaidEntity(dataType)} {SanitizeMermaidEntity(col.ColumnName)}{marker}");
}
}
sb.AppendLine(" }");
}
// Relationships, deduplicated by FK name
var emittedFks = new HashSet<string>();
foreach (var fk in foreignKeys)
{
if (!emittedFks.Add(fk.FkName))
continue;
var refEntity = fk.RefSchema == "dbo"
? SanitizeMermaidEntity(fk.RefTable)
: $"{SanitizeMermaidEntity(fk.RefSchema)}__{SanitizeMermaidEntity(fk.RefTable)}";
var fkEntity = fk.FkSchema == "dbo"
? SanitizeMermaidEntity(fk.FkTable)
: $"{SanitizeMermaidEntity(fk.FkSchema)}__{SanitizeMermaidEntity(fk.FkTable)}";
// Determine cardinality notation
string refSide = "||"; // referenced (parent) side is always mandatory one
string fkSide;
if (fk.IsUnique)
fkSide = fk.IsNullable ? "o|" : "||"; // one-to-one
else
fkSide = fk.IsNullable ? "o{" : "|{"; // one-to-many
sb.AppendLine($" {refEntity} {refSide}--{fkSide} {fkEntity} : \"{SanitizeMermaidText(fk.FkName)}\"");
}
if (tables.Count >= maxTables)
{
sb.AppendLine();
sb.AppendLine($" %% WARNING: Output truncated at {maxTables} tables. Increase maxTables to see more.");
}
return sb.ToString();
}
internal static string GenerateEmptyMermaidDiagram(string serverName, string databaseName, IReadOnlyList<string>? includeSchemas)
{
var schemaDisplay = includeSchemas is { Count: > 0 } ? string.Join(",", includeSchemas) : "all";
var sb = new StringBuilder();
sb.AppendLine("---");
sb.AppendLine($"title: \"ER Diagram: {SanitizeMermaidText(databaseName)} on {SanitizeMermaidText(serverName)}\"");
sb.AppendLine("---");
sb.AppendLine("erDiagram");
sb.AppendLine($" %% Schema: {SanitizeMermaidText(schemaDisplay)} | Tables: 0");
sb.AppendLine(" %% No tables found");
return sb.ToString();
}
internal static string GenerateEmptyDiagram(string serverName, string databaseName, IReadOnlyList<string>? includeSchemas)
{
var schemaDisplay = includeSchemas is { Count: > 0 } ? string.Join(",", includeSchemas) : "all";
var sb = new StringBuilder();
sb.AppendLine("@startuml");
sb.AppendLine($"' ER Diagram: {SanitizePlantUmlText(databaseName)} on {SanitizePlantUmlText(serverName)}");
sb.AppendLine($"' Schema: {SanitizePlantUmlText(schemaDisplay)} | Tables: 0");
sb.AppendLine();
sb.AppendLine("note \"No tables found\" as N1");
sb.AppendLine();
sb.AppendLine("@enduml");
return sb.ToString();
}
internal static string SanitizeAlias(string input)
=> AliasRegex().Replace(input, "_");
/// <summary>
/// Strips characters that could break out of PlantUML comments or inject directives.
/// </summary>
internal static string SanitizePlantUmlText(string input)
=> PlantUmlUnsafeChars().Replace(input, "");
/// <summary>
/// Strips characters that could break Mermaid syntax.
/// </summary>
internal static string SanitizeMermaidText(string input)
=> MermaidUnsafeChars().Replace(input, "");
/// <summary>
/// Sanitizes text for use as a Mermaid entity name (no spaces, special chars).
/// </summary>
internal static string SanitizeMermaidEntity(string input)
=> MermaidEntityUnsafeChars().Replace(input, "_");
[GeneratedRegex(@"[.\s\-]")]
private static partial Regex AliasRegex();
[GeneratedRegex(@"[\r\n@""{}!]")]
private static partial Regex PlantUmlUnsafeChars();
[GeneratedRegex(@"[\r\n""{}%;]")]
private static partial Regex MermaidUnsafeChars();
[GeneratedRegex(@"[^a-zA-Z0-9_]")]
private static partial Regex MermaidEntityUnsafeChars();
}