refactor!: use completion source for expected response

This commit changes the constraint on TReceive making it a breaking change!
This commit is contained in:
Oliver Booth 2024-02-12 19:26:33 +00:00
parent 0a9348fd66
commit 362c51cd09
Signed by: oliverbooth
GPG Key ID: E60B570D1B7557B5
2 changed files with 12 additions and 34 deletions

View File

@ -2,7 +2,6 @@ using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using System.Reflection;
using System.Runtime.Serialization;
using Chilkat;
using TcpDotNet.Protocol;
using Stream = System.IO.Stream;
@ -16,6 +15,7 @@ namespace TcpDotNet;
public abstract class ClientNode : Node
{
private readonly ConcurrentDictionary<int, List<TaskCompletionSource<Packet>>> _packetCompletionSources = new();
private readonly ConcurrentDictionary<long, TaskCompletionSource<ResponsePacket>> _callbackCompletionSources = new();
private EndPoint? _remoteEP;
/// <summary>
@ -231,38 +231,16 @@ public abstract class ClientNode : Node
/// <exception cref="ArgumentNullException"><paramref name="packet" /> is <see langword="null" />.</exception>
public async Task<TReceive> SendAndReceiveAsync<TReceive>(RequestPacket packet,
CancellationToken cancellationToken = default)
where TReceive : Packet
where TReceive : ResponsePacket
{
if (packet is null) throw new ArgumentNullException(nameof(packet));
var attribute = typeof(TReceive).GetCustomAttribute<PacketAttribute>();
if (attribute is null)
throw new ArgumentException($"The packet type {typeof(TReceive).Name} is not a valid packet.");
var completionSource = new TaskCompletionSource<Packet>();
if (!_packetCompletionSources.TryGetValue(attribute.Id,
out List<TaskCompletionSource<Packet>>? completionSources))
{
completionSources = new List<TaskCompletionSource<Packet>>();
_packetCompletionSources.TryAdd(attribute.Id, completionSources);
}
lock (completionSources)
{
if (!completionSources.Contains(completionSource))
completionSources.Add(completionSource);
}
var completionSource = new TaskCompletionSource<ResponsePacket>();
if (!_callbackCompletionSources.TryAdd(packet.CallbackId, completionSource))
throw new InvalidOperationException("Duplicate packet sent");
await SendPacketAsync(packet, cancellationToken);
TReceive response;
do
{
response = await WaitForPacketAsync<TReceive>(completionSource, cancellationToken);
if (response is ResponsePacket responsePacket && responsePacket.CallbackId == packet.CallbackId)
break;
} while (true);
return response;
return (TReceive)await completionSource.Task;
}
/// <summary>
@ -281,10 +259,8 @@ public abstract class ClientNode : Node
return WaitForPacketAsync<TPacket>(completionSource, cancellationToken);
}
private async Task<TPacket> WaitForPacketAsync<TPacket>(
TaskCompletionSource<Packet> completionSource,
CancellationToken cancellationToken = default
)
private async Task<TPacket> WaitForPacketAsync<TPacket>(TaskCompletionSource<Packet> completionSource,
CancellationToken cancellationToken = default)
where TPacket : Packet
{
var attribute = typeof(TPacket).GetCustomAttribute<PacketAttribute>();

View File

@ -6,14 +6,16 @@ namespace TcpDotNet.Protocol.Packets.ClientBound;
/// Represents a packet which responds to a <see cref="HandshakeRequestPacket" />.
/// </summary>
[Packet(0x7FFFFFE1)]
internal sealed class HandshakeResponsePacket : Packet
internal sealed class HandshakeResponsePacket : ResponsePacket
{
/// <summary>
/// Initializes a new instance of the <see cref="HandshakeResponsePacket" /> class.
/// </summary>
/// <param name="callbackId">The callback ID.</param>
/// <param name="protocolVersion">The requested protocol version.</param>
/// <param name="handshakeResponse">The handshake response.</param>
public HandshakeResponsePacket(int protocolVersion, HandshakeResponse handshakeResponse)
public HandshakeResponsePacket(long callbackId, int protocolVersion, HandshakeResponse handshakeResponse)
: base(callbackId)
{
ProtocolVersion = protocolVersion;
HandshakeResponse = handshakeResponse;