Refactoring the Switch Statement code smell

The Switch Statement code smell refers to using switch statements with a type code to get different behavior or data instead of using subclasses and polymorphism.

In general, it looks like this:

switch(typeCode)
   case type1:
      return data specific to type1
   case type2:
      return data specific to type2
   case type3:
      return data specific to type3Code language: plaintext (plaintext)

This switch(typeCode) structure is typically spread throughout many methods. This makes the code difficult to extend, and violates the Open-Closed Principle. This principle states that code should be open to extension, but closed to modification.

Why does this make code difficult to extend?

Imagine your code currently supports three types of birds, and this switch(birdType) structure is used in several methods. Now you’re given a new requirement where you need to support a new type of bird – let’s say Hummingbird. You’ll have to add a case for Hummingbird in all of the methods that are using the switch(birdType) structure.

Not only is this tedious, and requires you to modify several different methods (thus violating the Open-Closed Principle), but it also exposes you to potential bugs. It’s very easy to forget to update one of the methods, and the compiler won’t help you with this problem, and therefore you would only discover this by getting a runtime exception (potentially in production code).

In this article, I’ll show an example of how to refactor this code smell.

Note: The problem is not specific to switch statements. You could also have an if-elseif block checking the type code. It’s only referred to as the Switch Statement code smell because it’s more common to use the switch(typeCode) structure vs the if (typeCode == type1)-elseif(typeCode==type2) structure.

Code Smell: Switch Statement.

Definition: Using switch statements with a type code to get different behavior or data instead of using subclasses and polymorphism.

Solution:

  • Apply the Replace Type Code with Subclasses refactoring:
    • Add subclasses for each type represented by the type code.
    • Use a factory method to create the subclass objects based on the type.
    • Apply Push Down Method by moving the switch-statement-abusing methods to the subclasses.

Switch Statement code smell example

Here’s an example of the Switch Statement code smell. The Bird class is using a type code (BirdType) instead of polymorphism to get behavior and properties from different types of birds.

Bird class

public class Bird
{
	private readonly BirdType birdType;

	public Bird(BirdType type)
	{
		birdType = type;
	}
	public List<BirdColor> GetColors()
	{
		switch (birdType)
		{
			case BirdType.Cardinal:
				return new List<BirdColor>() { BirdColor.Black, BirdColor.Red };
			case BirdType.Goldfinch:
				return new List<BirdColor>() { BirdColor.Black, BirdColor.Yellow, BirdColor.White };
			case BirdType.Chickadee:
				return new List<BirdColor>() { BirdColor.Black, BirdColor.White, BirdColor.Tan };
		}
		throw new InvalidBirdTypeException();
	}
	public List<BirdFood> GetFoods()
	{
		switch (birdType)
		{
			case BirdType.Cardinal:
				return new List<BirdFood>() { BirdFood.Insects, BirdFood.Seeds, BirdFood.Fruit};
			case BirdType.Goldfinch:
				return new List<BirdFood>() { BirdFood.Insects, BirdFood.Seeds };
			case BirdType.Chickadee:
				return new List<BirdFood>() { BirdFood.Insects, BirdFood.Fruit, BirdFood.Seeds };
		}
		throw new InvalidBirdTypeException();
	}
	public BirdSizeRange GetSizeRange()
	{
		switch (birdType)
		{
			case BirdType.Cardinal:
				return new BirdSizeRange() { Lower=8, Upper=9 };
			case BirdType.Goldfinch:
				return new BirdSizeRange() { Lower=4.5, Upper=5.5 };
			case BirdType.Chickadee:
				return new BirdSizeRange() { Lower=4.75, Upper=5.75 };
		}
		throw new InvalidBirdTypeException();
	}
}
Code language: C# (cs)

Unit tests

The following parameterized unit tests verify the behavior of the existing code (before refactoring):

