diff --git a/TcpDotNet/ClientNode.cs b/TcpDotNet/ClientNode.cs index 6ff6b55..7a931b4 100644 --- a/TcpDotNet/ClientNode.cs +++ b/TcpDotNet/ClientNode.cs @@ -15,7 +15,6 @@ namespace TcpDotNet; /// public abstract class ClientNode : Node { - private readonly ObjectIDGenerator _callbackIdGenerator = new(); private readonly ConcurrentDictionary>> _packetCompletionSources = new(); private EndPoint? _remoteEP; @@ -219,25 +218,27 @@ public abstract class ClientNode : Node /// /// Sends a packet, and waits for a specific packet to be received. /// - /// The packet to send. - /// A cancellation token that can be used to cancel the asynchronous operation. + /// The packet to send. + /// + /// A cancellation token that can be used to cancel the asynchronous operation. + /// /// The type of the packet to return. /// The received packet. /// - /// This method will consume all incoming packets, raising their associated handlers if such packets are recognised. + /// This method will consume all incoming packets, raising their associated handlers if such packets are + /// recognised. /// - public async Task SendAndReceiveAsync(Packet packetToSend, + /// is . + public async Task SendAndReceiveAsync(RequestPacket packet, CancellationToken cancellationToken = default) where TReceive : Packet { + if (packet is null) throw new ArgumentNullException(nameof(packet)); + var attribute = typeof(TReceive).GetCustomAttribute(); if (attribute is null) throw new ArgumentException($"The packet type {typeof(TReceive).Name} is not a valid packet."); - var requestPacket = packetToSend as RequestPacket; - if (requestPacket is not null) - requestPacket.CallbackId = _callbackIdGenerator.GetId(packetToSend, out _); - var completionSource = new TaskCompletionSource(); if (!_packetCompletionSources.TryGetValue(attribute.Id, out List>? completionSources)) @@ -252,15 +253,12 @@ public abstract class ClientNode : Node completionSources.Add(completionSource); } - await SendPacketAsync(packetToSend, cancellationToken); + await SendPacketAsync(packet, cancellationToken); TReceive response; do { response = await WaitForPacketAsync(completionSource, cancellationToken); - if (requestPacket is null) - break; - - if (response is ResponsePacket responsePacket && responsePacket.CallbackId == requestPacket.CallbackId) + if (response is ResponsePacket responsePacket && responsePacket.CallbackId == packet.CallbackId) break; } while (true); diff --git a/TcpDotNet/Protocol/RequestPacket.cs b/TcpDotNet/Protocol/RequestPacket.cs index f1ebe51..21037d9 100644 --- a/TcpDotNet/Protocol/RequestPacket.cs +++ b/TcpDotNet/Protocol/RequestPacket.cs @@ -1,3 +1,6 @@ +using System.Buffers.Binary; +using System.Security.Cryptography; + namespace TcpDotNet.Protocol; /// @@ -5,11 +8,21 @@ namespace TcpDotNet.Protocol; /// public abstract class RequestPacket : Packet { + /// + /// Initializes a new instance of the class. + /// + protected RequestPacket() + { + Span buffer = stackalloc byte[8]; + RandomNumberGenerator.Fill(buffer); + CallbackId = BinaryPrimitives.ReadInt64BigEndian(buffer); + } + /// /// Gets the request identifier. /// /// The request identifier. - public long CallbackId { get; internal set; } + public long CallbackId { get; private set; } /// protected internal override void Deserialize(ProtocolReader reader)