diff --git a/VpSharp.Commands/CommandsExtension.cs b/VpSharp.Commands/CommandsExtension.cs index a7f2fe5..d07a5b7 100644 --- a/VpSharp.Commands/CommandsExtension.cs +++ b/VpSharp.Commands/CommandsExtension.cs @@ -5,6 +5,7 @@ using VpSharp.Commands.Attributes; using VpSharp.Commands.Attributes.ExecutionChecks; using VpSharp.Entities; using VpSharp.EventData; +using VpSharp.Internal; namespace VpSharp.Commands; @@ -146,29 +147,9 @@ public sealed class CommandsExtension : VirtualParadiseClientExtension throw new ArgumentException($"Module type is not a subclass of {typeof(CommandModule)}"); } - ConstructorInfo[] constructors = moduleType.GetTypeInfo().DeclaredConstructors.Where(c => c.IsPublic).ToArray(); - if (constructors.Length != 1) - { - throw new ArgumentException( - $"Constructor for {moduleType} is not public, or {moduleType} has more than one public constructor."); - } - - ConstructorInfo constructor = constructors[0]; - ParameterInfo[] parameters = constructor.GetParameters(); - IServiceProvider? serviceProvider = _configuration.Services; - if (parameters.Length != 0 && serviceProvider is null) - { - throw new InvalidOperationException("No ServiceProvider has been registered!"); - } - - var args = new object[parameters.Length]; - for (var index = 0; index < args.Length; index++) - { - args[index] = serviceProvider!.GetRequiredService(parameters[index].ParameterType); - } - - if (Activator.CreateInstance(moduleType, args) is not CommandModule module) + object instance = DependencyInjectionUtility.CreateInstance(moduleType, serviceProvider); + if (instance is not CommandModule module) { throw new TypeInitializationException(moduleType.FullName, null); } @@ -180,7 +161,7 @@ public sealed class CommandsExtension : VirtualParadiseClientExtension } /// - protected override Task OnMessageReceived(MessageReceivedEventArgs args) + protected internal override Task OnMessageReceived(MessageReceivedEventArgs args) { ArgumentNullException.ThrowIfNull(args); VirtualParadiseMessage message = args.Message; diff --git a/VpSharp/src/Assembly.cs b/VpSharp/src/Assembly.cs index 560422b..09e3f50 100644 --- a/VpSharp/src/Assembly.cs +++ b/VpSharp/src/Assembly.cs @@ -1,5 +1,6 @@ using System.Runtime.CompilerServices; [assembly: CLSCompliant(true)] +[assembly: InternalsVisibleTo("VpSharp.Commands")] [assembly: InternalsVisibleTo("VpSharp.IntegrationTests")] [assembly: InternalsVisibleTo("VpSharp.Tests")] diff --git a/VpSharp/src/Internal/DependencyInjectionUtility.cs b/VpSharp/src/Internal/DependencyInjectionUtility.cs new file mode 100644 index 0000000..61b1112 --- /dev/null +++ b/VpSharp/src/Internal/DependencyInjectionUtility.cs @@ -0,0 +1,71 @@ +using System.Reflection; +using Microsoft.Extensions.DependencyInjection; + +namespace VpSharp.Internal; + +internal static class DependencyInjectionUtility +{ + public static T CreateInstance(VirtualParadiseClient client) + { + return CreateInstance(client.Services); + } + + public static T CreateInstance(IServiceProvider? serviceProvider = null) + { + return (T)CreateInstance(typeof(T), serviceProvider); + } + + public static object CreateInstance(Type type, VirtualParadiseClient client) + { + ArgumentNullException.ThrowIfNull(type); + ArgumentNullException.ThrowIfNull(client); + return CreateInstance(type, client.Services); + } + + public static object CreateInstance(Type type, IServiceProvider? serviceProvider = null) + { + ArgumentNullException.ThrowIfNull(type); + object? instance; + + TypeInfo typeInfo = type.GetTypeInfo(); + ConstructorInfo[] constructors = typeInfo.DeclaredConstructors.Where(c => c.IsPublic).ToArray(); + + if (constructors.Length != 1) + { + throw new InvalidOperationException($"{type} has no public constructors, or has more than one public constructor."); + } + + ConstructorInfo constructor = constructors[0]; + ParameterInfo[] parameters = constructor.GetParameters(); + + if (parameters.Length > 0 && serviceProvider is null) + { + throw new InvalidOperationException("No ServiceProvider has been registered!"); + } + + if (parameters.Length == 0) + { + instance = Activator.CreateInstance(type); + if (instance is null) + { + throw new TypeInitializationException(type.FullName, null); + } + + return instance; + } + + var args = new object[parameters.Length]; + for (var index = 0; index < parameters.Length; index++) + { + args[index] = serviceProvider!.GetRequiredService(parameters[index].ParameterType); + } + + instance = Activator.CreateInstance(type, args); + if (instance is null) + { + throw new TypeInitializationException(type.FullName, null); + } + + return instance; + } +}