| | | 1 | | using System.Reflection; |
| | | 2 | | using Microsoft.AspNetCore.RateLimiting; |
| | | 3 | | namespace Kestrun.Utilities; |
| | | 4 | | /// <summary> |
| | | 5 | | /// Provides extension methods for copying rate limiter options and policies. |
| | | 6 | | /// </summary> |
| | | 7 | | public static class RateLimiterOptionsExtensions |
| | | 8 | | { |
| | | 9 | | /// <summary> |
| | | 10 | | /// Copies all rate limiter options and policies from the source to the target <see cref="RateLimiterOptions"/>. |
| | | 11 | | /// </summary> |
| | | 12 | | /// <param name="target">The target <see cref="RateLimiterOptions"/> to copy to.</param> |
| | | 13 | | /// <param name="source">The source <see cref="RateLimiterOptions"/> to copy from.</param> |
| | | 14 | | public static void CopyFrom(this RateLimiterOptions target, RateLimiterOptions source) |
| | | 15 | | { |
| | 3 | 16 | | ArgumentNullException.ThrowIfNull(source); |
| | | 17 | | |
| | | 18 | | // ───── scalar props ───── |
| | 2 | 19 | | target.GlobalLimiter = source.GlobalLimiter; |
| | 2 | 20 | | target.OnRejected = source.OnRejected; |
| | 2 | 21 | | target.RejectionStatusCode = source.RejectionStatusCode; |
| | | 22 | | |
| | | 23 | | // ───── activated policies ───── |
| | | 24 | | try |
| | | 25 | | { |
| | 2 | 26 | | var policyMapField = typeof(RateLimiterOptions).GetField("PolicyMap", |
| | 2 | 27 | | BindingFlags.Instance | BindingFlags.NonPublic); |
| | 2 | 28 | | if (policyMapField != null) |
| | | 29 | | { |
| | 0 | 30 | | var policyMap = (IDictionary<string, object>?)policyMapField.GetValue(source); |
| | 0 | 31 | | if (policyMap != null) |
| | | 32 | | { |
| | | 33 | | // Find the AddPolicy method that takes an IRateLimiterPolicy<HttpContext> |
| | 0 | 34 | | var addPolicyMethod = GetAddPolicyMethod(true); |
| | 0 | 35 | | foreach (var kvp in policyMap) |
| | | 36 | | { |
| | 0 | 37 | | _ = (addPolicyMethod?.Invoke(target, [kvp.Key, kvp.Value])); |
| | | 38 | | } |
| | | 39 | | } |
| | | 40 | | } |
| | 2 | 41 | | } |
| | 0 | 42 | | catch |
| | | 43 | | { |
| | | 44 | | // Silently ignore if PolicyMap field doesn't exist in this version |
| | 0 | 45 | | } |
| | | 46 | | |
| | | 47 | | // ───── factories awaiting DI (un-activated) ───── |
| | | 48 | | try |
| | | 49 | | { |
| | 2 | 50 | | var factoryMapField = typeof(RateLimiterOptions).GetField("UnactivatedPolicyMap", |
| | 2 | 51 | | BindingFlags.Instance | BindingFlags.NonPublic); |
| | 2 | 52 | | if (factoryMapField != null) |
| | | 53 | | { |
| | 0 | 54 | | var factoryMap = (IDictionary<string, object>?)factoryMapField.GetValue(source); |
| | 0 | 55 | | if (factoryMap != null) |
| | | 56 | | { |
| | | 57 | | // Find the AddPolicy method that takes a Func<IServiceProvider, IRateLimiterPolicy<HttpContext>> |
| | 0 | 58 | | var addPolicyMethod = GetAddPolicyMethod(false); |
| | 0 | 59 | | foreach (var kvp in factoryMap) |
| | | 60 | | { |
| | 0 | 61 | | _ = (addPolicyMethod?.Invoke(target, [kvp.Key, kvp.Value])); |
| | | 62 | | } |
| | | 63 | | } |
| | | 64 | | } |
| | 2 | 65 | | } |
| | 0 | 66 | | catch |
| | | 67 | | { |
| | | 68 | | // Silently ignore if UnactivatedPolicyMap field doesn't exist in this version |
| | 0 | 69 | | } |
| | 2 | 70 | | } |
| | | 71 | | |
| | | 72 | | private static MethodInfo? GetAddPolicyMethod(bool forDirectPolicy) |
| | | 73 | | { |
| | 0 | 74 | | var methods = typeof(RateLimiterOptions).GetMethods(); |
| | 0 | 75 | | foreach (var method in methods) |
| | | 76 | | { |
| | 0 | 77 | | if (method.Name != "AddPolicy") |
| | | 78 | | { |
| | | 79 | | continue; |
| | | 80 | | } |
| | | 81 | | |
| | 0 | 82 | | var parameters = method.GetParameters(); |
| | 0 | 83 | | if (parameters.Length != 2) |
| | | 84 | | { |
| | | 85 | | continue; |
| | | 86 | | } |
| | | 87 | | |
| | 0 | 88 | | if (parameters[0].ParameterType != typeof(string)) |
| | | 89 | | { |
| | | 90 | | continue; |
| | | 91 | | } |
| | | 92 | | |
| | 0 | 93 | | var secondParamType = parameters[1].ParameterType; |
| | | 94 | | |
| | 0 | 95 | | if (forDirectPolicy) |
| | | 96 | | { |
| | | 97 | | // Looking for AddPolicy(string, IRateLimiterPolicy<HttpContext>) |
| | | 98 | | // The parameter should be an interface that is IRateLimiterPolicy<T> |
| | 0 | 99 | | if (secondParamType.IsGenericType && |
| | 0 | 100 | | secondParamType.GetGenericTypeDefinition().Name.Contains("IRateLimiterPolicy")) |
| | | 101 | | { |
| | 0 | 102 | | return method; |
| | | 103 | | } |
| | | 104 | | } |
| | | 105 | | else |
| | | 106 | | { |
| | | 107 | | // Looking for AddPolicy(string, Func<IServiceProvider, IRateLimiterPolicy<HttpContext>>) |
| | | 108 | | // The parameter should be a Func delegate |
| | 0 | 109 | | if (secondParamType.IsGenericType && |
| | 0 | 110 | | secondParamType.GetGenericTypeDefinition() == typeof(Func<,>)) |
| | | 111 | | { |
| | 0 | 112 | | return method; |
| | | 113 | | } |
| | | 114 | | } |
| | | 115 | | } |
| | | 116 | | |
| | 0 | 117 | | return null; |
| | | 118 | | } |
| | | 119 | | } |