1
0
mirror of https://github.com/oliverbooth/TcpDotNet synced 2024-10-18 06:16:10 +00:00

Add AES encryption

Temporarily disable compression. GZip being weird
This commit is contained in:
Oliver Booth 2022-05-19 10:37:37 +01:00
parent a87e186c58
commit c1115e26c1
No known key found for this signature in database
GPG Key ID: 32A00B35503AF634
23 changed files with 687 additions and 49 deletions

View File

@ -5,14 +5,31 @@ using TcpDotNet.Protocol.Packets.ClientBound;
using TcpDotNet.Protocol.Packets.ServerBound;
using var client = new ProtocolClient();
client.Disconnected += (_, e) => Console.WriteLine($"Disconnected: {e.DisconnectReason}");
new Thread(() =>
{
ClientState oldState = client.State;
while (true)
{
if (oldState != client.State)
{
Console.WriteLine($"State changed to {client.State}");
oldState = client.State;
}
}
}).Start();
client.RegisterPacketHandler(PacketHandler<PongPacket>.Empty);
await client.ConnectAsync(IPAddress.IPv6Loopback, 1234);
Console.WriteLine($"Connected to {client.RemoteEndPoint}");
var ping = new PingPacket();
Console.WriteLine($"Connected to {client.RemoteEndPoint}. My session is {client.SessionId}");
var cancellationTokenSource = new CancellationTokenSource();
cancellationTokenSource.CancelAfter(5000); // if no pong is received in 5 seconds, cancel
var ping = new PingPacket();
Console.WriteLine($"Sending ping packet with payload: {BitConverter.ToString(ping.Payload)}");
var pong = await client.SendAndReceive<PingPacket, PongPacket>(ping);
var pong = await client.SendAndReceive<PingPacket, PongPacket>(ping, cancellationTokenSource.Token);
Console.WriteLine($"Received pong packet with payload: {BitConverter.ToString(pong.Payload)}");
Console.WriteLine(pong.Payload.SequenceEqual(ping.Payload) ? "Payload matches!" : "Payload does not match!");

View File

@ -11,4 +11,8 @@
<ProjectReference Include="..\TcpDotNet\TcpDotNet.csproj" />
</ItemGroup>
<ItemGroup>
<Folder Include="PacketHandlers" />
</ItemGroup>
</Project>

View File

@ -4,14 +4,8 @@ using TcpDotNet.Protocol.Packets.ClientBound;
using TcpDotNet.Protocol.Packets.ServerBound;
var listener = new ProtocolListener();
listener.Started += (_, _) =>
Console.WriteLine($"Listener started on {listener.LocalEndPoint}");
listener.ClientConnected += (_, e) =>
Console.WriteLine($"Client connected from {e.Client.RemoteEndPoint} with session {e.Client.SessionId}");
listener.ClientDisconnected += (_, e) =>
Console.WriteLine($"Client {e.Client.SessionId} disconnected ({e.DisconnectReason})");
listener.ClientConnected += (_, e) => Console.WriteLine($"Client connected from {e.Client.RemoteEndPoint} with session {e.Client.SessionId}");
listener.ClientDisconnected += (_, e) => Console.WriteLine($"Client {e.Client.SessionId} disconnected ({e.DisconnectReason})");
listener.RegisterPacketHandler(new PingPacketHandler());
listener.Start(1234);

View File