[TestClass()]
public class BirdTests
{
	[DataRow(BirdType.Cardinal, new BirdColor[] { BirdColor.Red, BirdColor.Black })]
	[DataRow(BirdType.Goldfinch, new BirdColor[] { BirdColor.Yellow, BirdColor.Black, BirdColor.White})]
	[DataRow(BirdType.Chickadee, new BirdColor[] { BirdColor.Black, BirdColor.White, BirdColor.Tan})]
	[DataTestMethod]
	public void GetColorsTest(BirdType birdType, BirdColor[] expected)
	{
		//arrange
		var bird = new Bird(birdType);

		//act
		var actual = bird.GetColors();

		//assert
		CollectionAssert.AreEquivalent(expected, actual.ToArray());
	}
	[DataRow(BirdType.Cardinal, new BirdFood[] { BirdFood.Insects, BirdFood.Seeds, BirdFood.Fruit })]
	[DataRow(BirdType.Goldfinch, new BirdFood[] { BirdFood.Seeds, BirdFood.Insects })]
	[DataRow(BirdType.Chickadee, new BirdFood[] { BirdFood.Insects, BirdFood.Fruit, BirdFood.Seeds})]
	[DataTestMethod]
	public void GetFoodsTest(BirdType birdType, BirdColor[] expected)
	{
		//arrange
		var bird = new Bird(birdType);

		//act
		var actual = bird.GetFoods();

		//assert
		CollectionAssert.AreEquivalent(expected, actual.ToArray());
	}
	[DataRow(BirdType.Cardinal, 8.0, 9.0)]
	[DataRow(BirdType.Goldfinch, 4.5, 5.5)]
	[DataRow(BirdType.Chickadee, 4.75, 5.75)]
	[DataTestMethod]
	public void GetSizeRange(BirdType birdType, double expectedSizeRangeLower, double expectedSizeRangeUpper)
	{
		//arrange
		var bird = new Bird(birdType);

		//act
		var actual = bird.GetSizeRange();

		//assert
		Assert.AreEqual(expectedSizeRangeLower, actual.Lower);
		Assert.AreEqual(expectedSizeRangeUpper, actual.Upper);
	}
}
Code language: C# (cs)

Before we begin

Refactoring rule #1: Always make sure you have tests covering the code you’re about to refactor. Run the tests after each small step.

Add Factory Method for creating Bird subclass objects from BirdType

1 – Replace constructor with Factory Method

  • Delete Bird() constructor.
  • Add static method Create(BirdType).
  • Make birdType a private field. Note: We need to keep this field around until the end because it’s used in all of the methods, and we’ll be refactoring the methods one at a time.
private  BirdType birdType;

public static Bird Create(BirdType birdType)
{
	return new Bird()
	{
		birdType = birdType
	};
}
Code language: C# (cs)
  • This will break the unit tests, because they are all using the Bird constructor, which is now private. So we need to update them to use the Bird.Create() factory method instead.

2 – Create a Bird subclass for each type of bird specified by BirdType

public class Cardinal : Bird
{
}

public class Chickadee : Bird
{
}

public class Goldfinch : Bird
{
}
Code language: C# (cs)

Now we have three subclasses inheriting from the Bird class. The class diagram looks like this:

Birds class relationship diagram showing subclasses - Cardinal, Goldfinch, and Chickadee

3 – Update Factory Method to generate Bird subclass objects

  • Add a switch statement that creates the appropriate Bird subclass based on the birdType.
  • Add a default case to throw an exception. Note: you always need a default case, because it’s possible to pass in an invalid value. For example: Bird.Create((BirdType)4). This does not cause a compiler error, and it would be handled by the default case.

Wait, didn’t we just say the switch(birdType) structure is a code smell, and now we’re adding it here? Yes, and this is the only exception to the rule. The only time this is not a code smell is when you’re using it for object creation. After all, something has to create the subclass objects.

public static Bird Create(BirdType birdType)
{
	Bird bird;
	switch (birdType)
	{
		case BirdType.Cardinal:
			bird = new Cardinal();
			break;
		case BirdType.Chickadee:
			bird = new Chickadee();
			break;
		case BirdType.Goldfinch:
			bird = new Goldfinch();
			break;
		default:
			throw new InvalidBirdTypeException();
	}

	bird.birdType = birdType;
	return bird;
}
Code language: C# (cs)

Push Down Method – GetColors()

We want to replace the switch statements with polymorphism. To do that, we’ll need to push down the methods in the Bird class to the subclasses. We’ll make these methods abstract in the Bird class. This will require all subclasses to override and implement the methods.

If you recall, one problem caused by the Switch Statement code smell is that it exposes us to potential runtime exceptions. When we want to add a new type, it’s easy to forget to update all of the switch statements with the new case, which results in a runtime exception.

