| | | 1 | | using System.Collections.Concurrent; |
| | | 2 | | using System.Collections.Immutable; |
| | | 3 | | using System.Net.Security; |
| | | 4 | | using System.Reflection; |
| | | 5 | | using System.Security.Cryptography.X509Certificates; |
| | | 6 | | using System.Text; |
| | | 7 | | using Kestrun.Hosting; |
| | | 8 | | using Kestrun.Scripting; |
| | | 9 | | using Microsoft.CodeAnalysis; |
| | | 10 | | using Microsoft.CodeAnalysis.CSharp; |
| | | 11 | | using Microsoft.CodeAnalysis.VisualBasic; |
| | | 12 | | |
| | | 13 | | namespace Kestrun.Certificates; |
| | | 14 | | |
| | | 15 | | /// <summary> |
| | | 16 | | /// Compiles C# or VB.NET code into a TLS client certificate validation callback. |
| | | 17 | | /// </summary> |
| | | 18 | | /// <remarks> |
| | | 19 | | /// This is intended for advanced scenarios where a pure .NET delegate is required (e.g. Kestrel TLS handshake callbacks |
| | | 20 | | /// The compiled delegate executes inside the Kestrel TLS handshake path, so it must be fast and thread-safe. |
| | | 21 | | /// </remarks> |
| | | 22 | | public static class ClientCertificateValidationCompiler |
| | | 23 | | { |
| | 1 | 24 | | private static readonly ConcurrentDictionary<string, Lazy<Func<X509Certificate2, X509Chain, SslPolicyErrors, bool>>> |
| | | 25 | | |
| | | 26 | | /// <summary> |
| | | 27 | | /// Compiles code into a TLS client certificate validation callback. |
| | | 28 | | /// </summary> |
| | | 29 | | /// <param name="host">The Kestrun host (used for logging).</param> |
| | | 30 | | /// <param name="code"> |
| | | 31 | | /// The code that forms the body of a method returning <c>bool</c>. |
| | | 32 | | /// The method signature is: |
| | | 33 | | /// <c>bool Validate(X509Certificate2 certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)</c>. |
| | | 34 | | /// </param> |
| | | 35 | | /// <param name="language">The language used for <paramref name="code"/>.</param> |
| | | 36 | | /// <returns>A compiled callback delegate.</returns> |
| | | 37 | | public static Func<X509Certificate2, X509Chain, SslPolicyErrors, bool> Compile( |
| | | 38 | | KestrunHost host, |
| | | 39 | | string code, |
| | | 40 | | ScriptLanguage language = ScriptLanguage.CSharp) |
| | | 41 | | { |
| | 11 | 42 | | ArgumentNullException.ThrowIfNull(host); |
| | 10 | 43 | | if (string.IsNullOrWhiteSpace(code)) |
| | | 44 | | { |
| | 4 | 45 | | throw new ArgumentNullException(nameof(code), "Client certificate validation code cannot be null or whitespa |
| | | 46 | | } |
| | | 47 | | |
| | 6 | 48 | | var cacheKey = BuildCacheKey(language, code); |
| | 11 | 49 | | var lazy = Cache.GetOrAdd(cacheKey, _ => new Lazy<Func<X509Certificate2, X509Chain, SslPolicyErrors, bool>>( |
| | 16 | 50 | | () => CompileCore(host, code, language), isThreadSafe: true)); |
| | | 51 | | |
| | 6 | 52 | | return lazy.Value; |
| | | 53 | | } |
| | | 54 | | |
| | | 55 | | private static string BuildCacheKey(ScriptLanguage language, string code) |
| | 6 | 56 | | => ((int)language).ToString(System.Globalization.CultureInfo.InvariantCulture) + ":" + code; |
| | | 57 | | |
| | | 58 | | private static Func<X509Certificate2, X509Chain, SslPolicyErrors, bool> CompileCore( |
| | | 59 | | KestrunHost host, |
| | | 60 | | string code, |
| | | 61 | | ScriptLanguage language) |
| | | 62 | | { |
| | 5 | 63 | | return language switch |
| | 5 | 64 | | { |
| | 3 | 65 | | ScriptLanguage.CSharp => CompileCSharp(host, code), |
| | 1 | 66 | | ScriptLanguage.VBNet => CompileVbNet(host, code), |
| | 1 | 67 | | _ => throw new NotSupportedException($"ClientCertificateValidation supports only CSharp and VBNet, not {lang |
| | 5 | 68 | | }; |
| | | 69 | | } |
| | | 70 | | |
| | | 71 | | private static Func<X509Certificate2, X509Chain, SslPolicyErrors, bool> CompileCSharp(KestrunHost host, string code) |
| | | 72 | | { |
| | 3 | 73 | | var source = WrapCSharp(code); |
| | 3 | 74 | | var startLine = GetStartLine(source, "// ---- User code starts here ----"); |
| | | 75 | | |
| | 3 | 76 | | var parseOptions = new CSharpParseOptions(Microsoft.CodeAnalysis.CSharp.LanguageVersion.CSharp12); |
| | 3 | 77 | | var tree = CSharpSyntaxTree.ParseText(source, parseOptions); |
| | | 78 | | |
| | 3 | 79 | | var refs = BuildMetadataReferences(includeVisualBasicRuntime: false); |
| | 3 | 80 | | var compilation = CSharpCompilation.Create( |
| | 3 | 81 | | assemblyName: $"TlsClientCertValidation_{Guid.NewGuid():N}", |
| | 3 | 82 | | syntaxTrees: [tree], |
| | 3 | 83 | | references: refs, |
| | 3 | 84 | | options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); |
| | | 85 | | |
| | 3 | 86 | | using var ms = new MemoryStream(); |
| | 3 | 87 | | var emit = compilation.Emit(ms); |
| | 3 | 88 | | ThrowIfErrors(emit.Diagnostics, startLine, host, languageLabel: "C#"); |
| | | 89 | | |
| | 2 | 90 | | ms.Position = 0; |
| | 2 | 91 | | return LoadCallbackDelegate(ms.ToArray(), typeName: "ClientCertValidationScript", methodName: "Validate"); |
| | 2 | 92 | | } |
| | | 93 | | |
| | | 94 | | private static Func<X509Certificate2, X509Chain, SslPolicyErrors, bool> CompileVbNet(KestrunHost host, string code) |
| | | 95 | | { |
| | 1 | 96 | | var source = WrapVbNet(code); |
| | 1 | 97 | | var startLine = GetStartLine(source, "' ---- User code starts here ----"); |
| | | 98 | | |
| | 1 | 99 | | var parseOptions = new VisualBasicParseOptions(Microsoft.CodeAnalysis.VisualBasic.LanguageVersion.VisualBasic16_ |
| | 1 | 100 | | var tree = VisualBasicSyntaxTree.ParseText(source, parseOptions); |
| | | 101 | | |
| | 1 | 102 | | var refs = BuildMetadataReferences(includeVisualBasicRuntime: true); |
| | 1 | 103 | | var compilation = VisualBasicCompilation.Create( |
| | 1 | 104 | | assemblyName: $"TlsClientCertValidation_{Guid.NewGuid():N}", |
| | 1 | 105 | | syntaxTrees: [tree], |
| | 1 | 106 | | references: refs, |
| | 1 | 107 | | options: new VisualBasicCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); |
| | | 108 | | |
| | 1 | 109 | | using var ms = new MemoryStream(); |
| | 1 | 110 | | var emit = compilation.Emit(ms); |
| | 1 | 111 | | ThrowIfErrors(emit.Diagnostics, startLine, host, languageLabel: "VB.NET"); |
| | | 112 | | |
| | 1 | 113 | | ms.Position = 0; |
| | 1 | 114 | | return LoadCallbackDelegate(ms.ToArray(), typeName: "ClientCertValidationScript", methodName: "Validate"); |
| | 1 | 115 | | } |
| | | 116 | | |
| | | 117 | | private static IEnumerable<MetadataReference> BuildMetadataReferences(bool includeVisualBasicRuntime) |
| | | 118 | | { |
| | | 119 | | // Baseline references (includes X509 and many common assemblies) |
| | 4 | 120 | | var baseRefs = DelegateBuilder.BuildBaselineReferences(); |
| | | 121 | | |
| | | 122 | | // Ensure the assembly containing SslPolicyErrors is referenced. |
| | 4 | 123 | | var netSecurityAsm = typeof(SslPolicyErrors).Assembly; |
| | 4 | 124 | | var netSecurityRef = string.IsNullOrWhiteSpace(netSecurityAsm.Location) |
| | 4 | 125 | | ? null |
| | 4 | 126 | | : MetadataReference.CreateFromFile(netSecurityAsm.Location); |
| | | 127 | | |
| | | 128 | | // Add already-loaded assemblies to improve binding for "using" namespaces. |
| | 4 | 129 | | var loaded = AppDomain.CurrentDomain.GetAssemblies() |
| | 940 | 130 | | .Where(a => !a.IsDynamic && SafeHasLocation(a)) |
| | 919 | 131 | | .Select(a => MetadataReference.CreateFromFile(a.Location)); |
| | | 132 | | |
| | 4 | 133 | | IEnumerable<MetadataReference> refs = baseRefs; |
| | 4 | 134 | | if (netSecurityRef is not null) |
| | | 135 | | { |
| | 4 | 136 | | refs = refs.Append(netSecurityRef); |
| | | 137 | | } |
| | | 138 | | |
| | 4 | 139 | | refs = refs.Concat(loaded); |
| | | 140 | | |
| | 4 | 141 | | if (includeVisualBasicRuntime) |
| | | 142 | | { |
| | 1 | 143 | | refs = refs.Append(MetadataReference.CreateFromFile(typeof(Microsoft.VisualBasic.Constants).Assembly.Locatio |
| | | 144 | | } |
| | | 145 | | |
| | 4 | 146 | | return refs; |
| | | 147 | | } |
| | | 148 | | |
| | | 149 | | private static bool SafeHasLocation(Assembly a) |
| | | 150 | | { |
| | | 151 | | try |
| | | 152 | | { |
| | 929 | 153 | | var loc = a.Location; |
| | 929 | 154 | | return !string.IsNullOrEmpty(loc) && File.Exists(loc); |
| | | 155 | | } |
| | 0 | 156 | | catch |
| | | 157 | | { |
| | 0 | 158 | | return false; |
| | | 159 | | } |
| | 929 | 160 | | } |
| | | 161 | | |
| | | 162 | | private static Func<X509Certificate2, X509Chain, SslPolicyErrors, bool> LoadCallbackDelegate(byte[] asmBytes, string |
| | | 163 | | { |
| | 3 | 164 | | var asm = Assembly.Load(asmBytes); |
| | 3 | 165 | | var method = asm.GetType(typeName, throwOnError: true)! |
| | 3 | 166 | | .GetMethod(methodName, BindingFlags.Public | BindingFlags.Static) |
| | 3 | 167 | | ?? throw new MissingMethodException(typeName, methodName); |
| | | 168 | | |
| | 3 | 169 | | return (Func<X509Certificate2, X509Chain, SslPolicyErrors, bool>)method |
| | 3 | 170 | | .CreateDelegate(typeof(Func<X509Certificate2, X509Chain, SslPolicyErrors, bool>)); |
| | | 171 | | } |
| | | 172 | | |
| | | 173 | | private static void ThrowIfErrors(ImmutableArray<Diagnostic> diagnostics, int startLine, KestrunHost host, string la |
| | | 174 | | { |
| | 8 | 175 | | var errors = diagnostics.Where(d => d.Severity == DiagnosticSeverity.Error).ToArray(); |
| | 4 | 176 | | if (errors.Length == 0) |
| | | 177 | | { |
| | 3 | 178 | | return; |
| | | 179 | | } |
| | | 180 | | |
| | 1 | 181 | | host.Logger.Error("{Lang} client certificate validation compilation completed with {Count} error(s).", languageL |
| | | 182 | | |
| | 1 | 183 | | var sb = new StringBuilder(); |
| | 1 | 184 | | _ = sb.AppendLine($"{languageLabel} client certificate validation compilation failed:"); |
| | 4 | 185 | | foreach (var error in errors) |
| | | 186 | | { |
| | 1 | 187 | | var location = error.Location.IsInSource |
| | 1 | 188 | | ? $" at line {error.Location.GetLineSpan().StartLinePosition.Line - startLine + 1}" |
| | 1 | 189 | | : string.Empty; |
| | 1 | 190 | | var msg = $" Error [{error.Id}]: {error.GetMessage()}{location}"; |
| | 1 | 191 | | host.Logger.Error(msg); |
| | 1 | 192 | | _ = sb.AppendLine(msg); |
| | | 193 | | } |
| | | 194 | | |
| | 1 | 195 | | throw new CompilationErrorException(sb.ToString().TrimEnd(), diagnostics); |
| | | 196 | | } |
| | | 197 | | |
| | | 198 | | private static int GetStartLine(string source, string marker) |
| | | 199 | | { |
| | 4 | 200 | | var idx = source.IndexOf(marker, StringComparison.Ordinal); |
| | 4 | 201 | | if (idx < 0) |
| | | 202 | | { |
| | 0 | 203 | | return 0; |
| | | 204 | | } |
| | | 205 | | |
| | 4 | 206 | | var line = 0; |
| | 2180 | 207 | | for (var i = 0; i < idx; i++) |
| | | 208 | | { |
| | 1086 | 209 | | if (source[i] == '\n') |
| | | 210 | | { |
| | 30 | 211 | | line++; |
| | | 212 | | } |
| | | 213 | | } |
| | | 214 | | |
| | 4 | 215 | | return line; |
| | | 216 | | } |
| | | 217 | | |
| | | 218 | | private static string WrapCSharp(string code) |
| | 3 | 219 | | => $$""" |
| | 3 | 220 | | using System; |
| | 3 | 221 | | using System.Net.Security; |
| | 3 | 222 | | using System.Security.Cryptography.X509Certificates; |
| | 3 | 223 | | |
| | 3 | 224 | | public static class ClientCertValidationScript |
| | 3 | 225 | | { |
| | 3 | 226 | | public static bool Validate(X509Certificate2 certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) |
| | 3 | 227 | | { |
| | 3 | 228 | | // ---- User code starts here ---- |
| | 3 | 229 | | {{Indent(code, 8)}} |
| | 3 | 230 | | } |
| | 3 | 231 | | } |
| | 3 | 232 | | """; |
| | | 233 | | |
| | | 234 | | private static string WrapVbNet(string code) |
| | 1 | 235 | | => $$""" |
| | 1 | 236 | | Imports System |
| | 1 | 237 | | Imports System.Net.Security |
| | 1 | 238 | | Imports System.Security.Cryptography.X509Certificates |
| | 1 | 239 | | |
| | 1 | 240 | | Public Module ClientCertValidationScript |
| | 1 | 241 | | Public Function Validate(certificate As X509Certificate2, chain As X509Chain, sslPolicyErrors As SslPolicyErrors) As |
| | 1 | 242 | | ' ---- User code starts here ---- |
| | 1 | 243 | | {{Indent(code, 8)}} |
| | 1 | 244 | | End Function |
| | 1 | 245 | | End Module |
| | 1 | 246 | | """; |
| | | 247 | | |
| | | 248 | | private static string Indent(string code, int spaces) |
| | | 249 | | { |
| | 4 | 250 | | var pad = new string(' ', spaces); |
| | 4 | 251 | | var lines = code.Replace("\r\n", "\n", StringComparison.Ordinal).Replace("\r", "\n", StringComparison.Ordinal) |
| | 4 | 252 | | .Split('\n'); |
| | 8 | 253 | | return string.Join("\n", lines.Select(l => pad + l)); |
| | | 254 | | } |
| | | 255 | | } |