@ -1,10 +1,11 @@
using System.Collections.Concurrent;
using System.IO.Compression;
using System.Net;
using System.Net.Sockets;
using System.Reflection;
using System.Security.Cryptography;
using Chilkat;
using TcpDotNet.Protocol;
using Stream = System.IO.Stream;
using Task = System.Threading.Tasks.Task;
namespace TcpDotNet;
@ -15,6 +16,13 @@ public abstract class BaseClientNode : Node
{
private readonly ConcurrentDictionary<int, List<TaskCompletionSource<Packet>>> _packetCompletionSources = new();
/// <summary>
/// Initializes a new instance of the <see cref="BaseClientNode" /> class.
/// </summary>
protected BaseClientNode()
{
}
/// <summary>
/// Gets a value indicating whether the client is connected.
/// </summary>
@ -35,6 +43,11 @@ public abstract class BaseClientNode : Node
/// <value>The session ID.</value>
public Guid SessionId { get; internal set; }
/// <summary>
/// Gets the current state of the client.
/// </summary>
public ClientState State { get; protected internal set; }
/// <summary>
/// Gets or sets a value indicating whether GZip compression is enabled.
/// </summary>
@ -48,10 +61,26 @@ public abstract class BaseClientNode : Node
internal bool UseEncryption { get; set; } = false;
/// <summary>
/// Gets the AES implementation used by this client.
/// Gets or sets the AES implementation used by this client.
/// </summary>
/// <value>The AES implementation.</value>
internal Aes Aes { get; } = Aes.Create();
internal Crypt2 Aes { get; set; } = CryptographyUtils.GenerateAes(Array.Empty<byte>());
/// <inheritdoc />
public override void Close()
{
IsConnected = false;
State = ClientState.Disconnected;
base.Close();
}
/// <inheritdoc />
public override void Dispose()
{
IsConnected = false;
State = ClientState.Disconnected;
base.Dispose();
}
/// <summary>
/// Reads the next packet from the client's stream.
@ -65,10 +94,11 @@ public abstract class BaseClientNode : Node
int length;
try
{
length = networkReader.ReadInt32();
length = await Task.Run(() => networkReader.ReadInt32(), cancellationToken);
}
catch (EndOfStreamException)
{
State = ClientState.Disconnected;
throw new DisconnectedException();
}
@ -77,11 +107,18 @@ public abstract class BaseClientNode : Node
buffer.Write(networkReader.ReadBytes(length));
buffer.Position = 0;
if (UseCompression) targetStream = new GZipStream(targetStream, CompressionMode.Decompress);
if (UseEncryption) targetStream = new CryptoStream(targetStream, Aes.CreateDecryptor(), CryptoStreamMode.Read);
// if (UseCompression) targetStream = new GZipStream(targetStream, CompressionLevel.Optimal);
if (UseEncryption)
{
var data = new byte[targetStream.Length];
_ = await targetStream.ReadAsync(data, 0, data.Length, cancellationToken);
buffer.SetLength(0);
buffer.Write(Aes.DecryptBytes(data));
buffer.Position = 0;
}
using var bufferReader = new ProtocolReader(targetStream);
int packetHeader = bufferReader.ReadInt32();
int packetHeader = await Task.Run(() => bufferReader.ReadInt32(), cancellationToken);
if (!RegisteredPackets.TryGetValue(packetHeader, out Type? packetType))
{
@ -127,31 +164,41 @@ public abstract class BaseClientNode : Node
var buffer = new MemoryStream();
Stream targetStream = buffer;
if (UseEncryption) targetStream = new CryptoStream(targetStream, Aes.CreateEncryptor(), CryptoStreamMode.Write);
if (UseCompression) targetStream = new GZipStream(targetStream, CompressionMode.Compress);
// if (UseCompression) targetStream = new GZipStream(targetStream, CompressionMode.Compress);
await using var bufferWriter = new ProtocolWriter(targetStream);
bufferWriter.Write(packet.Id);
await packet.SerializeAsync(bufferWriter);
switch (targetStream)
{
case CryptoStream cryptoStream:
cryptoStream.FlushFinalBlock();
break;
case GZipStream {BaseStream: CryptoStream baseCryptoStream}:
baseCryptoStream.FlushFinalBlock();
break;
}
await targetStream.FlushAsync(cancellationToken);
buffer.Position = 0;
await using var networkStream = new NetworkStream(BaseSocket);
await using var networkWriter = new ProtocolWriter(networkStream);
networkWriter.Write((int) buffer.Length);
await buffer.CopyToAsync(networkStream, cancellationToken);
await networkStream.FlushAsync(cancellationToken);
if (UseEncryption)
{
byte[] data = buffer.ToArray();
buffer.SetLength(0);
buffer.Write(Aes.EncryptBytes(data));
buffer.Position = 0;
}
var length = (int) buffer.Length;
networkWriter.Write(length);
try
{
await buffer.CopyToAsync(networkStream, cancellationToken);
await networkStream.FlushAsync(cancellationToken);
}
catch (IOException)
{
State = ClientState.Disconnected;
IsConnected = false;
if (this is ProtocolClient client)
client.OnDisconnect(DisconnectReason.EndOfStream);
}
}
/// <summary>
@ -232,13 +279,22 @@ public abstract class BaseClientNode : Node
{
while (!cancellationToken.IsCancellationRequested)
{
Packet? packet = await ReadNextPacketAsync(cancellationToken);
if (packet is TPacket typedPacket)
try
{
completionSource.TrySetResult(typedPacket);
return;
Packet? packet = await ReadNextPacketAsync(cancellationToken);
if (packet is TPacket typedPacket)
{
completionSource.TrySetResult(typedPacket);
return;
}
}
catch (TaskCanceledException)
{
completionSource.SetCanceled();
}
}
completionSource.SetCanceled();
}, cancellationToken);
var packet = (TPacket) await Task.Run(() => completionSource.Task, cancellationToken);

