From c1115e26c1940c03489be52c8e1368f40b8d4f07 Mon Sep 17 00:00:00 2001 From: Oliver Booth Date: Thu, 19 May 2022 10:37:37 +0100 Subject: [PATCH] Add AES encryption Temporarily disable compression. GZip being weird --- TcpDotNet.ClientIntegrationTest/Program.cs | 23 +++- .../TcpDotNet.ClientIntegrationTest.csproj | 4 + TcpDotNet.ListenerIntegrationTest/Program.cs | 10 +- TcpDotNet/BaseClientNode.cs | 112 +++++++++++++----- TcpDotNet/ClientState.cs | 37 ++++++ TcpDotNet/CryptographyUtils.cs | 19 +++ TcpDotNet/DisconnectReason.cs | 7 +- TcpDotNet/EventData/DisconnectedEventArgs.cs | 22 ++++ TcpDotNet/Node.cs | 15 ++- TcpDotNet/Protocol/HandshakeResponse.cs | 10 ++ .../PacketHandlers/DisconnectPacketHandler.cs | 20 ++++ .../EncryptionResponsePacketHandler.cs | 41 +++++++ .../HandshakeRequestPacketHandler.cs | 41 +++++++ .../Packets/ClientBound/DisconnectPacket.cs | 37 ++++++ .../ClientBound/EncryptionRequestPacket.cs | 66 +++++++++++ .../ClientBound/HandshakeResponsePacket.cs | 53 +++++++++ .../ClientBound/SessionExchangePacket.cs | 39 ++++++ .../ServerBound/EncryptionResponsePacket.cs | 66 +++++++++++ .../ServerBound/HandshakeRequestPacket.cs | 2 +- TcpDotNet/ProtocolClient.cs | 69 ++++++++++- TcpDotNet/ProtocolListener.Client.cs | 6 + TcpDotNet/ProtocolListener.cs | 32 ++++- TcpDotNet/TcpDotNet.csproj | 5 + 23 files changed, 687 insertions(+), 49 deletions(-) create mode 100644 TcpDotNet/ClientState.cs create mode 100644 TcpDotNet/CryptographyUtils.cs create mode 100644 TcpDotNet/EventData/DisconnectedEventArgs.cs create mode 100644 TcpDotNet/Protocol/HandshakeResponse.cs create mode 100644 TcpDotNet/Protocol/PacketHandlers/DisconnectPacketHandler.cs create mode 100644 TcpDotNet/Protocol/PacketHandlers/EncryptionResponsePacketHandler.cs create mode 100644 TcpDotNet/Protocol/PacketHandlers/HandshakeRequestPacketHandler.cs create mode 100644 TcpDotNet/Protocol/Packets/ClientBound/DisconnectPacket.cs create mode 100644 TcpDotNet/Protocol/Packets/ClientBound/EncryptionRequestPacket.cs create mode 100644 TcpDotNet/Protocol/Packets/ClientBound/HandshakeResponsePacket.cs create mode 100644 TcpDotNet/Protocol/Packets/ClientBound/SessionExchangePacket.cs create mode 100644 TcpDotNet/Protocol/Packets/ServerBound/EncryptionResponsePacket.cs diff --git a/TcpDotNet.ClientIntegrationTest/Program.cs b/TcpDotNet.ClientIntegrationTest/Program.cs index 57231ee..550434f 100644 --- a/TcpDotNet.ClientIntegrationTest/Program.cs +++ b/TcpDotNet.ClientIntegrationTest/Program.cs @@ -5,14 +5,31 @@ using TcpDotNet.Protocol.Packets.ClientBound; using TcpDotNet.Protocol.Packets.ServerBound; using var client = new ProtocolClient(); +client.Disconnected += (_, e) => Console.WriteLine($"Disconnected: {e.DisconnectReason}"); + +new Thread(() => +{ + ClientState oldState = client.State; + while (true) + { + if (oldState != client.State) + { + Console.WriteLine($"State changed to {client.State}"); + oldState = client.State; + } + } +}).Start(); + client.RegisterPacketHandler(PacketHandler.Empty); await client.ConnectAsync(IPAddress.IPv6Loopback, 1234); -Console.WriteLine($"Connected to {client.RemoteEndPoint}"); -var ping = new PingPacket(); +Console.WriteLine($"Connected to {client.RemoteEndPoint}. My session is {client.SessionId}"); +var cancellationTokenSource = new CancellationTokenSource(); +cancellationTokenSource.CancelAfter(5000); // if no pong is received in 5 seconds, cancel +var ping = new PingPacket(); Console.WriteLine($"Sending ping packet with payload: {BitConverter.ToString(ping.Payload)}"); -var pong = await client.SendAndReceive(ping); +var pong = await client.SendAndReceive(ping, cancellationTokenSource.Token); Console.WriteLine($"Received pong packet with payload: {BitConverter.ToString(pong.Payload)}"); Console.WriteLine(pong.Payload.SequenceEqual(ping.Payload) ? "Payload matches!" : "Payload does not match!"); diff --git a/TcpDotNet.ClientIntegrationTest/TcpDotNet.ClientIntegrationTest.csproj b/TcpDotNet.ClientIntegrationTest/TcpDotNet.ClientIntegrationTest.csproj index fa69423..2e13429 100644 --- a/TcpDotNet.ClientIntegrationTest/TcpDotNet.ClientIntegrationTest.csproj +++ b/TcpDotNet.ClientIntegrationTest/TcpDotNet.ClientIntegrationTest.csproj @@ -11,4 +11,8 @@ + + + + diff --git a/TcpDotNet.ListenerIntegrationTest/Program.cs b/TcpDotNet.ListenerIntegrationTest/Program.cs index 82f83d7..2f5a1eb 100644 --- a/TcpDotNet.ListenerIntegrationTest/Program.cs +++ b/TcpDotNet.ListenerIntegrationTest/Program.cs @@ -4,14 +4,8 @@ using TcpDotNet.Protocol.Packets.ClientBound; using TcpDotNet.Protocol.Packets.ServerBound; var listener = new ProtocolListener(); -listener.Started += (_, _) => - Console.WriteLine($"Listener started on {listener.LocalEndPoint}"); - -listener.ClientConnected += (_, e) => - Console.WriteLine($"Client connected from {e.Client.RemoteEndPoint} with session {e.Client.SessionId}"); - -listener.ClientDisconnected += (_, e) => - Console.WriteLine($"Client {e.Client.SessionId} disconnected ({e.DisconnectReason})"); +listener.ClientConnected += (_, e) => Console.WriteLine($"Client connected from {e.Client.RemoteEndPoint} with session {e.Client.SessionId}"); +listener.ClientDisconnected += (_, e) => Console.WriteLine($"Client {e.Client.SessionId} disconnected ({e.DisconnectReason})"); listener.RegisterPacketHandler(new PingPacketHandler()); listener.Start(1234); diff --git a/TcpDotNet/BaseClientNode.cs b/TcpDotNet/BaseClientNode.cs index 4e856dd..8f61364 100644 --- a/TcpDotNet/BaseClientNode.cs +++ b/TcpDotNet/BaseClientNode.cs @@ -1,10 +1,11 @@ using System.Collections.Concurrent; -using System.IO.Compression; using System.Net; using System.Net.Sockets; using System.Reflection; -using System.Security.Cryptography; +using Chilkat; using TcpDotNet.Protocol; +using Stream = System.IO.Stream; +using Task = System.Threading.Tasks.Task; namespace TcpDotNet; @@ -15,6 +16,13 @@ public abstract class BaseClientNode : Node { private readonly ConcurrentDictionary>> _packetCompletionSources = new(); + /// + /// Initializes a new instance of the class. + /// + protected BaseClientNode() + { + } + /// /// Gets a value indicating whether the client is connected. /// @@ -35,6 +43,11 @@ public abstract class BaseClientNode : Node /// The session ID. public Guid SessionId { get; internal set; } + /// + /// Gets the current state of the client. + /// + public ClientState State { get; protected internal set; } + /// /// Gets or sets a value indicating whether GZip compression is enabled. /// @@ -48,10 +61,26 @@ public abstract class BaseClientNode : Node internal bool UseEncryption { get; set; } = false; /// - /// Gets the AES implementation used by this client. + /// Gets or sets the AES implementation used by this client. /// /// The AES implementation. - internal Aes Aes { get; } = Aes.Create(); + internal Crypt2 Aes { get; set; } = CryptographyUtils.GenerateAes(Array.Empty()); + + /// + public override void Close() + { + IsConnected = false; + State = ClientState.Disconnected; + base.Close(); + } + + /// + public override void Dispose() + { + IsConnected = false; + State = ClientState.Disconnected; + base.Dispose(); + } /// /// Reads the next packet from the client's stream. @@ -65,10 +94,11 @@ public abstract class BaseClientNode : Node int length; try { - length = networkReader.ReadInt32(); + length = await Task.Run(() => networkReader.ReadInt32(), cancellationToken); } catch (EndOfStreamException) { + State = ClientState.Disconnected; throw new DisconnectedException(); } @@ -77,11 +107,18 @@ public abstract class BaseClientNode : Node buffer.Write(networkReader.ReadBytes(length)); buffer.Position = 0; - if (UseCompression) targetStream = new GZipStream(targetStream, CompressionMode.Decompress); - if (UseEncryption) targetStream = new CryptoStream(targetStream, Aes.CreateDecryptor(), CryptoStreamMode.Read); + // if (UseCompression) targetStream = new GZipStream(targetStream, CompressionLevel.Optimal); + if (UseEncryption) + { + var data = new byte[targetStream.Length]; + _ = await targetStream.ReadAsync(data, 0, data.Length, cancellationToken); + buffer.SetLength(0); + buffer.Write(Aes.DecryptBytes(data)); + buffer.Position = 0; + } using var bufferReader = new ProtocolReader(targetStream); - int packetHeader = bufferReader.ReadInt32(); + int packetHeader = await Task.Run(() => bufferReader.ReadInt32(), cancellationToken); if (!RegisteredPackets.TryGetValue(packetHeader, out Type? packetType)) { @@ -127,31 +164,41 @@ public abstract class BaseClientNode : Node var buffer = new MemoryStream(); Stream targetStream = buffer; - if (UseEncryption) targetStream = new CryptoStream(targetStream, Aes.CreateEncryptor(), CryptoStreamMode.Write); - if (UseCompression) targetStream = new GZipStream(targetStream, CompressionMode.Compress); + // if (UseCompression) targetStream = new GZipStream(targetStream, CompressionMode.Compress); await using var bufferWriter = new ProtocolWriter(targetStream); bufferWriter.Write(packet.Id); await packet.SerializeAsync(bufferWriter); - - switch (targetStream) - { - case CryptoStream cryptoStream: - cryptoStream.FlushFinalBlock(); - break; - case GZipStream {BaseStream: CryptoStream baseCryptoStream}: - baseCryptoStream.FlushFinalBlock(); - break; - } - await targetStream.FlushAsync(cancellationToken); buffer.Position = 0; await using var networkStream = new NetworkStream(BaseSocket); await using var networkWriter = new ProtocolWriter(networkStream); - networkWriter.Write((int) buffer.Length); - await buffer.CopyToAsync(networkStream, cancellationToken); - await networkStream.FlushAsync(cancellationToken); + + if (UseEncryption) + { + byte[] data = buffer.ToArray(); + buffer.SetLength(0); + buffer.Write(Aes.EncryptBytes(data)); + buffer.Position = 0; + } + + var length = (int) buffer.Length; + networkWriter.Write(length); + + try + { + await buffer.CopyToAsync(networkStream, cancellationToken); + await networkStream.FlushAsync(cancellationToken); + } + catch (IOException) + { + State = ClientState.Disconnected; + IsConnected = false; + + if (this is ProtocolClient client) + client.OnDisconnect(DisconnectReason.EndOfStream); + } } /// @@ -232,13 +279,22 @@ public abstract class BaseClientNode : Node { while (!cancellationToken.IsCancellationRequested) { - Packet? packet = await ReadNextPacketAsync(cancellationToken); - if (packet is TPacket typedPacket) + try { - completionSource.TrySetResult(typedPacket); - return; + Packet? packet = await ReadNextPacketAsync(cancellationToken); + if (packet is TPacket typedPacket) + { + completionSource.TrySetResult(typedPacket); + return; + } + } + catch (TaskCanceledException) + { + completionSource.SetCanceled(); } } + + completionSource.SetCanceled(); }, cancellationToken); var packet = (TPacket) await Task.Run(() => completionSource.Task, cancellationToken); diff --git a/TcpDotNet/ClientState.cs b/TcpDotNet/ClientState.cs new file mode 100644 index 0000000..01532ba --- /dev/null +++ b/TcpDotNet/ClientState.cs @@ -0,0 +1,37 @@ +namespace TcpDotNet; + +/// +/// An enumeration of states for a to be in. +/// +public enum ClientState +{ + /// + /// The client is not connected to a remote server. + /// + None, + + /// + /// The client has disconnected from a previously-connected server. + /// + Disconnected, + + /// + /// The client is establishing a connection to a remote server. + /// + Connecting, + + /// + /// The client is handshaking. + /// + Handshaking, + + /// + /// The client is exchanging encryption keys. + /// + Encrypting, + + /// + /// The client is connected. + /// + Connected +} diff --git a/TcpDotNet/CryptographyUtils.cs b/TcpDotNet/CryptographyUtils.cs new file mode 100644 index 0000000..2907bb9 --- /dev/null +++ b/TcpDotNet/CryptographyUtils.cs @@ -0,0 +1,19 @@ +using Chilkat; + +namespace TcpDotNet; + +internal static class CryptographyUtils +{ + public static Crypt2 GenerateAes(byte[] key) + { + return new Crypt2 + { + CryptAlgorithm = "aes", + CipherMode = "cfb", + KeyLength = 128, + PaddingScheme = 0, + SecretKey = key[..], + IV = key[..] + }; + } +} diff --git a/TcpDotNet/DisconnectReason.cs b/TcpDotNet/DisconnectReason.cs index 9d78ccc..499e02c 100644 --- a/TcpDotNet/DisconnectReason.cs +++ b/TcpDotNet/DisconnectReason.cs @@ -13,5 +13,10 @@ public enum DisconnectReason /// /// The client reached an unexpected end of stream. /// - EndOfStream + EndOfStream, + + /// + /// The client sent an invalid encryption payload. + /// + InvalidEncryptionKey } diff --git a/TcpDotNet/EventData/DisconnectedEventArgs.cs b/TcpDotNet/EventData/DisconnectedEventArgs.cs new file mode 100644 index 0000000..760f984 --- /dev/null +++ b/TcpDotNet/EventData/DisconnectedEventArgs.cs @@ -0,0 +1,22 @@ +namespace TcpDotNet.EventData; + +/// +/// Provides event data for the event. +/// +public sealed class DisconnectedEventArgs : EventArgs +{ + /// + /// Initializes a new instance of the class. + /// + /// The reason for the disconnect. + public DisconnectedEventArgs(DisconnectReason disconnectReason) + { + DisconnectReason = disconnectReason; + } + + /// + /// Gets the reason for the disconnect. + /// + /// The disconnect reason. + public DisconnectReason DisconnectReason { get; } +} diff --git a/TcpDotNet/Node.cs b/TcpDotNet/Node.cs index 9e4d25f..bdcf759 100644 --- a/TcpDotNet/Node.cs +++ b/TcpDotNet/Node.cs @@ -11,6 +11,11 @@ namespace TcpDotNet; /// public abstract class Node : IDisposable { + /// + /// The protocol version of this node. + /// + public const int ProtocolVersion = 1; + private readonly ConcurrentDictionary _registeredPackets = new(); private readonly ConcurrentDictionary> _registeredPacketHandlers = new(); @@ -35,8 +40,16 @@ public abstract class Node : IDisposable new ReadOnlyDictionary>( _registeredPacketHandlers.ToDictionary(p => p.Key, p => (IReadOnlyCollection) p.Value.AsReadOnly())); + /// + /// Closes the base socket connection and releases all associated resources. + /// + public virtual void Close() + { + BaseSocket.Close(); + } + /// - public void Dispose() + public virtual void Dispose() { BaseSocket.Dispose(); } diff --git a/TcpDotNet/Protocol/HandshakeResponse.cs b/TcpDotNet/Protocol/HandshakeResponse.cs new file mode 100644 index 0000000..a0a5a30 --- /dev/null +++ b/TcpDotNet/Protocol/HandshakeResponse.cs @@ -0,0 +1,10 @@ +namespace TcpDotNet.Protocol; + +/// +/// An enumeration of handshake responses. +/// +public enum HandshakeResponse : byte +{ + Success = 0x00, + UnsupportedProtocolVersion = 0x01 +} diff --git a/TcpDotNet/Protocol/PacketHandlers/DisconnectPacketHandler.cs b/TcpDotNet/Protocol/PacketHandlers/DisconnectPacketHandler.cs new file mode 100644 index 0000000..0107ba9 --- /dev/null +++ b/TcpDotNet/Protocol/PacketHandlers/DisconnectPacketHandler.cs @@ -0,0 +1,20 @@ +using TcpDotNet.Protocol.Packets.ClientBound; + +namespace TcpDotNet.Protocol.PacketHandlers; + +internal sealed class DisconnectPacketHandler : PacketHandler +{ + /// + public override Task HandleAsync( + BaseClientNode recipient, + DisconnectPacket packet, + CancellationToken cancellationToken = default + ) + { + if (recipient is ProtocolClient client) + client.OnDisconnect(packet.Reason); + + recipient.Close(); + return Task.CompletedTask; + } +} diff --git a/TcpDotNet/Protocol/PacketHandlers/EncryptionResponsePacketHandler.cs b/TcpDotNet/Protocol/PacketHandlers/EncryptionResponsePacketHandler.cs new file mode 100644 index 0000000..b80c1fe --- /dev/null +++ b/TcpDotNet/Protocol/PacketHandlers/EncryptionResponsePacketHandler.cs @@ -0,0 +1,41 @@ +using System.Security.Cryptography; +using TcpDotNet.Protocol.Packets.ClientBound; +using TcpDotNet.Protocol.Packets.ServerBound; + +namespace TcpDotNet.Protocol.PacketHandlers; + +/// +/// Represents a handler for a . +/// +internal sealed class EncryptionResponsePacketHandler : PacketHandler +{ + /// + public override async Task HandleAsync( + BaseClientNode recipient, + EncryptionResponsePacket packet, + CancellationToken cancellationToken = default + ) + { + if (recipient is not ProtocolListener.Client client) + return; + + RSACryptoServiceProvider rsa = client.ParentListener.Rsa; + byte[] payload = rsa.Decrypt(packet.Payload, true); + if (!payload.SequenceEqual(client.AesVerificationPayload)) + { + client.ParentListener.OnClientDisconnect(client, DisconnectReason.InvalidEncryptionKey); + return; + } + + byte[] key = rsa.Decrypt(packet.SharedSecret, true); + client.AesVerificationPayload = Array.Empty(); + client.Aes = CryptographyUtils.GenerateAes(key); + client.State = ClientState.Connected; + client.ParentListener.OnClientConnect(client); + + var sessionPacket = new SessionExchangePacket(client.SessionId); + await client.SendPacketAsync(sessionPacket, cancellationToken); + + client.UseEncryption = true; + } +} diff --git a/TcpDotNet/Protocol/PacketHandlers/HandshakeRequestPacketHandler.cs b/TcpDotNet/Protocol/PacketHandlers/HandshakeRequestPacketHandler.cs new file mode 100644 index 0000000..2df88d0 --- /dev/null +++ b/TcpDotNet/Protocol/PacketHandlers/HandshakeRequestPacketHandler.cs @@ -0,0 +1,41 @@ +using TcpDotNet.Protocol.Packets.ClientBound; +using TcpDotNet.Protocol.Packets.ServerBound; + +namespace TcpDotNet.Protocol.PacketHandlers; + +/// +/// Represents a handler for a . +/// +internal sealed class HandshakeRequestPacketHandler : PacketHandler +{ + /// + public override async Task HandleAsync( + BaseClientNode recipient, + HandshakeRequestPacket packet, + CancellationToken cancellationToken = default + ) + { + if (recipient is not ProtocolListener.Client client) + return; + + HandshakeResponsePacket response; + + if (packet.ProtocolVersion != Node.ProtocolVersion) + { + response = new HandshakeResponsePacket(packet.ProtocolVersion, HandshakeResponse.UnsupportedProtocolVersion); + await client.SendPacketAsync(response, cancellationToken); + client.Close(); + return; + } + + response = new HandshakeResponsePacket(packet.ProtocolVersion, HandshakeResponse.Success); + await client.SendPacketAsync(response, cancellationToken); + + client.State = ClientState.Encrypting; + await Task.Delay(1000, cancellationToken); + + var encryptionRequest = new EncryptionRequestPacket(client.ParentListener.Rsa.ExportCspBlob(false)); + client.AesVerificationPayload = encryptionRequest.Payload; + await client.SendPacketAsync(encryptionRequest, cancellationToken); + } +} diff --git a/TcpDotNet/Protocol/Packets/ClientBound/DisconnectPacket.cs b/TcpDotNet/Protocol/Packets/ClientBound/DisconnectPacket.cs new file mode 100644 index 0000000..8984529 --- /dev/null +++ b/TcpDotNet/Protocol/Packets/ClientBound/DisconnectPacket.cs @@ -0,0 +1,37 @@ +namespace TcpDotNet.Protocol.Packets.ClientBound; + +[Packet(0x7FFFFFFF)] +internal sealed class DisconnectPacket : Packet +{ + /// + /// Initializes a new instance of the class. + /// + /// The reason for the disconnect. + public DisconnectPacket(DisconnectReason reason) + { + Reason = reason; + } + + internal DisconnectPacket() + { + } + + /// + /// Gets the reason for the disconnect. + /// + public DisconnectReason Reason { get; private set; } + + /// + protected internal override Task DeserializeAsync(ProtocolReader reader) + { + Reason = (DisconnectReason) reader.ReadByte(); + return Task.CompletedTask; + } + + /// + protected internal override Task SerializeAsync(ProtocolWriter writer) + { + writer.Write((byte) Reason); + return Task.CompletedTask; + } +} diff --git a/TcpDotNet/Protocol/Packets/ClientBound/EncryptionRequestPacket.cs b/TcpDotNet/Protocol/Packets/ClientBound/EncryptionRequestPacket.cs new file mode 100644 index 0000000..4721c41 --- /dev/null +++ b/TcpDotNet/Protocol/Packets/ClientBound/EncryptionRequestPacket.cs @@ -0,0 +1,66 @@ +using System.Security.Cryptography; + +namespace TcpDotNet.Protocol.Packets.ClientBound; + +/// +/// Represents a packet which requests encryption from the client. +/// +[Packet(0xE2)] +internal sealed class EncryptionRequestPacket : Packet +{ + /// + /// Initializes a new instance of the class. + /// + /// The public key. + public EncryptionRequestPacket(byte[] publicKey) : this() + { + // ReSharper disable once NullCoalescingConditionIsAlwaysNotNullAccordingToAPIContract + publicKey ??= Array.Empty(); + PublicKey = publicKey[..]; + } + + internal EncryptionRequestPacket() + { + PublicKey = Array.Empty(); + + using var rng = new RNGCryptoServiceProvider(); + Payload = new byte[64]; + rng.GetBytes(Payload); + } + + /// + /// Gets the payload. + /// + /// The payload. + public byte[] Payload { get; private set; } + + /// + /// Gets the public key. + /// + /// The public key. + public byte[] PublicKey { get; private set; } + + /// + protected internal override Task DeserializeAsync(ProtocolReader reader) + { + int length = reader.ReadInt32(); + PublicKey = reader.ReadBytes(length); + + length = reader.ReadInt32(); + Payload = reader.ReadBytes(length); + + return Task.CompletedTask; + } + + /// + protected internal override Task SerializeAsync(ProtocolWriter writer) + { + writer.Write(PublicKey.Length); + writer.Write(PublicKey); + + writer.Write(Payload.Length); + writer.Write(Payload); + + return Task.CompletedTask; + } +} diff --git a/TcpDotNet/Protocol/Packets/ClientBound/HandshakeResponsePacket.cs b/TcpDotNet/Protocol/Packets/ClientBound/HandshakeResponsePacket.cs new file mode 100644 index 0000000..587d93f --- /dev/null +++ b/TcpDotNet/Protocol/Packets/ClientBound/HandshakeResponsePacket.cs @@ -0,0 +1,53 @@ +using TcpDotNet.Protocol.Packets.ServerBound; + +namespace TcpDotNet.Protocol.Packets.ClientBound; + +/// +/// Represents a packet which responds to a . +/// +[Packet(0xE1)] +internal sealed class HandshakeResponsePacket : Packet +{ + /// + /// Initializes a new instance of the class. + /// + /// The requested protocol version. + /// The handshake response. + public HandshakeResponsePacket(int protocolVersion, HandshakeResponse handshakeResponse) + { + ProtocolVersion = protocolVersion; + HandshakeResponse = handshakeResponse; + } + + internal HandshakeResponsePacket() + { + } + + /// + /// Gets the handshake response. + /// + /// The handshake response. + public HandshakeResponse HandshakeResponse { get; private set; } + + /// + /// Gets the requested protocol version. + /// + /// The protocol version. + public int ProtocolVersion { get; private set; } + + /// + protected internal override Task DeserializeAsync(ProtocolReader reader) + { + HandshakeResponse = (HandshakeResponse) reader.ReadByte(); + ProtocolVersion = reader.ReadInt32(); + return Task.CompletedTask; + } + + /// + protected internal override Task SerializeAsync(ProtocolWriter writer) + { + writer.Write((byte) HandshakeResponse); + writer.Write(ProtocolVersion); + return Task.CompletedTask; + } +} diff --git a/TcpDotNet/Protocol/Packets/ClientBound/SessionExchangePacket.cs b/TcpDotNet/Protocol/Packets/ClientBound/SessionExchangePacket.cs new file mode 100644 index 0000000..b555c9f --- /dev/null +++ b/TcpDotNet/Protocol/Packets/ClientBound/SessionExchangePacket.cs @@ -0,0 +1,39 @@ +namespace TcpDotNet.Protocol.Packets.ClientBound; + +[Packet(0xE4)] +internal sealed class SessionExchangePacket : Packet +{ + /// + /// Initializes a new instance of the class. + /// + /// The session. + public SessionExchangePacket(Guid session) + { + Session = session; + } + + internal SessionExchangePacket() + { + Session = Guid.Empty; + } + + /// + /// Gets the session. + /// + /// The session. + public Guid Session { get; private set; } + + /// + protected internal override Task DeserializeAsync(ProtocolReader reader) + { + Session = reader.ReadGuid(); + return Task.CompletedTask; + } + + /// + protected internal override Task SerializeAsync(ProtocolWriter writer) + { + writer.Write(Session); + return Task.CompletedTask; + } +} diff --git a/TcpDotNet/Protocol/Packets/ServerBound/EncryptionResponsePacket.cs b/TcpDotNet/Protocol/Packets/ServerBound/EncryptionResponsePacket.cs new file mode 100644 index 0000000..0bf33c8 --- /dev/null +++ b/TcpDotNet/Protocol/Packets/ServerBound/EncryptionResponsePacket.cs @@ -0,0 +1,66 @@ +using TcpDotNet.Protocol.Packets.ClientBound; + +namespace TcpDotNet.Protocol.Packets.ServerBound; + +/// +/// Represents a packet which responds to a . +/// +[Packet(0xE3)] +internal sealed class EncryptionResponsePacket : Packet +{ + /// + /// Initializes a new instance of the class. + /// + /// The payload. + /// The RSA-encrypted symmetric shared secret. + public EncryptionResponsePacket(byte[] payload, byte[] key) : this() + { + // ReSharper disable ConditionalAccessQualifierIsNonNullableAccordingToAPIContract + Payload = payload?[..] ?? Array.Empty(); + SharedSecret = key?[..] ?? Array.Empty(); + // ReSharper enable ConditionalAccessQualifierIsNonNullableAccordingToAPIContract + } + + internal EncryptionResponsePacket() + { + SharedSecret = Array.Empty(); + Payload = Array.Empty(); + } + + /// + /// Gets the symmetric shared secret. + /// + /// The RSA-encrypted symmetric shared secret. + public byte[] SharedSecret { get; private set; } + + /// + /// Gets the payload. + /// + /// The payload. + public byte[] Payload { get; private set; } + + + /// + protected internal override Task DeserializeAsync(ProtocolReader reader) + { + int length = reader.ReadInt32(); + SharedSecret = reader.ReadBytes(length); + + length = reader.ReadInt32(); + Payload = reader.ReadBytes(length); + + return Task.CompletedTask; + } + + /// + protected internal override Task SerializeAsync(ProtocolWriter writer) + { + writer.Write(SharedSecret.Length); + writer.Write(SharedSecret); + + writer.Write(Payload.Length); + writer.Write(Payload); + + return Task.CompletedTask; + } +} diff --git a/TcpDotNet/Protocol/Packets/ServerBound/HandshakeRequestPacket.cs b/TcpDotNet/Protocol/Packets/ServerBound/HandshakeRequestPacket.cs index 1297cba..dd5394b 100644 --- a/TcpDotNet/Protocol/Packets/ServerBound/HandshakeRequestPacket.cs +++ b/TcpDotNet/Protocol/Packets/ServerBound/HandshakeRequestPacket.cs @@ -3,7 +3,7 @@ /// /// Represents a packet which requests a handshake with a . /// -[Packet(0x00000001)] +[Packet(0xE0)] internal sealed class HandshakeRequestPacket : Packet { /// diff --git a/TcpDotNet/ProtocolClient.cs b/TcpDotNet/ProtocolClient.cs index 3e1f909..12120f0 100644 --- a/TcpDotNet/ProtocolClient.cs +++ b/TcpDotNet/ProtocolClient.cs @@ -1,5 +1,13 @@ using System.Net; using System.Net.Sockets; +using System.Security.Cryptography; +using TcpDotNet.EventData; +using TcpDotNet.Protocol; +using TcpDotNet.Protocol.PacketHandlers; +using TcpDotNet.Protocol.Packets.ClientBound; +using TcpDotNet.Protocol.Packets.ServerBound; +using Socket = System.Net.Sockets.Socket; +using Task = System.Threading.Tasks.Task; namespace TcpDotNet; @@ -13,10 +21,17 @@ public sealed class ProtocolClient : BaseClientNode /// public ProtocolClient() { - Aes.GenerateKey(); - Aes.GenerateIV(); + RegisterPacketHandler(PacketHandler.Empty); + RegisterPacketHandler(PacketHandler.Empty); + RegisterPacketHandler(PacketHandler.Empty); + RegisterPacketHandler(new DisconnectPacketHandler()); } + /// + /// Occurs when the client has been disconnected. + /// + public event EventHandler? Disconnected; + /// /// Establishes a connection to a remote host. /// @@ -62,8 +77,56 @@ public sealed class ProtocolClient : BaseClientNode public async Task ConnectAsync(EndPoint remoteEP, CancellationToken cancellationToken = default) { if (remoteEP is null) throw new ArgumentNullException(nameof(remoteEP)); + + State = ClientState.Connecting; BaseSocket = new Socket(remoteEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp); - await Task.Run(() => BaseSocket.ConnectAsync(remoteEP), cancellationToken); + try + { + await Task.Run(() => BaseSocket.ConnectAsync(remoteEP), cancellationToken); + } + catch + { + State = ClientState.None; + throw; + } + IsConnected = true; + + State = ClientState.Handshaking; + var handshakeRequest = new HandshakeRequestPacket(ProtocolVersion); + var handshakeResponse = + await SendAndReceive(handshakeRequest, cancellationToken); + + if (handshakeResponse.HandshakeResponse != HandshakeResponse.Success) + { + Close(); + IsConnected = false; + throw new InvalidOperationException("Handshake failed. " + + $"Server responded with {handshakeResponse.HandshakeResponse:D}"); + } + + State = ClientState.Encrypting; + var encryptionRequest = await WaitForPacketAsync(cancellationToken); + using var rsa = new RSACryptoServiceProvider(2048); + rsa.ImportCspBlob(encryptionRequest.PublicKey); + byte[] encryptedPayload = rsa.Encrypt(encryptionRequest.Payload, true); + + var key = new byte[128]; + using var rng = new RNGCryptoServiceProvider(); + rng.GetBytes(key); + + Aes = CryptographyUtils.GenerateAes(key); + var encryptionResponse = new EncryptionResponsePacket(encryptedPayload, rsa.Encrypt(key, true)); + var sessionPacket = await SendAndReceive(encryptionResponse, cancellationToken); + + SessionId = sessionPacket.Session; + UseEncryption = true; + State = ClientState.Connected; + } + + internal void OnDisconnect(DisconnectReason reason) + { + Disconnected?.Invoke(this, new DisconnectedEventArgs(reason)); + Close(); } } diff --git a/TcpDotNet/ProtocolListener.Client.cs b/TcpDotNet/ProtocolListener.Client.cs index 33d82fb..153f5a3 100644 --- a/TcpDotNet/ProtocolListener.Client.cs +++ b/TcpDotNet/ProtocolListener.Client.cs @@ -24,6 +24,11 @@ public sealed partial class ProtocolListener /// The parent listener. public ProtocolListener ParentListener { get; } + /// + /// Gets or sets the client's verification payload. + /// + internal byte[] AesVerificationPayload { get; set; } = Array.Empty(); + internal void Start() { foreach (Type packetType in ParentListener.RegisteredPackets.Values) @@ -33,6 +38,7 @@ public sealed partial class ProtocolListener foreach (PacketHandler handler in handlers) RegisterPacketHandler(packetType, handler); + State = ClientState.Handshaking; Task.Run(ReadLoopAsync); } diff --git a/TcpDotNet/ProtocolListener.cs b/TcpDotNet/ProtocolListener.cs index b922865..335d313 100644 --- a/TcpDotNet/ProtocolListener.cs +++ b/TcpDotNet/ProtocolListener.cs @@ -1,7 +1,10 @@ using System.Net; using System.Net.Sockets; +using System.Security.Cryptography; using TcpDotNet.EventData; using TcpDotNet.Protocol; +using TcpDotNet.Protocol.PacketHandlers; +using TcpDotNet.Protocol.Packets.ClientBound; namespace TcpDotNet; @@ -12,6 +15,15 @@ public sealed partial class ProtocolListener : Node { private readonly List _clients = new(); + /// + /// Initializes a new instance of the class. + /// + public ProtocolListener() + { + RegisterPacketHandler(new HandshakeRequestPacketHandler()); + RegisterPacketHandler(new EncryptionResponsePacketHandler()); + } + /// /// Occurs when a client connects to the listener. /// @@ -62,6 +74,12 @@ public sealed partial class ProtocolListener : Node /// The that is using for communications. public EndPoint LocalEndPoint => BaseSocket.LocalEndPoint; + /// + /// Gets the RSA provider for this listener. + /// + /// The RSA provider. + internal RSACryptoServiceProvider Rsa { get; } = new(2048); + /// /// Starts the listener on the specified port, using as the bind address, or /// if is . @@ -98,8 +116,17 @@ public sealed partial class ProtocolListener : Node Task.Run(AcceptLoop); } - private void OnClientDisconnect(Client client, DisconnectReason disconnectReason) + internal void OnClientConnect(Client client) { + lock (_clients) _clients.Add(client); + ClientConnected?.Invoke(this, new ClientConnectedEventArgs(client)); + } + + internal void OnClientDisconnect(Client client, DisconnectReason disconnectReason) + { + var disconnectPacket = new DisconnectPacket(disconnectReason); + client.SendPacketAsync(disconnectPacket).GetAwaiter().GetResult(); + client.Close(); lock (_clients) _clients.Remove(client); ClientDisconnected?.Invoke(this, new ClientDisconnectedEventArgs(client, disconnectReason)); } @@ -116,10 +143,7 @@ public sealed partial class ProtocolListener : Node Socket socket = await BaseSocket.AcceptAsync(); var client = new Client(this, socket); - lock (_clients) _clients.Add(client); - client.Start(); - ClientConnected?.Invoke(this, new ClientConnectedEventArgs(client)); } } } diff --git a/TcpDotNet/TcpDotNet.csproj b/TcpDotNet/TcpDotNet.csproj index bf8bc39..1c1f391 100644 --- a/TcpDotNet/TcpDotNet.csproj +++ b/TcpDotNet/TcpDotNet.csproj @@ -7,4 +7,9 @@ 10 + + + + +