diff --git a/TcpDotNet.ClientIntegrationTest/Program.cs b/TcpDotNet.ClientIntegrationTest/Program.cs index 333aaed..ada9729 100644 --- a/TcpDotNet.ClientIntegrationTest/Program.cs +++ b/TcpDotNet.ClientIntegrationTest/Program.cs @@ -17,7 +17,7 @@ Console.WriteLine($"Connected to {client.RemoteEndPoint}. My session is {client. var ping = new PingPacket(); Console.WriteLine($"Sending ping packet with payload: {BitConverter.ToString(ping.Payload)}"); -var pong = await client.SendAndReceive(ping); +var pong = await client.SendAndReceiveAsync(ping); Console.WriteLine($"Received pong packet with payload: {BitConverter.ToString(pong.Payload)}"); Console.WriteLine(pong.Payload.SequenceEqual(ping.Payload) ? "Payload matches!" : "Payload does not match!"); diff --git a/TcpDotNet.ListenerIntegrationTest/Program.cs b/TcpDotNet.ListenerIntegrationTest/Program.cs index 2ca7367..8f88848 100644 --- a/TcpDotNet.ListenerIntegrationTest/Program.cs +++ b/TcpDotNet.ListenerIntegrationTest/Program.cs @@ -19,7 +19,7 @@ internal sealed class PingPacketHandler : PacketHandler public override async Task HandleAsync(BaseClientNode recipient, PingPacket packet, CancellationToken cancellationToken = default) { Console.WriteLine($"Client {recipient.SessionId} sent ping with payload {BitConverter.ToString(packet.Payload)}"); - var pong = new PongPacket(packet.Payload); + var pong = new PongPacket(packet.CallbackId, packet.Payload); await recipient.SendPacketAsync(pong, cancellationToken); } } diff --git a/TcpDotNet/BaseClientNode.cs b/TcpDotNet/BaseClientNode.cs index 8f61364..7809dad 100644 --- a/TcpDotNet/BaseClientNode.cs +++ b/TcpDotNet/BaseClientNode.cs @@ -2,6 +2,7 @@ using System.Net; using System.Net.Sockets; using System.Reflection; +using System.Runtime.Serialization; using Chilkat; using TcpDotNet.Protocol; using Stream = System.IO.Stream; @@ -14,6 +15,7 @@ namespace TcpDotNet; /// public abstract class BaseClientNode : Node { + private readonly ObjectIDGenerator _callbackIdGenerator = new(); private readonly ConcurrentDictionary>> _packetCompletionSources = new(); /// @@ -212,7 +214,8 @@ public abstract class BaseClientNode : Node /// /// This method will consume all incoming packets, raising their associated handlers if such packets are recognised. /// - public async Task SendAndReceive(TSend packetToSend, CancellationToken cancellationToken = default) + public async Task SendAndReceiveAsync(TSend packetToSend, + CancellationToken cancellationToken = default) where TSend : Packet where TReceive : Packet { @@ -220,6 +223,10 @@ public abstract class BaseClientNode : Node if (attribute is null) throw new ArgumentException($"The packet type {typeof(TReceive).Name} is not a valid packet."); + var requestPacket = packetToSend as RequestPacket; + if (requestPacket is not null) + requestPacket.CallbackId = _callbackIdGenerator.GetId(this, out _); + var completionSource = new TaskCompletionSource(); if (!_packetCompletionSources.TryGetValue(attribute.Id, out List>? completionSources)) { @@ -234,7 +241,18 @@ public abstract class BaseClientNode : Node } await SendPacketAsync(packetToSend, cancellationToken); - return await WaitForPacketAsync(completionSource, cancellationToken); + TReceive response; + do + { + response = await WaitForPacketAsync(completionSource, cancellationToken); + if (requestPacket is null) + break; + + if (response is ResponsePacket responsePacket && responsePacket.CallbackId == requestPacket.CallbackId) + break; + } while (true); + + return response; } /// diff --git a/TcpDotNet/Protocol/Packets/ClientBound/PongPacket.cs b/TcpDotNet/Protocol/Packets/ClientBound/PongPacket.cs index e1c4119..302c46d 100644 --- a/TcpDotNet/Protocol/Packets/ClientBound/PongPacket.cs +++ b/TcpDotNet/Protocol/Packets/ClientBound/PongPacket.cs @@ -1,17 +1,21 @@ namespace TcpDotNet.Protocol.Packets.ClientBound; +/// +/// Represents a packet which performs a heartbeat response. +/// [Packet(0x7FFFFFF1)] -public sealed class PongPacket : Packet +public sealed class PongPacket : ResponsePacket { /// /// Initializes a new instance of the class. /// - public PongPacket(byte[] payload) + public PongPacket(long callbackId, byte[] payload) + : base(callbackId) { Payload = payload[..]; } - internal PongPacket() + internal PongPacket() : base(0) { Payload = Array.Empty(); } @@ -23,18 +27,18 @@ public sealed class PongPacket : Packet public byte[] Payload { get; private set; } /// - protected internal override Task DeserializeAsync(ProtocolReader reader) + protected internal override async Task DeserializeAsync(ProtocolReader reader) { + await base.DeserializeAsync(reader); int length = reader.ReadInt32(); Payload = reader.ReadBytes(length); - return Task.CompletedTask; } /// - protected internal override Task SerializeAsync(ProtocolWriter writer) + protected internal override async Task SerializeAsync(ProtocolWriter writer) { + await base.SerializeAsync(writer); writer.Write(Payload.Length); writer.Write(Payload); - return Task.CompletedTask; } } diff --git a/TcpDotNet/Protocol/Packets/ServerBound/PingPacket.cs b/TcpDotNet/Protocol/Packets/ServerBound/PingPacket.cs index 23e90e8..0a80a60 100644 --- a/TcpDotNet/Protocol/Packets/ServerBound/PingPacket.cs +++ b/TcpDotNet/Protocol/Packets/ServerBound/PingPacket.cs @@ -2,8 +2,11 @@ namespace TcpDotNet.Protocol.Packets.ServerBound; +/// +/// Represents a packet which performs a heartbeat request. +/// [Packet(0x7FFFFFF0)] -public sealed class PingPacket : Packet +public sealed class PingPacket : RequestPacket { /// /// Initializes a new instance of the class. @@ -22,18 +25,18 @@ public sealed class PingPacket : Packet public byte[] Payload { get; private set; } /// - protected internal override Task DeserializeAsync(ProtocolReader reader) + protected internal override async Task DeserializeAsync(ProtocolReader reader) { + await base.DeserializeAsync(reader); int length = reader.ReadInt32(); Payload = reader.ReadBytes(length); - return Task.CompletedTask; } /// - protected internal override Task SerializeAsync(ProtocolWriter writer) + protected internal override async Task SerializeAsync(ProtocolWriter writer) { + await base.SerializeAsync(writer); writer.Write(Payload.Length); writer.Write(Payload); - return Task.CompletedTask; } } diff --git a/TcpDotNet/Protocol/RequestPacket.cs b/TcpDotNet/Protocol/RequestPacket.cs new file mode 100644 index 0000000..6b17b4c --- /dev/null +++ b/TcpDotNet/Protocol/RequestPacket.cs @@ -0,0 +1,27 @@ +namespace TcpDotNet.Protocol; + +/// +/// Represents a request packet, which forms a request/response packet pair. +/// +public abstract class RequestPacket : Packet +{ + /// + /// Gets the request identifier. + /// + /// The request identifier. + public long CallbackId { get; internal set; } + + /// + protected internal override Task DeserializeAsync(ProtocolReader reader) + { + CallbackId = reader.ReadInt64(); + return Task.CompletedTask; + } + + /// + protected internal override Task SerializeAsync(ProtocolWriter writer) + { + writer.Write(CallbackId); + return Task.CompletedTask; + } +} diff --git a/TcpDotNet/Protocol/ResponsePacket.cs b/TcpDotNet/Protocol/ResponsePacket.cs new file mode 100644 index 0000000..6d675b7 --- /dev/null +++ b/TcpDotNet/Protocol/ResponsePacket.cs @@ -0,0 +1,36 @@ +namespace TcpDotNet.Protocol; + +/// +/// Represents a response packet, which forms a request/response packet pair. +/// +public abstract class ResponsePacket : Packet +{ + /// + /// Initializes a new instance of the class. + /// + /// The callback ID. + protected ResponsePacket(long callbackId) + { + CallbackId = callbackId; + } + + /// + /// Gets the response identifier. + /// + /// The response identifier. + public long CallbackId { get; private set; } + + /// + protected internal override Task DeserializeAsync(ProtocolReader reader) + { + CallbackId = reader.ReadInt64(); + return Task.CompletedTask; + } + + /// + protected internal override Task SerializeAsync(ProtocolWriter writer) + { + writer.Write(CallbackId); + return Task.CompletedTask; + } +} diff --git a/TcpDotNet/ProtocolClient.cs b/TcpDotNet/ProtocolClient.cs index 12120f0..f5a62bf 100644 --- a/TcpDotNet/ProtocolClient.cs +++ b/TcpDotNet/ProtocolClient.cs @@ -95,7 +95,7 @@ public sealed class ProtocolClient : BaseClientNode State = ClientState.Handshaking; var handshakeRequest = new HandshakeRequestPacket(ProtocolVersion); var handshakeResponse = - await SendAndReceive(handshakeRequest, cancellationToken); + await SendAndReceiveAsync(handshakeRequest, cancellationToken); if (handshakeResponse.HandshakeResponse != HandshakeResponse.Success) { @@ -117,7 +117,7 @@ public sealed class ProtocolClient : BaseClientNode Aes = CryptographyUtils.GenerateAes(key); var encryptionResponse = new EncryptionResponsePacket(encryptedPayload, rsa.Encrypt(key, true)); - var sessionPacket = await SendAndReceive(encryptionResponse, cancellationToken); + var sessionPacket = await SendAndReceiveAsync(encryptionResponse, cancellationToken); SessionId = sessionPacket.Session; UseEncryption = true;