Pages

Friday, November 25, 2011

C# Custom Enumerators Made Simple with the Yield Keyword

An enumerator enables you to iterate over a collection in a foreach loop.  You can use foreach to iterate over all C# collection classes, because all C# collection classes inherit from the IEnumerable interface (regular or generic).  IEnumerable contains the GetEnumerator method, which returns an enumerator.
Occasionally you may find a need to create a custom enumerator, which used to be somewhat of a challenge until the yield keyword was introduced.  Here is how Microsoft describes yield:
The yield keyword signals to the compiler that the method in which it appears is an iterator block.  The compiler generates a class to implement the behavior that is expressed in the iterator block.  In the iterator block, the yield keyword is used together with the return keyword to provide a value to the enumerator object.  This is the value that is returned, for example, in each loop of a foreach statement.
So rather than creating your own enumerator class and managing the enumeration state — which is time consuming and tricky — you can simply write the enumeration logic in the GetEnumerator method, and the yield keyword will automagically wrap your code in a handy-dandy enumerator.


Custom Enumerator Example: Wrap a Collection

To demonstrate the power of yield, let’s create a simple custom enumerator.  In this example problem, we have a collection of Base objects, but we only want to work with Derived objects.
So imagine two classes, where the “Derived” class inherits from the “Base” class:
public class Base
{
    public Base( string name )
    {
        this.Name = name;
    }
    public string Name;
}
public class Derived : Base
{
    public Derived( string name )
        : base( name ) { }
}
We also define a collection of Base objects:
public class BaseColl : List<Base> { }



But for this sample problem, we want to work only with Derived objects in the Base collection.  To make it easy for developers to use, we’ll create a Derived collection that wraps a Base collection.  Notice how the DerivedColl constructor takes a reference to the BaseColl that it wraps:
public class DerivedColl : IEnumerable<Derived>
{
    public DerivedColl( BaseColl baseColl )
    {
        this.m_BaseColl = baseColl;
    }
    private BaseColl m_BaseColl;
}

Yield Keyword

Missing from the DerivedColl code above is the GetEnumerator method where the “yield” magic occurs:
public IEnumerator<Derived> GetEnumerator()
{
    foreach (Base b in this.m_BaseColl)
    {
        Derived d = b as Derived;
        if (d != null)
            yield return d;
    }
}
The foreach code above iterates over the Base objects in the wrapped BaseColl collection.  If the Base object is a Derived object (d != null), then the yield keyword returns it.  The effect is that the enumerator returned by this GetEnumerator method iterates over ONLY the Derived objects in the Base collection. 
Quite handy!  If you’ve ever spent time writing a custom enumerator class, you will welcome the yield keyword.

Sample Program

Here is a sample console program:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
using System;
using System.Collections;
using System.Collections.Generic;
 
namespace CSharp411
{
    class Program
    {
        public class Base
        {
            public Base( string name )
            {
                this.Name = name;
            }
            public string Name;
        }
        public class Derived : Base
        {
            public Derived( string name )
                : base( name )
            {
            }
        }
 
        public class BaseColl : List<Base> { }
        public class DerivedColl : IEnumerable<Derived>
        {
            public DerivedColl( BaseColl baseColl )
            {
                this.m_BaseColl = baseColl;
            }
            private BaseColl m_BaseColl;
            public IEnumerator<Derived> GetEnumerator()
            {
                foreach (Base b in this.m_BaseColl)
                {
                    Derived d = b as Derived;
                    if (d != null)
                        yield return d;
                }
            }
            System.Collections.IEnumerator IEnumerable.GetEnumerator()
            {
                return this.GetEnumerator();
            }
        }
 
        static void Main( string[] args )
        {
            BaseColl baseColl = new BaseColl();
            DerivedColl derivedColl = new DerivedColl( baseColl );
 
            Base b = new Base( "Base1" ); baseColl.Add( b );
            b = new Base( "Base2" ); baseColl.Add( b );
            Derived d = new Derived( "Derived1" ); baseColl.Add( d );
            d = new Derived( "Derived2" ); baseColl.Add( d );
            b = new Base( "Base3" ); baseColl.Add( b );
            b = new Base( "Base4" ); baseColl.Add( b );
            d = new Derived( "Derived3" ); baseColl.Add( d );
            d = new Derived( "Derived4" ); baseColl.Add( d );
 
            foreach (Derived derived in derivedColl)
            {
                Console.WriteLine( derived.Name );
            }
            Console.ReadLine();
        }
    }
}

Sample Output

The output is only the Derived objects:
Derived1
Derived2
Derived3
Derived4

No comments:

Post a Comment