using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.SqlServer.Dac;
using Microsoft.SqlServer.Dac.CodeAnalysis;
using Microsoft.SqlServer.Dac.Extensibility;
using Microsoft.SqlServer.Dac.Model;
using Microsoft.SqlServer.TransactSql.ScriptDom;
using Microsoft.VisualStudio.TestTools.UnitTesting;
namespace SqlServer.Rules.Tests.Utils;
/// <summary>
/// Runs a test against the <see cref="CodeAnalysisService"/> - initializes a model,
/// runs analysis and then performs some verification action. This class could be extended to
/// output a results file and compare this to a baseline.
/// </summary>
internal class RuleTest : IDisposable
{
private DisposableList trash;
/// <summary>
/// What type of target should the test run against? Dacpacs are not backed by scripts, so
/// the model generated from them will be different from a scripted model. In particular the
/// <see cref="TSqlFragment"/>s generated by calling <see cref="SqlRuleExecutionContext.ScriptFragment"/>
/// will be generated from the model instead of representing the script contents, or may return null if
/// the <see cref="TSqlObject"/> is not a top-level type.
/// </summary>
public enum AnalysisTarget
{
PublicModel,
DacpacModel,
Database,
}
public RuleTest(IList<Tuple<string, string>> testScripts, TSqlModelOptions databaseOptions, SqlServerVersion sqlVersion)
{
trash = new DisposableList();
TestScripts = testScripts;
DatabaseOptions = databaseOptions ?? new TSqlModelOptions();
SqlVersion = sqlVersion;
}
public void Dispose()
{
if (trash != null)
{
trash.Dispose();
trash = null;
}
}
/// <summary>
/// List of tuples representing scripts and the logical source name for those scripts.
/// </summary>
public IList<Tuple<string, string>> TestScripts { get; set; }
/// <summary>
/// Update the DatabaseOptions if you wish to test with different properties, such as a different collation.
/// </summary>
public TSqlModelOptions DatabaseOptions { get; set; }
/// <summary>
/// Version to target the model at - the model will be compiled against that server version, and rules that do not
/// support that version will be ignored
/// </summary>
public SqlServerVersion SqlVersion { get; set; }
public AnalysisTarget Target { get; set; }
public TSqlModel ModelForAnalysis { get; set; }
public string DacpacPath { get; set; }
public string DatabaseName { get; set; }
protected void CreateModelUsingTestScripts()
{
switch (Target)
{
case AnalysisTarget.Database:
ModelForAnalysis = CreateDatabaseModel();
break;
case AnalysisTarget.DacpacModel:
var scriptedModel = CreateScriptedModel();
ModelForAnalysis = CreateDacpacModel(scriptedModel);
scriptedModel.Dispose();
break;
default:
ModelForAnalysis = CreateScriptedModel();
break;
}
trash.Add(ModelForAnalysis);
}
private TSqlModel CreateScriptedModel()
{
var model = new TSqlModel(SqlVersion, DatabaseOptions);
AddScriptsToModel(model);
AssertModelValid(model);
// Used to load the model from a dacpac, letting us use LoadAsScriptBackedModel option
// string fileName = $"{Path.GetTempFileName()}.dacpac";
// DacPackageExtensions.BuildPackage(fileName, model, new PackageMetadata());
// model = TSqlModel.LoadFromDacpac(
// fileName
// , new ModelLoadOptions()
// {
// LoadAsScriptBackedModel = true,
// ModelStorageType = Microsoft.SqlServer.Dac.DacSchemaModelStorageType.Memory
// });
return model;
}
private static void AssertModelValid(TSqlModel model)
{
var breakingIssuesFound = false;
var validationMessages = model.Validate();
if (validationMessages.Count > 0)
{
Console.WriteLine("Issues found during model build:");
foreach (var message in validationMessages)
{
Console.WriteLine("\t" + message.Message);
breakingIssuesFound = breakingIssuesFound || message.MessageType == DacMessageType.Error;
}
}
Assert.IsFalse(breakingIssuesFound, "Cannot run analysis if there are model errors");
}
/// <summary>
/// Deploys test scripts to a database and creates a model directly against this DB.
/// Since this is a RuleTest we load the model as script backed to ensure that we have file names,
/// source code positions, and that programmability objects (stored procedures, views) have a full SQLDOM
/// syntax tree instead of just a snippet.
/// </summary>
private TSqlModel CreateDatabaseModel()
{
ArgumentValidation.CheckForEmptyString(DatabaseName, "DatabaseName");
var db = TestUtils.CreateTestDatabase(TestUtils.DefaultInstanceInfo, DatabaseName);
trash.Add(db);
TestUtils.ExecuteNonQuery(db, TestScripts.Select(t => t.Item1).SelectMany(TestUtils.GetBatches).ToList());
var model = TSqlModel.LoadFromDatabase(db.BuildConnectionString(), new ModelExtractOptions { LoadAsScriptBackedModel = true });
AssertModelValid(model);
return model;
}
/// <summary>
/// Builds a dacpac and returns the path to that dacpac.
/// If the file already exists it will be deleted
/// </summary>
private string BuildDacpacFromModel(TSqlModel model)
{
var path = DacpacPath;
Assert.IsFalse(string.IsNullOrWhiteSpace(DacpacPath), "DacpacPath must be set if target for analysis is a Dacpac");
if (File.Exists(path))
{
File.Delete(path);
}
var dacpacDir = Path.GetDirectoryName(path);
if (dacpacDir != null && !Directory.Exists(dacpacDir))
{
Directory.CreateDirectory(dacpacDir);
}
DacPackageExtensions.BuildPackage(path, model, new PackageMetadata());
return path;
}
/// <summary>
/// Creates a new Dacpac file on disk and returns the model from this. If the file exists already it will be deleted.
///
/// The generated model will be automatically disposed when the ModelManager is disposed
/// </summary>
private TSqlModel CreateDacpacModel(TSqlModel model)
{
var dacpacPath = BuildDacpacFromModel(model);
// Note: when running Code Analysis most rules expect a scripted model. Use the
// static factory method on TSqlModel class to ensure you have scripts. If you
// didn't do this some rules would still work as expected, some would not, and
// a warning message would be included in the AnalysisErrors in the result.
return TSqlModel.LoadFromDacpac(
dacpacPath,
new ModelLoadOptions(DacSchemaModelStorageType.Memory, loadAsScriptBackedModel: true));
}
protected void AddScriptsToModel(TSqlModel model)
{
foreach (var tuple in TestScripts)
{
// Item1 = script, Item2 = (logical) source file name
model.AddOrUpdateObjects(tuple.Item1, tuple.Item2, new TSqlObjectOptions());
}
}
/// <summary>
/// RunTest for multiple scripts.
/// </summary>
/// <param name="fullId">ID of the single rule to be run. All other rules will be disabled</param>
/// <param name="verify">Action that runs verification on the result of analysis</param>
public virtual void RunTest(string fullId, Action<CodeAnalysisResult, string> verify)
{
ArgumentNullException.ThrowIfNull(fullId);
ArgumentNullException.ThrowIfNull(verify);
CreateModelUsingTestScripts();
var service = CreateCodeAnalysisService(fullId);
RunRulesAndVerifyResult(service, verify);
}
/// <summary>
/// Sets up the service and disables all rules except the rule you wish to test.
///
/// If you want all rules to run then do not change the
/// <see cref="CodeAnalysisRuleSettings.DisableRulesNotInSettings"/> flag, as it is set to "false" by default which
/// ensures that all rules are run.
///
/// To run some (but not all) of the built-in rules then you could query the
/// <see cref="CodeAnalysisService.GetRules"/> method to get a list of all the rules, then set their
/// <see cref="RuleConfiguration.Enabled"/> and other flags as needed, or alternatively call the
/// <see cref="CodeAnalysisService.ApplyRuleSettings"/> method to apply whatever rule settings you wish
///
/// </summary>
private CodeAnalysisService CreateCodeAnalysisService(string ruleIdToRun)
{
var factory = new CodeAnalysisServiceFactory();
var ruleSettings = new CodeAnalysisRuleSettings
{
new RuleConfiguration(ruleIdToRun),
};
ruleSettings.DisableRulesNotInSettings = true;
var service = factory.CreateAnalysisService(ModelForAnalysis.Version, new CodeAnalysisServiceSettings
{
RuleSettings = ruleSettings,
});
DumpErrors(service.GetRuleLoadErrors());
Assert.IsTrue(
service.GetRules().Any(rule => rule.RuleId.Equals(ruleIdToRun, StringComparison.OrdinalIgnoreCase)),
"Expected rule '{0}' not found by the service",
ruleIdToRun);
return service;
}
private void RunRulesAndVerifyResult(CodeAnalysisService service, Action<CodeAnalysisResult, string> verify)
{
var analysisResult = service.Analyze(ModelForAnalysis);
// Only considering analysis errors for now - might want to expand to initialization and suppression errors in the future
DumpErrors(analysisResult.AnalysisErrors);
var problemsString = DumpProblemsToString(analysisResult.Problems);
verify(analysisResult, problemsString);
}
private static void DumpErrors(IList<ExtensibilityError> errors)
{
if (errors.Count > 0)
{
var hasError = false;
var errorMessage = new StringBuilder();
errorMessage.AppendLine("Errors found:");
foreach (var error in errors)
{
hasError = true;
if (error.Document != null)
{
errorMessage.AppendFormat(CultureInfo.InvariantCulture, "{0}({1}, {2}): ", error.Document, error.Line, error.Column);
}
errorMessage.AppendLine(error.Message);
}
if (hasError)
{
Assert.Fail(errorMessage.ToString());
}
}
}
private string DumpProblemsToString(IEnumerable<SqlRuleProblem> problems)
{
var displayServices = ModelForAnalysis.DisplayServices;
List<SqlRuleProblem> problemList = [..problems];
SortProblemsByFileName(problemList);
var sb = new StringBuilder();
foreach (var problem in problemList)
{
AppendOneProblemItem(sb, "Problem description", problem.Description);
AppendOneProblemItem(sb, "FullID", problem.RuleId);
AppendOneProblemItem(sb, "Severity", problem.Severity.ToString());
AppendOneProblemItem(sb, "Model element", displayServices.GetElementName(problem.ModelElement, ElementNameStyle.FullyQualifiedName));
string fileName;
if (problem.SourceName != null)
{
var fileInfo = new FileInfo(problem.SourceName);
fileName = fileInfo.Name;
}
else
{
fileName = string.Empty;
}
AppendOneProblemItem(sb, "Script file", fileName);
AppendOneProblemItem(sb, "Start line", problem.StartLine.ToString(CultureInfo.InvariantCulture));
AppendOneProblemItem(sb, "Start column", problem.StartColumn.ToString(CultureInfo.InvariantCulture));
sb.Append("========end of problem========\r\n\r\n");
}
return sb.ToString();
}
private static void AppendOneProblemItem(StringBuilder sb, string name, string content)
{
sb.AppendLine(string.Format(CultureInfo.CurrentCulture, "{0}: {1}", name, content));
}
public static void SortProblemsByFileName(List<SqlRuleProblem> problemList)
{
problemList.Sort(new ProblemComparer());
}
private sealed class ProblemComparer : IComparer<SqlRuleProblem>
{
public int Compare(SqlRuleProblem x, SqlRuleProblem y)
{
var compare = string.Compare(x.SourceName, y.SourceName, StringComparison.OrdinalIgnoreCase);
if (compare == 0)
{
compare = x.StartLine - y.StartLine;
if (compare == 0)
{
compare = x.StartColumn - y.StartColumn;
}
}
return compare;
}
}
}