diff --git a/TcpDotNet/ClientNode.cs b/TcpDotNet/ClientNode.cs index 7a931b4..bc8ebd9 100644 --- a/TcpDotNet/ClientNode.cs +++ b/TcpDotNet/ClientNode.cs @@ -2,7 +2,6 @@ using System.Collections.Concurrent; using System.Net; using System.Net.Sockets; using System.Reflection; -using System.Runtime.Serialization; using Chilkat; using TcpDotNet.Protocol; using Stream = System.IO.Stream; @@ -16,6 +15,7 @@ namespace TcpDotNet; public abstract class ClientNode : Node { private readonly ConcurrentDictionary>> _packetCompletionSources = new(); + private readonly ConcurrentDictionary> _callbackCompletionSources = new(); private EndPoint? _remoteEP; /// @@ -231,38 +231,16 @@ public abstract class ClientNode : Node /// is . public async Task SendAndReceiveAsync(RequestPacket packet, CancellationToken cancellationToken = default) - where TReceive : Packet + where TReceive : ResponsePacket { if (packet is null) throw new ArgumentNullException(nameof(packet)); - var attribute = typeof(TReceive).GetCustomAttribute(); - if (attribute is null) - throw new ArgumentException($"The packet type {typeof(TReceive).Name} is not a valid packet."); - - var completionSource = new TaskCompletionSource(); - if (!_packetCompletionSources.TryGetValue(attribute.Id, - out List>? completionSources)) - { - completionSources = new List>(); - _packetCompletionSources.TryAdd(attribute.Id, completionSources); - } - - lock (completionSources) - { - if (!completionSources.Contains(completionSource)) - completionSources.Add(completionSource); - } + var completionSource = new TaskCompletionSource(); + if (!_callbackCompletionSources.TryAdd(packet.CallbackId, completionSource)) + throw new InvalidOperationException("Duplicate packet sent"); await SendPacketAsync(packet, cancellationToken); - TReceive response; - do - { - response = await WaitForPacketAsync(completionSource, cancellationToken); - if (response is ResponsePacket responsePacket && responsePacket.CallbackId == packet.CallbackId) - break; - } while (true); - - return response; + return (TReceive)await completionSource.Task; } /// @@ -281,10 +259,8 @@ public abstract class ClientNode : Node return WaitForPacketAsync(completionSource, cancellationToken); } - private async Task WaitForPacketAsync( - TaskCompletionSource completionSource, - CancellationToken cancellationToken = default - ) + private async Task WaitForPacketAsync(TaskCompletionSource completionSource, + CancellationToken cancellationToken = default) where TPacket : Packet { var attribute = typeof(TPacket).GetCustomAttribute(); diff --git a/TcpDotNet/Protocol/Packets/ClientBound/HandshakeResponsePacket.cs b/TcpDotNet/Protocol/Packets/ClientBound/HandshakeResponsePacket.cs index 2e3ecd2..b0eb8e2 100644 --- a/TcpDotNet/Protocol/Packets/ClientBound/HandshakeResponsePacket.cs +++ b/TcpDotNet/Protocol/Packets/ClientBound/HandshakeResponsePacket.cs @@ -6,14 +6,16 @@ namespace TcpDotNet.Protocol.Packets.ClientBound; /// Represents a packet which responds to a . /// [Packet(0x7FFFFFE1)] -internal sealed class HandshakeResponsePacket : Packet +internal sealed class HandshakeResponsePacket : ResponsePacket { /// /// Initializes a new instance of the class. /// + /// The callback ID. /// The requested protocol version. /// The handshake response. - public HandshakeResponsePacket(int protocolVersion, HandshakeResponse handshakeResponse) + public HandshakeResponsePacket(long callbackId, int protocolVersion, HandshakeResponse handshakeResponse) + : base(callbackId) { ProtocolVersion = protocolVersion; HandshakeResponse = handshakeResponse;