Monday, October 5, 2020

Dependency Injection with Unity

It's time I looked at the most commonly used Dependency Injection framework in the Visual Studio world - Unity. I started by reading an excellent tutorial on the subject at tutorialsteacher.com and then modified my standard application -> configuration -> data provider -> logging provider application to use Unity.

Before we start today, I want to discuss which is better...

  1. Pass individual dependencies to constructors
  2. Pass the dependency container and allow the constructor to extract what it needs
Without dependency injection frameworks we only have option 1. But when we are collecting all our dependencies together we have the option of passing the collection and allowing the constructor to grab the dependencies it wants. This means the instantiating code needs to know less about the class being instantiated. This is a good thing.

If you only want to instantiate the class in a dependency injection framework and you are confident that framework will not be replaced, use option 2. Otherwise use option 1.

Although Unity makes registration of classes with non-default constructors simple, I like the idea of extension methods to encapsulate this functionality so it doesn't have to be repeated in different places. I'm going to continue to use this technique with Unity, even thought I don't strictly have to.

Unity provides good support for just-in-time instantiation but our application doesn't do that. We will be registering pre-instantiated instances. Also, Unity does not support just-in-time registration of abstract classes.

Start a new C#, .Net Core, console application called UnityCore. Use NuGet to add references to Unity and SqlClient. Here is the code that goes into program.cs. I have highlighted the important bits. You can see it's doing the same things as Microsoft.Extensions.DependencyInjection but with different words.

using System;
using System.Data.SqlClient;
using Unity;
 
namespace UnityCore
{
    class Program
    {
        static void Main(string[] args)
        {
            IUnityContainer container = new UnityContainer()
                .RegisterLoggingProvider<ProductionLoggingProvider>()
                .RegisterDataProvider<ProductionDataProvider>();
 
            Configuration configuration = new Configuration(container);
            bool IsProductionMode = configuration.GetSettings().isProductionMode;
 
            Console.WriteLine($"IsProductionMode is {IsProductionMode}");
            Console.WriteLine($"DataProvider is {configuration.GetDataProviderName()}");
            Console.WriteLine($"LoggingProvider is {configuration.GetLoggingProviderName()}");
            Console.WriteLine($"DataProvider's LoggingProvider is {configuration.GetDataProviderLoggingProviderName()}");
        }
    }
 
