1
0
mirror of https://github.com/oliverbooth/TcpDotNet synced 2024-11-14 03:55:41 +00:00

refactor: generate callback ID from RequestPacket, not client

This commit is contained in:
Oliver Booth 2024-02-12 18:03:46 +00:00
parent 18b0ff3f25
commit a3e540ac83
Signed by: oliverbooth
GPG Key ID: E60B570D1B7557B5
2 changed files with 26 additions and 15 deletions

View File

@ -15,7 +15,6 @@ namespace TcpDotNet;
/// </summary> /// </summary>
public abstract class ClientNode : Node public abstract class ClientNode : Node
{ {
private readonly ObjectIDGenerator _callbackIdGenerator = new();
private readonly ConcurrentDictionary<int, List<TaskCompletionSource<Packet>>> _packetCompletionSources = new(); private readonly ConcurrentDictionary<int, List<TaskCompletionSource<Packet>>> _packetCompletionSources = new();
private EndPoint? _remoteEP; private EndPoint? _remoteEP;
@ -219,25 +218,27 @@ public abstract class ClientNode : Node
/// <summary> /// <summary>
/// Sends a packet, and waits for a specific packet to be received. /// Sends a packet, and waits for a specific packet to be received.
/// </summary> /// </summary>
/// <param name="packetToSend">The packet to send.</param> /// <param name="packet">The packet to send.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param> /// <param name="cancellationToken">
/// A cancellation token that can be used to cancel the asynchronous operation.
/// </param>
/// <typeparam name="TReceive">The type of the packet to return.</typeparam> /// <typeparam name="TReceive">The type of the packet to return.</typeparam>
/// <returns>The received packet.</returns> /// <returns>The received packet.</returns>
/// <remarks> /// <remarks>
/// This method will consume all incoming packets, raising their associated handlers if such packets are recognised. /// This method will consume all incoming packets, raising their associated handlers if such packets are
/// recognised.
/// </remarks> /// </remarks>
public async Task<TReceive> SendAndReceiveAsync<TReceive>(Packet packetToSend, /// <exception cref="ArgumentNullException"><paramref name="packet" /> is <see langword="null" />.</exception>
public async Task<TReceive> SendAndReceiveAsync<TReceive>(RequestPacket packet,
CancellationToken cancellationToken = default) CancellationToken cancellationToken = default)
where TReceive : Packet where TReceive : Packet
{ {
if (packet is null) throw new ArgumentNullException(nameof(packet));
var attribute = typeof(TReceive).GetCustomAttribute<PacketAttribute>(); var attribute = typeof(TReceive).GetCustomAttribute<PacketAttribute>();
if (attribute is null) if (attribute is null)
throw new ArgumentException($"The packet type {typeof(TReceive).Name} is not a valid packet."); throw new ArgumentException($"The packet type {typeof(TReceive).Name} is not a valid packet.");
var requestPacket = packetToSend as RequestPacket;
if (requestPacket is not null)
requestPacket.CallbackId = _callbackIdGenerator.GetId(packetToSend, out _);
var completionSource = new TaskCompletionSource<Packet>(); var completionSource = new TaskCompletionSource<Packet>();
if (!_packetCompletionSources.TryGetValue(attribute.Id, if (!_packetCompletionSources.TryGetValue(attribute.Id,
out List<TaskCompletionSource<Packet>>? completionSources)) out List<TaskCompletionSource<Packet>>? completionSources))
@ -252,15 +253,12 @@ public abstract class ClientNode : Node
completionSources.Add(completionSource); completionSources.Add(completionSource);
} }
await SendPacketAsync(packetToSend, cancellationToken); await SendPacketAsync(packet, cancellationToken);
TReceive response; TReceive response;
do do
{ {
response = await WaitForPacketAsync<TReceive>(completionSource, cancellationToken); response = await WaitForPacketAsync<TReceive>(completionSource, cancellationToken);
if (requestPacket is null) if (response is ResponsePacket responsePacket && responsePacket.CallbackId == packet.CallbackId)
break;
if (response is ResponsePacket responsePacket && responsePacket.CallbackId == requestPacket.CallbackId)
break; break;
} while (true); } while (true);

View File

@ -1,3 +1,6 @@
using System.Buffers.Binary;
using System.Security.Cryptography;
namespace TcpDotNet.Protocol; namespace TcpDotNet.Protocol;
/// <summary> /// <summary>
@ -5,11 +8,21 @@ namespace TcpDotNet.Protocol;
/// </summary> /// </summary>
public abstract class RequestPacket : Packet public abstract class RequestPacket : Packet
{ {
/// <summary>
/// Initializes a new instance of the <see cref="RequestPacket" /> class.
/// </summary>
protected RequestPacket()
{
Span<byte> buffer = stackalloc byte[8];
RandomNumberGenerator.Fill(buffer);
CallbackId = BinaryPrimitives.ReadInt64BigEndian(buffer);
}
/// <summary> /// <summary>
/// Gets the request identifier. /// Gets the request identifier.
/// </summary> /// </summary>
/// <value>The request identifier.</value> /// <value>The request identifier.</value>
public long CallbackId { get; internal set; } public long CallbackId { get; private set; }
/// <inheritdoc /> /// <inheritdoc />
protected internal override void Deserialize(ProtocolReader reader) protected internal override void Deserialize(ProtocolReader reader)