Sample Asynchronous SslStream Client/Server Implementation

I was recently asked about sample code for the System.Net.Security.SslStream using the asynchronous APIs. I searched the web and couldn't find anything significant, so I decided to write some that included the asynchronous usage of the TcpListenter and TcpClient objects (both in the System.Net.Sockets namespace). In this post, I will show both the client and server side of the picture.

The person that was requesting this sample code seemed pretty anxious to get it, so I pumped this out pretty quickly, and as such, there are probably some bugs in it – please let me know of any you find and I will fix them. My goal here is to introduce the concepts and is not to provide production ready code. There are some things that I plan on changing about this code, but this should be a good starting place for anyone unfamiliar with the usage. Watch for “To Do Items” sections where I try to call out some of the things that I would like to add at some point, but you may need to add to your own code in the mean time…

To start, here is some helper code that is shared between client and server code.

Common Code:

The delegate function definition is used to notify you that a secure connection attempt is either ready for use or that an asynchronous error occurred. The SecureConnectionResults class is merely a holder for the results. You will need to provide an implementation of this delegate function.

When you delegate function is called, you will need to check for any asynchronous exception that may have occurred by looking at the SecureConnectionResults.AsyncException property. If it is null, then the SslStream object should be ready for use. Also, you will want to note that the SecureConnectionResultsCallback will be invoked on a Threadpool thread, so you won’t want to hold onto that thread permanently as it is a limited resource.

To Do Items:

· May also want to expose the underlying socket through the SecureConnectionResults object. This would allow you to change socket options or switch back to an unencrypted network connection if you so desire.

using System;

using System.Net.Security;

public delegate void SecureConnectionResultsCallback(object sender, SecureConnectionResults args);

public class SecureConnectionResults

{

    private SslStream secureStream;

    private Exception asyncException;

    internal SecureConnectionResults(SslStream sslStream)

    {

        this.secureStream = sslStream;

    }

    internal SecureConnectionResults(Exception exception)

    {

        this.asyncException = exception;

    }

    public Exception AsyncException{get{return asyncException;}}

    public SslStream SecureStream{get{return secureStream;}}

}

 

Server Code:

Here is my implementation of a TcpListener that accepts a connection and then initiates the SSL handshake immediately (by calling SslStream.BeginAuthenticateAsServer from within the OnAcceptConnection callback). This code automatically handles both IPv6 and IPv4 for you – make sure to read the “To Do Items”. If you aren't interested in reading all of the code, the bulk of the work is done inside the StartConnecting, OnAcceptConnection and OnAuthenticateAsServer functions.

To Do Items:

· I haven’t added any substantial error handling to this code.

o For example, you might be able to bind the IPv4 socket to the port but an error could occur when trying to bind the IPv6 socket to that same port (i.e. - it is taken by another application on IPv6 but not on IPv4).

· When on Windows Vista or later, use a single socket to listen for both IPv4 and IPv6 – this avoids the binding problem previously discussed.

· Another example would be that a fatal socket error could happen in OnAcceptConnection or OnAuthenticateAsServer that should result in the entire object being invalidated. I will leave it up to you to add these details for now, but will try to update the code when I have time.

· Thread safety. I am not doing any locking inside the object and so you could see funny results.

o For example, you shutdown the object while a callback is in progress.

using System;

using System.Net;

using System.Net.Sockets;

using System.Net.Security;

using System.Security.Cryptography.X509Certificates;

using System.Security.Authentication;

public class SecureTcpServer : IDisposable

{

    X509Certificate serverCert;

    RemoteCertificateValidationCallback certValidationCallback;

    SecureConnectionResultsCallback connectionCallback;

    AsyncCallback onAcceptConnection;

    AsyncCallback onAuthenticateAsServer;

    bool started;

    int listenPort;

    TcpListener listenerV4;

    TcpListener listenerV6;

    int disposed;

    bool clientCertificateRequired;

    bool checkCertifcateRevocation;

    SslProtocols sslProtocols;

    public SecureTcpServer(int listenPort, X509Certificate serverCertificate,

        SecureConnectionResultsCallback callback)

        :this(listenPort,serverCertificate,callback,null)

    {

    }

    public SecureTcpServer(int listenPort, X509Certificate serverCertificate,

        SecureConnectionResultsCallback callback,

        RemoteCertificateValidationCallback certValidationCallback)