37
TcpDotNet/ClientState.cs Normal file
View File

@ -0,0 +1,37 @@
namespace TcpDotNet;
/// <summary>
/// An enumeration of states for a <see cref="BaseClientNode" /> to be in.
/// </summary>
public enum ClientState
{
/// <summary>
/// The client is not connected to a remote server.
/// </summary>
None,
/// <summary>
/// The client has disconnected from a previously-connected server.
/// </summary>
Disconnected,
/// <summary>
/// The client is establishing a connection to a remote server.
/// </summary>
Connecting,
/// <summary>
/// The client is handshaking.
/// </summary>
Handshaking,
/// <summary>
/// The client is exchanging encryption keys.
/// </summary>
Encrypting,
/// <summary>
/// The client is connected.
/// </summary>
Connected
}

View File

@ -0,0 +1,19 @@
using Chilkat;
namespace TcpDotNet;
internal static class CryptographyUtils
{
public static Crypt2 GenerateAes(byte[] key)
{
return new Crypt2
{
CryptAlgorithm = "aes",
CipherMode = "cfb",
KeyLength = 128,
PaddingScheme = 0,
SecretKey = key[..],
IV = key[..]
};
}
}

View File

@ -13,5 +13,10 @@ public enum DisconnectReason
/// <summary>
/// The client reached an unexpected end of stream.
/// </summary>
EndOfStream
EndOfStream,
/// <summary>
/// The client sent an invalid encryption payload.
/// </summary>
InvalidEncryptionKey
}

View File

@ -0,0 +1,22 @@
namespace TcpDotNet.EventData;
/// <summary>
/// Provides event data for the <see cref="ProtocolClient.Disconnected" /> event.
/// </summary>
public sealed class DisconnectedEventArgs : EventArgs
{
/// <summary>
/// Initializes a new instance of the <see cref="DisconnectedEventArgs" /> class.
/// </summary>
/// <param name="disconnectReason">The reason for the disconnect.</param>
public DisconnectedEventArgs(DisconnectReason disconnectReason)
{
DisconnectReason = disconnectReason;
}
/// <summary>
/// Gets the reason for the disconnect.
/// </summary>
/// <value>The disconnect reason.</value>
public DisconnectReason DisconnectReason { get; }
}

View File