    public static class ServiceExtensions
    {
        public static IUnityContainer RegisterLoggingProvider<T>(this IUnityContainer containerwhere T:BaseLoggingProvider
        {
            container.RegisterInstance<BaseLoggingProvider>(Activator.CreateInstance(typeof(T)) as BaseLoggingProvider);
            return container;
        }
 
        public static IUnityContainer RegisterDataProvider<T>(this IUnityContainer containerwhere T:BaseDataProvider
        {
            container.RegisterInstance<BaseDataProvider>(Activator.CreateInstance(typeof(T), containeras BaseDataProvider);
            return container;
        }
    }
 
    public abstract class BaseLoggingProvider
    {
        public BaseLoggingProvider() { }
 
        abstract public string GetName();
        abstract public void Log(string s);
    }
 
    public class ProductionLoggingProvider : BaseLoggingProvider
    {
        public override string GetName()
        {
            return "Production logging provider";
        }
        public override void Log(string s)
        {
            // write to log file
        }
    }
 
    public class TestLoggingProvider : BaseLoggingProvider
    {
        public override string GetName()
        {
            return "Test logging provider";
        }
        public override void Log(string s)
        {
            Console.WriteLine(s);
        }
    }
 
    public class Settings
    {
        public bool isProductionMode;
    }
 
    public abstract class BaseDataProvider
    {
        internal BaseLoggingProvider loggingProvider = null;    // This would normally be protected
        public BaseDataProvider(IUnityContainer container)
        {
            this.loggingProvider = container.Resolve<BaseLoggingProvider>();
        }
        public abstract string GetName();
        public abstract void OpenConnection();
        public abstract Settings GetSettings();
        public abstract void CloseConnection();
        public abstract bool IsConnectionOpen();
    }
 
    public class ProductionDataProvider : BaseDataProvider
    {
        public ProductionDataProvider(IUnityContainer container) : base(container) { }
 
        private SqlConnection sqlConnection;
 
        public override void CloseConnection()
        {
            loggingProvider.Log("Closing connection");
            sqlConnection.Close();
            sqlConnection = null;
        }
 
        public override string GetName()
        {
            return "Production data provider";
        }
 
        public override Settings GetSettings()
        {
            loggingProvider.Log("Getting settings");
            if (IsConnectionOpen())
                return new Settings() { isProductionMode = true };
            else
                throw new Exception("Connection is not open");
        }
 
        public override bool IsConnectionOpen()
        {
            return (sqlConnection != null && sqlConnection.State == System.Data.ConnectionState.Open);
        }
 
        public override void OpenConnection()
        {
            loggingProvider.Log("Opening connection");
            sqlConnection = new SqlConnection("server=(local);database=master;trusted_connection=true");
            sqlConnection.Open();
        }
    }
 
    public class TestDataProvider : BaseDataProvider
    {
        public TestDataProvider(IUnityContainer container) : base(container) { }
 
        private bool isConnectionOpen = false;
 
        public override void CloseConnection()
        {
            loggingProvider.Log("Closing connection");
            isConnectionOpen = false;
        }
 
        public override string GetName()
        {
            return "Test data provider";
        }
 
        public override Settings GetSettings()
        {
            loggingProvider.Log("Getting settings");
            if (IsConnectionOpen())
                return new Settings() { isProductionMode = false };
            else
                throw new Exception("Connection is not open");
        }
 
        public override bool IsConnectionOpen()
        {
            return (isConnectionOpen);
        }
 
        public override void OpenConnection()
        {
            loggingProvider.Log("Opening connection");
            isConnectionOpen = true;
        }
    }
 
    public class Configuration
    {
        private BaseDataProvider dataProvider;
        private BaseLoggingProvider loggingProvider;
 
        public Configuration(IUnityContainer container)
        {
            this.dataProvider = container.Resolve<BaseDataProvider>();
            this.loggingProvider = container.Resolve<BaseLoggingProvider>();
        }
 
        public string GetDataProviderName()
        {
            return dataProvider.GetName();
        }
 
        public string GetLoggingProviderName()
        {
            return loggingProvider.GetName();
        }
 
        public string GetDataProviderLoggingProviderName()
        {
            return dataProvider.loggingProvider.GetName();
        }
 
        public Settings GetSettings()
        {
            Settings settings;
            dataProvider.OpenConnection();
            settings = dataProvider.GetSettings();
            dataProvider.CloseConnection();
            return settings;
        }
    }
}

If you run the program you will see the expected output.


Add an NUnit project called UnitCoreTest and rename Test1.cs to UnityTest.cs. Make the code look like this.

using NUnit.Framework;
using Unity;
using UnityCore;
 
namespace UnityCoreTest
{
    public class Tests
    {
        [Test]
        public void TestProvidersTest()
        {
            IUnityContainer container = new UnityContainer()
                .RegisterLoggingProvider<TestLoggingProvider>()
                .RegisterDataProvider<TestDataProvider>();
 
            Configuration configuration = new Configuration(container);
            bool IsProductionMode = configuration.GetSettings().isProductionMode;
 
            Assert.IsFalse(IsProductionMode);
            Assert.AreEqual("Test data provider"configuration.GetDataProviderName());
            Assert.AreEqual("Test logging provider"configuration.GetLoggingProviderName());
            Assert.AreEqual("Test logging provider"configuration.GetDataProviderLoggingProviderName());
        }
    }
}
Again, because we are using the RegisterDataProvider extension method neither the production nor the test code needs to know how the Data Provider is instantiated. Running the tests, we see they all pass, although the Unity test is slightly slower that the DependencyInjection test.


While I was reading the tutorial I noticed a slew of BMW ads, probably because the page contains many references to BMW. It's scary that Google is reading the web page I'm looking at and trying to figure out what I might be in the mood to buy. Somewhere, there's a database that says I'm interested in buying a BMW.


Right now there's probably a piece of code thinking "He's onto us!"




No comments:

Post a Comment