    {

        if (listenPort < 0 || listenPort > UInt16.MaxValue)

            throw new ArgumentOutOfRangeException("listenPort");

        if (serverCertificate == null)

            throw new ArgumentNullException("serverCertificate");

        if (callback == null)

            throw new ArgumentNullException("callback");

        onAcceptConnection = new AsyncCallback(OnAcceptConnection);

        onAuthenticateAsServer = new AsyncCallback(OnAuthenticateAsServer);

        this.serverCert = serverCertificate;

        this.certValidationCallback = certValidationCallback;

        this.connectionCallback = callback;

        this.listenPort = listenPort;

        this.disposed = 0;

        this.checkCertifcateRevocation = false;

    this.clientCertificateRequired = false;

        this.sslProtocols = SslProtocols.Default;

    }

    ~SecureTcpServer()

    {

        Dispose();

    }

    public SslProtocols SslProtocols

    {

        get { return sslProtocols; }

        set { sslProtocols = value; }

    }

    public bool CheckCertifcateRevocation

    {

        get { return checkCertifcateRevocation; }

        set { checkCertifcateRevocation = value; }

    }

    public bool ClientCertificateRequired

    {

        get { return clientCertificateRequired; }

        set { clientCertificateRequired = value; }

    }

    public void StartListening()

    {

        if(started)

            throw new InvalidOperationException("Already started...");

        IPEndPoint localIP;

        if (Socket.SupportsIPv4 && listenerV4 == null)

        {

            localIP = new IPEndPoint(IPAddress.Any, listenPort);

            Console.WriteLine("SecureTcpServer: Started listening on {0}", localIP);

            listenerV4 = new TcpListener(localIP);

        }

        if(Socket.OSSupportsIPv6 && listenerV6 == null)

        {

            localIP = new IPEndPoint(IPAddress.IPv6Any, listenPort);

            Console.WriteLine("SecureTcpServer: Started listening on {0}", localIP);

            listenerV6 = new TcpListener(localIP);

        }

        if(listenerV4 != null)

        {

            listenerV4.Start();

            listenerV4.BeginAcceptTcpClient(onAcceptConnection,listenerV4);

        }

        if(listenerV6 != null)

   {

            listenerV6.Start();

            listenerV6.BeginAcceptTcpClient(onAcceptConnection,listenerV6);

        }

        started = true;

    }

    public void StopListening()

    {

        if (!started)

            return;

        started = false;

        if (listenerV4 != null)

            listenerV4.Stop();

        if (listenerV6 != null)

            listenerV6.Stop();

    }

    void OnAcceptConnection(IAsyncResult result)

    {

        TcpListener listener = result.AsyncState as TcpListener;

        TcpClient client = null;

        SslStream sslStream = null;

        try

        {

            if (this.started)

            {

                //start accepting the next connection...

                listener.BeginAcceptTcpClient(this.onAcceptConnection, listener);

            }

            else

            {

                //someone called Stop() - don't call EndAcceptTcpClient because

                //it will throw an ObjectDisposedException

                return;

            }

            //complete the last operation...

            client = listener.EndAcceptTcpClient(result);

            bool leaveStreamOpen = false;//close the socket when done

            if (this.certValidationCallback != null)

                sslStream = new SslStream(client.GetStream(), leaveStreamOpen, this.certValidationCallback);

            else

                sslStream = new SslStream(client.GetStream(), leaveStreamOpen);

            sslStream.BeginAuthenticateAsServer(this.serverCert,

           this.clientCertificateRequired,

                this.sslProtocols,

                this.checkCertifcateRevocation,//checkCertifcateRevocation

                this.onAuthenticateAsServer,

                sslStream);

        }

        catch (Exception ex)

        {

            if (sslStream != null)

            {

                sslStream.Dispose();

                sslStream = null;

            }

            this.connectionCallback(this, new SecureConnectionResults(ex));

        }

    }

    void OnAuthenticateAsServer(IAsyncResult result)

    {

        SslStream sslStream = null;

        try

        {

            sslStream = result.AsyncState as SslStream;

            sslStream.EndAuthenticateAsServer(result);

            this.connectionCallback(this, new SecureConnectionResults(sslStream));

        }

        catch (Exception ex)

        {

            if (sslStream != null)

            {

                sslStream.Dispose();

                sslStream = null;

            }

            this.connectionCallback(this, new SecureConnectionResults(ex));

        }

    }

    public void Dispose()

    {

        if (System.Threading.Interlocked.Increment(ref disposed) == 1)

        {

            if (this.listenerV4 != null)

                listenerV4.Stop();

            if (this.listenerV6 != null)

                listenerV6.Stop();

            listenerV4 = null;

            listenerV6 = null;

            GC.SuppressFinalize(this);

        }

    }

}

Client Code:

The client code is similar to the server code in that it tries to abstract you from handling the connecting and SSL negotiation. The remoteHostName parameter is used when validating the server’s certificate. If the name you pass in doesn’t match the name on the server’s certificate, you will get certificate errors (i.e. System.Net.Security.SslPolicyErrors.RemoteCertificateNameMismatch). Again, if you really only want to look at the most important parts of the code, focus on the StartConnecting, OnConnected and OnAuthenticateAsClient functions.

To Do Items:

· Change StartConnecting so that it doesn’t require a single IPEndpoint. It would be nice if the code just did an asynchronous DNS lookup for you based on the host name and then automatically connected to any of the resolved addresses (regardless of IPv4 or IPv6)

· I haven’t added any substantial error handling to this code.

o An example would be that a fatal error could happen that should result in the entire object being invalidated. I will leave it up to you to add these details for now, but will try to update the code when I have time.

· Thread safety. I am not doing any locking inside the object and so you could see funny results.

using System;

using System.Net;

using System.Net.Sockets;

using System.Net.Security;

using System.Security.Cryptography.X509Certificates;

using System.Security.Authentication;

public class SecureTcpClient : IDisposable

{

    X509CertificateCollection clientCertificates;

   RemoteCertificateValidationCallback certValidationCallback;

    SecureConnectionResultsCallback connectionCallback;

    bool checkCertificateRevocation;

    AsyncCallback onConnected;

    AsyncCallback onAuthenticateAsClient;

    TcpClient client;

    IPEndPoint remoteEndPoint;

    string remoteHostName;

    SslProtocols protocols;

    int disposed;

    public SecureTcpClient(SecureConnectionResultsCallback callback)

        : this(callback,null,SslProtocols.Default)

    {

    }

  public SecureTcpClient(SecureConnectionResultsCallback callback,

        RemoteCertificateValidationCallback certValidationCallback)

        : this(callback, certValidationCallback, SslProtocols.Default)

    {

    }

    public SecureTcpClient(SecureConnectionResultsCallback callback,

        RemoteCertificateValidationCallback certValidationCallback,

        SslProtocols sslProtocols)

    {

        if (callback == null)

            throw new ArgumentNullException("callback");

        onConnected = new AsyncCallback(OnConnected);

        onAuthenticateAsClient = new AsyncCallback(OnAuthenticateAsClient);

        this.certValidationCallback = certValidationCallback;

        this.connectionCallback = callback;

        protocols = sslProtocols;

        this.disposed = 0;

    }

    ~SecureTcpClient()

    {

        Dispose();

    }

    public bool CheckCertificateRevocation

    {

        get { return checkCertificateRevocation; }

        set {checkCertificateRevocation = value;}

    }

    public void StartConnecting(string remoteHostName, IPEndPoint remoteEndPoint)

    {

        StartConnecting(remoteHostName,remoteEndPoint,null);

    }

    public void StartConnecting(string remoteHostName, IPEndPoint remoteEndPoint,

        X509CertificateCollection clientCertificates)

    {

        if (string.IsNullOrEmpty(remoteHostName))

            throw new ArgumentException("Value cannot be null or empty","remoteHostName");

        if (remoteEndPoint == null)

            throw new ArgumentNullException("remoteEndPoint");

        Console.WriteLine("Client connecting to: {0}", remoteEndPoint);

       

        this.clientCertificates = clientCertificates;

        this.remoteHostName = remoteHostName;

        this.remoteEndPoint = remoteEndPoint;

        if (client != null)

            client.Close();

       

        client = new TcpClient(remoteEndPoint.AddressFamily);

       

        client.BeginConnect(remoteEndPoint.Address,

            remoteEndPoint.Port,

            this.onConnected,null);

    }

    public void Close()

    {

        Dispose();

    }

    void OnConnected(IAsyncResult result)

    {

        SslStream sslStream = null;

        try

        {

            bool leaveStreamOpen = false;//close the socket when done

            if (this.certValidationCallback != null)

                sslStream = new SslStream(client.GetStream(), leaveStreamOpen, this.certValidationCallback);

            else

                sslStream = new SslStream(client.GetStream(), leaveStreamOpen);

         

            sslStream.BeginAuthenticateAsClient(this.remoteHostName,

                    this.clientCertificates,

                    this.protocols,

                    this.checkCertificateRevocation,

                    this.onAuthenticateAsClient,

                    sslStream);

        }

        catch (Exception ex)

        {

            if (sslStream != null)

            {

                sslStream.Dispose();

                sslStream = null;

            }

            this.connectionCallback(this,new SecureConnectionResults(ex));

        }

    }

    void OnAuthenticateAsClient(IAsyncResult result)