@ -11,6 +11,11 @@ namespace TcpDotNet;
/// </summary>
public abstract class Node : IDisposable
{
/// <summary>
/// The protocol version of this node.
/// </summary>
public const int ProtocolVersion = 1;
private readonly ConcurrentDictionary<int, Type> _registeredPackets = new();
private readonly ConcurrentDictionary<Type, List<PacketHandler>> _registeredPacketHandlers = new();
@ -35,8 +40,16 @@ public abstract class Node : IDisposable
new ReadOnlyDictionary<Type, IReadOnlyCollection<PacketHandler>>(
_registeredPacketHandlers.ToDictionary(p => p.Key, p => (IReadOnlyCollection<PacketHandler>) p.Value.AsReadOnly()));
/// <summary>
/// Closes the base socket connection and releases all associated resources.
/// </summary>
public virtual void Close()
{
BaseSocket.Close();
}
/// <inheritdoc />
public void Dispose()
public virtual void Dispose()
{
BaseSocket.Dispose();
}

View File

@ -0,0 +1,10 @@
namespace TcpDotNet.Protocol;
/// <summary>
/// An enumeration of handshake responses.
/// </summary>
public enum HandshakeResponse : byte
{
Success = 0x00,
UnsupportedProtocolVersion = 0x01
}

View File

@ -0,0 +1,20 @@
using TcpDotNet.Protocol.Packets.ClientBound;
namespace TcpDotNet.Protocol.PacketHandlers;
internal sealed class DisconnectPacketHandler : PacketHandler<DisconnectPacket>
{
/// <inheritdoc />
public override Task HandleAsync(
BaseClientNode recipient,
DisconnectPacket packet,
CancellationToken cancellationToken = default
)
{
if (recipient is ProtocolClient client)
client.OnDisconnect(packet.Reason);
recipient.Close();
return Task.CompletedTask;
}
}

View File

@ -0,0 +1,41 @@
using System.Security.Cryptography;
using TcpDotNet.Protocol.Packets.ClientBound;
using TcpDotNet.Protocol.Packets.ServerBound;
namespace TcpDotNet.Protocol.PacketHandlers;
/// <summary>
/// Represents a handler for a <see cref="EncryptionResponsePacket" />.
/// </summary>
internal sealed class EncryptionResponsePacketHandler : PacketHandler<EncryptionResponsePacket>
{
/// <inheritdoc />
public override async Task HandleAsync(
BaseClientNode recipient,
EncryptionResponsePacket packet,
CancellationToken cancellationToken = default
)
{
if (recipient is not ProtocolListener.Client client)
return;
RSACryptoServiceProvider rsa = client.ParentListener.Rsa;
byte[] payload = rsa.Decrypt(packet.Payload, true);
if (!payload.SequenceEqual(client.AesVerificationPayload))
{
client.ParentListener.OnClientDisconnect(client, DisconnectReason.InvalidEncryptionKey);
return;
}
byte[] key = rsa.Decrypt(packet.SharedSecret, true);
client.AesVerificationPayload = Array.Empty<byte>();
client.Aes = CryptographyUtils.GenerateAes(key);
client.State = ClientState.Connected;
client.ParentListener.OnClientConnect(client);
var sessionPacket = new SessionExchangePacket(client.SessionId);
await client.SendPacketAsync(sessionPacket, cancellationToken);
client.UseEncryption = true;
}
}

View File

@ -0,0 +1,41 @@
using TcpDotNet.Protocol.Packets.ClientBound;
using TcpDotNet.Protocol.Packets.ServerBound;
namespace TcpDotNet.Protocol.PacketHandlers;
/// <summary>
/// Represents a handler for a <see cref="HandshakeRequestPacket" />.
/// </summary>
internal sealed class HandshakeRequestPacketHandler : PacketHandler<HandshakeRequestPacket>
{
/// <inheritdoc />
public override async Task HandleAsync(
BaseClientNode recipient,
HandshakeRequestPacket packet,
CancellationToken cancellationToken = default
)
{
if (recipient is not ProtocolListener.Client client)
return;
HandshakeResponsePacket response;
if (packet.ProtocolVersion != Node.ProtocolVersion)
{
response = new HandshakeResponsePacket(packet.ProtocolVersion, HandshakeResponse.UnsupportedProtocolVersion);
await client.SendPacketAsync(response, cancellationToken);
client.Close();
return;
}
response = new HandshakeResponsePacket(packet.ProtocolVersion, HandshakeResponse.Success);
await client.SendPacketAsync(response, cancellationToken);
client.State = ClientState.Encrypting;
await Task.Delay(1000, cancellationToken);
var encryptionRequest = new EncryptionRequestPacket(client.ParentListener.Rsa.ExportCspBlob(false));
client.AesVerificationPayload = encryptionRequest.Payload;
await client.SendPacketAsync(encryptionRequest, cancellationToken);
}
}

View File

@ -0,0 +1,37 @@
namespace TcpDotNet.Protocol.Packets.ClientBound;
[Packet(0x7FFFFFFF)]
internal sealed class DisconnectPacket : Packet
{
/// <summary>
/// Initializes a new instance of the <see cref="DisconnectPacket" /> class.
/// </summary>
/// <param name="reason">The reason for the disconnect.</param>
public DisconnectPacket(DisconnectReason reason)
{
Reason = reason;
}
internal DisconnectPacket()
{
}
/// <summary>
/// Gets the reason for the disconnect.
/// </summary>
public DisconnectReason Reason { get; private set; }
/// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader)
{
Reason = (DisconnectReason) reader.ReadByte();
return Task.CompletedTask;
}
/// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer)
{
writer.Write((byte) Reason);
return Task.CompletedTask;
}
}

View File

@ -0,0 +1,66 @@
using System.Security.Cryptography;
namespace TcpDotNet.Protocol.Packets.ClientBound;
/// <summary>
/// Represents a packet which requests encryption from the client.
/// </summary>
[Packet(0xE2)]
internal sealed class EncryptionRequestPacket : Packet
{
/// <summary>
/// Initializes a new instance of the <see cref="EncryptionRequestPacket" /> class.
/// </summary>
/// <param name="publicKey">The public key.</param>
public EncryptionRequestPacket(byte[] publicKey) : this()
{
// ReSharper disable once NullCoalescingConditionIsAlwaysNotNullAccordingToAPIContract
publicKey ??= Array.Empty<byte>();
PublicKey = publicKey[..];
}
internal EncryptionRequestPacket()
{
PublicKey = Array.Empty<byte>();
using var rng = new RNGCryptoServiceProvider();
Payload = new byte[64];
rng.GetBytes(Payload);
}
/// <summary>
/// Gets the payload.
/// </summary>
/// <value>The payload.</value>
public byte[] Payload { get; private set; }
/// <summary>
/// Gets the public key.
/// </summary>
/// <value>The public key.</value>
public byte[] PublicKey { get; private set; }
/// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader)
{
int length = reader.ReadInt32();
PublicKey = reader.ReadBytes(length);
length = reader.ReadInt32();
Payload = reader.ReadBytes(length);
return Task.CompletedTask;
}
/// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer)
{
writer.Write(PublicKey.Length);
writer.Write(PublicKey);
writer.Write(Payload.Length);
writer.Write(Payload);
return Task.CompletedTask;
}
}

View File

@ -0,0 +1,53 @@
using TcpDotNet.Protocol.Packets.ServerBound;
namespace TcpDotNet.Protocol.Packets.ClientBound;
/// <summary>
/// Represents a packet which responds to a <see cref="HandshakeRequestPacket" />.
/// </summary>
[Packet(0xE1)]
internal sealed class HandshakeResponsePacket : Packet
{
/// <summary>
/// Initializes a new instance of the <see cref="HandshakeResponsePacket" /> class.
/// </summary>
/// <param name="protocolVersion">The requested protocol version.</param>
/// <param name="handshakeResponse">The handshake response.</param>
public HandshakeResponsePacket(int protocolVersion, HandshakeResponse handshakeResponse)
{
ProtocolVersion = protocolVersion;
HandshakeResponse = handshakeResponse;
}
internal HandshakeResponsePacket()
{
}
/// <summary>
/// Gets the handshake response.
/// </summary>
/// <value>The handshake response.</value>
public HandshakeResponse HandshakeResponse { get; private set; }
/// <summary>
/// Gets the requested protocol version.
/// </summary>
/// <value>The protocol version.</value>
public int ProtocolVersion { get; private set; }
/// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader)
{
HandshakeResponse = (HandshakeResponse) reader.ReadByte();
ProtocolVersion = reader.ReadInt32();
return Task.CompletedTask;
}
/// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer)
{
writer.Write((byte) HandshakeResponse);
writer.Write(ProtocolVersion);
return Task.CompletedTask;
}
}

View File

@ -0,0 +1,39 @@
namespace TcpDotNet.Protocol.Packets.ClientBound;
[Packet(0xE4)]
internal sealed class SessionExchangePacket : Packet
{
/// <summary>
/// Initializes a new instance of the <see cref="SessionExchangePacket" /> class.
/// </summary>
/// <param name="session">The session.</param>
public SessionExchangePacket(Guid session)
{
Session = session;
}
internal SessionExchangePacket()
{
Session = Guid.Empty;
}
/// <summary>
/// Gets the session.
/// </summary>
/// <value>The session.</value>
public Guid Session { get; private set; }
/// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader)
{
Session = reader.ReadGuid();
return Task.CompletedTask;
}
/// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer)
{
writer.Write(Session);
return Task.CompletedTask;
}
}

View File

@ -0,0 +1,66 @@
using TcpDotNet.Protocol.Packets.ClientBound;
namespace TcpDotNet.Protocol.Packets.ServerBound;
/// <summary>
/// Represents a packet which responds to a <see cref="EncryptionRequestPacket" />.
/// </summary>
[Packet(0xE3)]
internal sealed class EncryptionResponsePacket : Packet
{
/// <summary>
/// Initializes a new instance of the <see cref="EncryptionResponsePacket" /> class.
/// </summary>
/// <param name="payload">The payload.</param>
/// <param name="key">The RSA-encrypted symmetric shared secret.</param>
public EncryptionResponsePacket(byte[] payload, byte[] key) : this()
{
// ReSharper disable ConditionalAccessQualifierIsNonNullableAccordingToAPIContract
Payload = payload?[..] ?? Array.Empty<byte>();
SharedSecret = key?[..] ?? Array.Empty<byte>();
// ReSharper enable ConditionalAccessQualifierIsNonNullableAccordingToAPIContract
}
internal EncryptionResponsePacket()
{
SharedSecret = Array.Empty<byte>();
Payload = Array.Empty<byte>();
}
/// <summary>
/// Gets the symmetric shared secret.
/// </summary>
/// <value>The RSA-encrypted symmetric shared secret.</value>
public byte[] SharedSecret { get; private set; }
/// <summary>
/// Gets the payload.
/// </summary>
/// <value>The payload.</value>
public byte[] Payload { get; private set; }
/// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader)
{
int length = reader.ReadInt32();
SharedSecret = reader.ReadBytes(length);
length = reader.ReadInt32();
Payload = reader.ReadBytes(length);
return Task.CompletedTask;
}
/// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer)
{
writer.Write(SharedSecret.Length);
writer.Write(SharedSecret);
writer.Write(Payload.Length);
writer.Write(Payload);
return Task.CompletedTask;
}
}

View File

@ -3,7 +3,7 @@
/// <summary>
/// Represents a packet which requests a handshake with a <see cref="ProtocolListener" />.
/// </summary>
[Packet(0x00000001)]
[Packet(0xE0)]
internal sealed class HandshakeRequestPacket : Packet
{
/// <summary>

View File

@ -1,5 +1,13 @@
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using TcpDotNet.EventData;
using TcpDotNet.Protocol;
using TcpDotNet.Protocol.PacketHandlers;
using TcpDotNet.Protocol.Packets.ClientBound;
using TcpDotNet.Protocol.Packets.ServerBound;
using Socket = System.Net.Sockets.Socket;
using Task = System.Threading.Tasks.Task;
namespace TcpDotNet;
@ -13,10 +21,17 @@ public sealed class ProtocolClient : BaseClientNode
/// </summary>
public ProtocolClient()
{
Aes.GenerateKey();
Aes.GenerateIV();
RegisterPacketHandler(PacketHandler<HandshakeResponsePacket>.Empty);
RegisterPacketHandler(PacketHandler<EncryptionRequestPacket>.Empty);
RegisterPacketHandler(PacketHandler<SessionExchangePacket>.Empty);
RegisterPacketHandler(new DisconnectPacketHandler());
}
/// <summary>
/// Occurs when the client has been disconnected.
/// </summary>
public event EventHandler<DisconnectedEventArgs>? Disconnected;
/// <summary>
/// Establishes a connection to a remote host.
/// </summary>
@ -62,8 +77,56 @@ public sealed class ProtocolClient : BaseClientNode
public async Task ConnectAsync(EndPoint remoteEP, CancellationToken cancellationToken = default)
{
if (remoteEP is null) throw new ArgumentNullException(nameof(remoteEP));
State = ClientState.Connecting;
BaseSocket = new Socket(remoteEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
await Task.Run(() => BaseSocket.ConnectAsync(remoteEP), cancellationToken);
try
{
await Task.Run(() => BaseSocket.ConnectAsync(remoteEP), cancellationToken);
}
catch
{
State = ClientState.None;
throw;
}
IsConnected = true;
State = ClientState.Handshaking;
var handshakeRequest = new HandshakeRequestPacket(ProtocolVersion);
var handshakeResponse =
await SendAndReceive<HandshakeRequestPacket, HandshakeResponsePacket>(handshakeRequest, cancellationToken);
if (handshakeResponse.HandshakeResponse != HandshakeResponse.Success)
{
Close();
IsConnected = false;
throw new InvalidOperationException("Handshake failed. " +
$"Server responded with {handshakeResponse.HandshakeResponse:D}");
}
State = ClientState.Encrypting;
var encryptionRequest = await WaitForPacketAsync<EncryptionRequestPacket>(cancellationToken);
using var rsa = new RSACryptoServiceProvider(2048);
rsa.ImportCspBlob(encryptionRequest.PublicKey);
byte[] encryptedPayload = rsa.Encrypt(encryptionRequest.Payload, true);
var key = new byte[128];
using var rng = new RNGCryptoServiceProvider();
rng.GetBytes(key);
Aes = CryptographyUtils.GenerateAes(key);
var encryptionResponse = new EncryptionResponsePacket(encryptedPayload, rsa.Encrypt(key, true));
var sessionPacket = await SendAndReceive<EncryptionResponsePacket, SessionExchangePacket>(encryptionResponse, cancellationToken);
SessionId = sessionPacket.Session;
UseEncryption = true;
State = ClientState.Connected;
}
internal void OnDisconnect(DisconnectReason reason)
{
Disconnected?.Invoke(this, new DisconnectedEventArgs(reason));
Close();
}
}

View File

@ -24,6 +24,11 @@ public sealed partial class ProtocolListener
/// <value>The parent listener.</value>
public ProtocolListener ParentListener { get; }
/// <summary>
/// Gets or sets the client's verification payload.
/// </summary>
internal byte[] AesVerificationPayload { get; set; } = Array.Empty<byte>();
internal void Start()
{
foreach (Type packetType in ParentListener.RegisteredPackets.Values)
@ -33,6 +38,7 @@ public sealed partial class ProtocolListener
foreach (PacketHandler handler in handlers)
RegisterPacketHandler(packetType, handler);
State = ClientState.Handshaking;
Task.Run(ReadLoopAsync);
}

View File

@ -1,7 +1,10 @@
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using TcpDotNet.EventData;
using TcpDotNet.Protocol;
using TcpDotNet.Protocol.PacketHandlers;
using TcpDotNet.Protocol.Packets.ClientBound;
namespace TcpDotNet;
@ -12,6 +15,15 @@ public sealed partial class ProtocolListener : Node
{
private readonly List<Client> _clients = new();
/// <summary>
/// Initializes a new instance of the <see cref="ProtocolListener" /> class.
/// </summary>
public ProtocolListener()
{
RegisterPacketHandler(new HandshakeRequestPacketHandler());
RegisterPacketHandler(new EncryptionResponsePacketHandler());
}
/// <summary>
/// Occurs when a client connects to the listener.
/// </summary>
@ -62,6 +74,12 @@ public sealed partial class ProtocolListener : Node
/// <value>The <see cref="EndPoint" /> that <see cref="Node.BaseSocket" /> is using for communications.</value>
public EndPoint LocalEndPoint => BaseSocket.LocalEndPoint;
/// <summary>
/// Gets the RSA provider for this listener.
/// </summary>
/// <value>The RSA provider.</value>
internal RSACryptoServiceProvider Rsa { get; } = new(2048);
/// <summary>
/// Starts the listener on the specified port, using <see cref="IPAddress.Any" /> as the bind address, or
/// <see cref="IPAddress.IPv6Any" /> if <see cref="Socket.OSSupportsIPv6" /> is <see langword="true" />.
@ -98,8 +116,17 @@ public sealed partial class ProtocolListener : Node
Task.Run(AcceptLoop);
}
private void OnClientDisconnect(Client client, DisconnectReason disconnectReason)
internal void OnClientConnect(Client client)
{
lock (_clients) _clients.Add(client);
ClientConnected?.Invoke(this, new ClientConnectedEventArgs(client));
}
internal void OnClientDisconnect(Client client, DisconnectReason disconnectReason)
{
var disconnectPacket = new DisconnectPacket(disconnectReason);
client.SendPacketAsync(disconnectPacket).GetAwaiter().GetResult();
client.Close();
lock (_clients) _clients.Remove(client);
ClientDisconnected?.Invoke(this, new ClientDisconnectedEventArgs(client, disconnectReason));
}
@ -116,10 +143,7 @@ public sealed partial class ProtocolListener : Node
Socket socket = await BaseSocket.AcceptAsync();
var client = new Client(this, socket);
lock (_clients) _clients.Add(client);
client.Start();
ClientConnected?.Invoke(this, new ClientConnectedEventArgs(client));
}
}
}

View File

@ -7,4 +7,9 @@
<LangVersion>10</LangVersion>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="ChilkatDnCore" Version="9.5.0.90" />
<PackageReference Include="X10D" Version="3.2.0-nightly.108" />
</ItemGroup>
</Project>