With abstract methods, it’s impossible to make this mistake, because you are forced to override and implement the methods in the subclasses.

Bird class diagram showing the Push Down Method refactoring technique

1 – Add the abstract keyword to GetColors()

public abstract List<BirdColor> GetColors()
Code language: C# (cs)

2 – Add the abstract keyword to the Bird class

Because we changed GetColors() to abstract, we’ll get this error: ‘Bird.GetColors()’ is abstract but it is contained in non-abstract class ‘Bird’.

Only abstract classes can have abstract methods, so we need to make the Bird class abstract.

 public abstract class Bird
Code language: C# (cs)

Note: I prefer to use this ‘lean on the compiler’ technique when refactoring. Basically you make a small change that causes compiler errors, then you simply go fix these errors.

3 – Comment out the body of GetColors()

Because we changed GetColors() to abstract, we’ll get the error: ‘Bird.GetColors()’ cannot declare a body because it is marked abstract.

Abstract methods cannot have code in them. However, since we want to move the code from this method into the subclasses, we’ll simply comment it out for now instead of deleting it. We’ll also need to add a semicolon to the end of the method declaration.

public abstract List<BirdColor> GetColors();
/*
{
	switch (birdType)
	{
		case BirdType.Cardinal:
			return new List<BirdColor>() { BirdColor.Black, BirdColor.Red };
		case BirdType.Goldfinch:
			return new List<BirdColor>() { BirdColor.Black, BirdColor.Yellow, BirdColor.White };
		case BirdType.Chickadee:
			return new List<BirdColor>() { BirdColor.Black, BirdColor.White, BirdColor.Tan };
	}
	throw new InvalidBirdTypeException();
}*/
Code language: C# (cs)

4 – Implement GetColors() in the Cardinal subclass

Because Bird.GetColors() is now abstract we’ll get errors for each subclass like: ‘Cardinal’ does not implement inherited abstract member ‘Bird.GetColors()’.

Let’s start with implementing this method in the Cardinal subclass.

  • In the Cardinal subclass, override the GetColors() method.
  • From the commented out Bird.GetColors() method, copy the code relevant to the BirdType.Cardinal to Cardinal.GetColors().
public class Cardinal : Bird
{
	public override List<BirdColor> GetColors()
	{
		return new List<BirdColor>() { BirdColor.Black, BirdColor.Red };
	}
}
Code language: C# (cs)

5 – Implement GetColors() in the remaining subclasses

Just like the step before, override the GetColors() method in the remaining subclasses – Chickadee and Goldfinch – and copy the relevant code over from the commented out code.

public class Chickadee : Bird
{
	public override List<BirdColor> GetColors()
	{
		return new List<BirdColor>() { BirdColor.Black, BirdColor.White, BirdColor.Tan };
	}
}
Code language: C# (cs)
public class Goldfinch : Bird
{
	public override List<BirdColor> GetColors()
	{
		return new List<BirdColor>() { BirdColor.Black, BirdColor.Yellow, BirdColor.White };
	}
}
Code language: C# (cs)

6 – Delete the commented out code in the Bird class

We’ve moved the logic from this commented out code to the subclasses. Now the commented out code serves no purpose and we can delete it.

public abstract List<BirdColor> GetColors();
/* Delete this
{
	switch (birdType)
	{
		case BirdType.Cardinal:
			return new List<BirdColor>() { BirdColor.Black, BirdColor.Red };
		case BirdType.Goldfinch:
			return new List<BirdColor>() { BirdColor.Black, BirdColor.Yellow, BirdColor.White };
		case BirdType.Chickadee:
			return new List<BirdColor>() { BirdColor.Black, BirdColor.White, BirdColor.Tan };
	}
	throw new InvalidBirdTypeException();
}
*/
Code language: C# (cs)

7 – Run the unit tests

Verify the GetColors() unit tests are all passing.

Test Detail Summary showing the GetColorsTest unit tests passing after the refactoring step

Push Down Methods – GetFoods() and GetSizeRange()

In the previous step we pushed down the GetColors() method to the three subclasses. So let’s push down the other two methods – GetFoods() and GetSizeRange().

  • Make the methods abstract.
  • Override the methods in the subclasses.
  • Copy the relevant logic for each type to the methods in the subclasses.
  • Run the unit tests.