    {

        SslStream sslStream = null;

        try

        {

            sslStream = result.AsyncState as SslStream;

            sslStream.EndAuthenticateAsClient(result);

            this.connectionCallback(this,new SecureConnectionResults( sslStream));

        }

        catch (Exception ex)

        {

            if (sslStream != null)

            {

                sslStream.Dispose();

                sslStream = null;

            }

            this.connectionCallback(this, new SecureConnectionResults(ex));

        }

    }

    public void Dispose()

    {

        if (System.Threading.Interlocked.Increment(ref disposed) == 1)

        {

            if (client != null)

            {

                client.Close();

                client = null;

            }

            GC.SuppressFinalize(this);

        }

    }

}

 

Sample Program:

 

Here is a very simple program that demonstrates the use of these objects.

 

using System;

using System.IO;

using System.Net;

using System.Threading;

using System.Net.Sockets;

using System.Security.Cryptography;

using System.Security.Cryptography.X509Certificates;

using System.Net.Security;

class Program

{

    static void Main(string[] args)

    {

        SecureTcpServer server = null;

        SecureTcpClient client = null;

        try

        {

            int port = 8889;

            RemoteCertificateValidationCallback certValidationCallback = null;

            certValidationCallback = new RemoteCertificateValidationCallback(IgnoreCertificateErrorsCallback);

            string certPath = System.Reflection.Assembly.GetEntryAssembly().Location;

            certPath = Path.GetDirectoryName(certPath);

            certPath = Path.Combine(certPath, "serverCert.cer");

            Console.WriteLine("Loading Server Cert From: " + certPath);

            X509Certificate serverCert = X509Certificate.CreateFromCertFile(certPath);

            server = new SecureTcpServer(port, serverCert,

                new SecureConnectionResultsCallback(OnServerConnectionAvailable));

           

            server.StartListening();

            client = new SecureTcpClient(new SecureConnectionResultsCallback(OnClientConnectionAvailable),

                certValidationCallback);

            client.StartConnecting("localhost", new IPEndPoint(IPAddress.Loopback, port));

        }

        catch (Exception ex)

        {

            Console.WriteLine(ex);

        }

        //sleep to avoid printing this text until after the callbacks have been invoked.

        Thread.Sleep(4000);

        Console.WriteLine("Press any key to continue...");

        Console.ReadKey();

        if (server != null)

            server.Dispose();

        if (client != null)

            client.Dispose();

       

    }

    static void OnServerConnectionAvailable(object sender, SecureConnectionResults args)

    {

        if (args.AsyncException != null)

        {

            Console.WriteLine(args.AsyncException);

            return;

        }

        SslStream stream = args.SecureStream;

        Console.WriteLine("Server Connection secured: " + stream.IsAuthenticated);

 

       

       

        StreamWriter writer = new StreamWriter(stream);

        writer.AutoFlush = true;

        writer.WriteLine("Hello from server!");

        StreamReader reader = new StreamReader(stream);

        string line = reader.ReadLine();

        Console.WriteLine("Server Recieved: '{0}'", line == null ? "<NULL>" : line);

        writer.Close();

        reader.Close();

        stream.Close();

    }

    static void OnClientConnectionAvailable(object sender, SecureConnectionResults args)

    {

        if (args.AsyncException != null)

        {

            Console.WriteLine(args.AsyncException);

            return;

        }

        SslStream stream = args.SecureStream;

        Console.WriteLine("Client Connection secured: " + stream.IsAuthenticated);

        StreamWriter writer = new StreamWriter(stream);

        writer.AutoFlush = true;

        writer.WriteLine("Hello from client!");

       

        StreamReader reader = new StreamReader(stream);

        string line = reader.ReadLine();

        Console.WriteLine("Client Recieved: '{0}'", line == null ? "<NULL>" : line);

        writer.Close();

        reader.Close();

        stream.Close();

    }

    static bool IgnoreCertificateErrorsCallback(object sender,

        X509Certificate certificate,

        X509Chain chain,

        SslPolicyErrors sslPolicyErrors)

    {

        if (sslPolicyErrors != SslPolicyErrors.None)

        {

            Console.WriteLine("IgnoreCertificateErrorsCallback: {0}", sslPolicyErrors);

            //you should implement different logic here...

            if ((sslPolicyErrors & SslPolicyErrors.RemoteCertificateChainErrors) != 0)

            {

                foreach (X509ChainStatus chainStatus in chain.ChainStatus)

             {

                    Console.WriteLine("\t" + chainStatus.Status);

                }

            }

        }

        //returning true tells the SslStream object you don't care about any errors.

        return true;

    }

}

 

Note that if you don’t have a certificate for testing this code, you can create a certificate using the MakeCert.exe tool, which ships with the Visual Studio SDK.