Here’s the Cardinal subclass with all three of the methods implemented. Please see the end of the article for a full listing of all the refactored code.

public class Cardinal : Bird
{
	public override List<BirdColor> GetColors()
	{
		return new List<BirdColor>() { BirdColor.Black, BirdColor.Red };
	}

	public override List<BirdFood> GetFoods()
	{
		return new List<BirdFood>() { BirdFood.Insects, BirdFood.Seeds, BirdFood.Fruit };
	}

	public override BirdSizeRange GetSizeRange()
	{
		return new BirdSizeRange() { Lower = 8, Upper = 9 };
	}
}
Code language: C# (cs)

Clean up the Factory Method

When we first created the Bird.Create() factory method we had to keep the birdType field, because it was being referenced in all of the methods. These are now abstract and no longer reference this field, so we can delete it.

  • In the Bird class delete the birdType field.
  • In the Create() factory method, remove the reference to Bird.birdType.
  • Remove the Bird object.
  • In each case in the switch statement, return the new object immediately.
  • Remove the break statements.
public static Bird Create(BirdType birdType)
{
	switch (birdType)
	{
		case BirdType.Cardinal:
			return new Cardinal();
		case BirdType.Chickadee:
			return new Chickadee();
		case BirdType.Goldfinch:
			return new Goldfinch();
		default:
			throw new InvalidBirdTypeException();
	}
}
Code language: C# (cs)

Refactored code

We dealt with the Switch Statement code smell by applying the replace type code with subclasses refactoring. We added Bird subclasses for each bird type and used a factory method to create them. We got rid of the switch statements by pushing down the methods into the subclasses and relying on polymorphism to execute the correct behavior.

Our refactored class diagram now looks like this:

Bird class and subclasses after applying the replace typecode with subclass refactoring

In the end we are left with code that is easy to extend. If we wanted to add another bird type, we’d simply need to update the factory method and implement all the functionality specific to that bird in a new subclass.

Bird class

public abstract class Bird
{
	public abstract List<BirdColor> GetColors();
	public abstract List<BirdFood> GetFoods();
	public abstract BirdSizeRange GetSizeRange();

	public static Bird Create(BirdType birdType)
	{
		switch (birdType)
		{
			case BirdType.Cardinal:
				return new Cardinal();
			case BirdType.Chickadee:
				return new Chickadee();
			case BirdType.Goldfinch:
				return new Goldfinch();
			default:
				throw new InvalidBirdTypeException();
		}
	}
}
Code language: C# (cs)

Cardinal class

public class Cardinal : Bird
{
	public override List<BirdColor> GetColors()
	{
		return new List<BirdColor>() { BirdColor.Black, BirdColor.Red };
	}

	public override List<BirdFood> GetFoods()
	{
		return new List<BirdFood>() { BirdFood.Insects, BirdFood.Seeds, BirdFood.Fruit };
	}

	public override BirdSizeRange GetSizeRange()
	{
		return new BirdSizeRange() { Lower = 8, Upper = 9 };
	}
}
Code language: C# (cs)

Chickadee class

public class Chickadee : Bird
{
	public override List<BirdColor> GetColors()
	{
		return new List<BirdColor>() { BirdColor.Black, BirdColor.White, BirdColor.Tan };
	}

	public override List<BirdFood> GetFoods()
	{
		return new List<BirdFood>() { BirdFood.Insects, BirdFood.Fruit, BirdFood.Seeds };
	}

	public override BirdSizeRange GetSizeRange()
	{
		return new BirdSizeRange() { Lower = 4.75, Upper = 5.75 };
	}
}
Code language: C# (cs)

Goldfinch class

public class Goldfinch : Bird
{
	public override List<BirdColor> GetColors()
	{
		return new List<BirdColor>() { BirdColor.Black, BirdColor.Yellow, BirdColor.White };
	}

	public override List<BirdFood> GetFoods()
	{
		return new List<BirdFood>() { BirdFood.Insects, BirdFood.Seeds };
	}

	public override BirdSizeRange GetSizeRange()
	{
		return new BirdSizeRange() { Lower = 4.5, Upper = 5.5 };
	}
}
Code language: C# (cs)

Leave a Comment