From 96ca5d778b369b420880e4e1bc660f5ed4fb537c Mon Sep 17 00:00:00 2001 From: Oliver Booth Date: Wed, 18 May 2022 17:51:12 +0100 Subject: [PATCH] Add support for specific packet reading --- TcpDotNet.ClientIntegrationTest/Program.cs | 18 +--- TcpDotNet/BaseClientNode.cs | 114 ++++++++++++++++++++- 2 files changed, 118 insertions(+), 14 deletions(-) diff --git a/TcpDotNet.ClientIntegrationTest/Program.cs b/TcpDotNet.ClientIntegrationTest/Program.cs index 98d2eed..57231ee 100644 --- a/TcpDotNet.ClientIntegrationTest/Program.cs +++ b/TcpDotNet.ClientIntegrationTest/Program.cs @@ -9,18 +9,10 @@ client.RegisterPacketHandler(PacketHandler.Empty); await client.ConnectAsync(IPAddress.IPv6Loopback, 1234); Console.WriteLine($"Connected to {client.RemoteEndPoint}"); -Console.WriteLine("Sending ping packet..."); var ping = new PingPacket(); -await client.SendPacketAsync(ping); -Console.WriteLine("Waiting for response..."); -Packet? response = await client.ReadNextPacketAsync(); -if (response is PongPacket pong) -{ - Console.WriteLine("Received pong packet"); - Console.WriteLine(pong.Payload.SequenceEqual(ping.Payload) ? "Payload matches" : "Payload does not match"); -} -else -{ - Console.WriteLine("Received unknown packet"); -} +Console.WriteLine($"Sending ping packet with payload: {BitConverter.ToString(ping.Payload)}"); +var pong = await client.SendAndReceive(ping); + +Console.WriteLine($"Received pong packet with payload: {BitConverter.ToString(pong.Payload)}"); +Console.WriteLine(pong.Payload.SequenceEqual(ping.Payload) ? "Payload matches!" : "Payload does not match!"); diff --git a/TcpDotNet/BaseClientNode.cs b/TcpDotNet/BaseClientNode.cs index 2140300..4e856dd 100644 --- a/TcpDotNet/BaseClientNode.cs +++ b/TcpDotNet/BaseClientNode.cs @@ -1,4 +1,5 @@ -using System.IO.Compression; +using System.Collections.Concurrent; +using System.IO.Compression; using System.Net; using System.Net.Sockets; using System.Reflection; @@ -12,6 +13,8 @@ namespace TcpDotNet; /// public abstract class BaseClientNode : Node { + private readonly ConcurrentDictionary>> _packetCompletionSources = new(); + /// /// Gets a value indicating whether the client is connected. /// @@ -100,6 +103,15 @@ public abstract class BaseClientNode : Node if (RegisteredPacketHandlers.TryGetValue(packetType, out IReadOnlyCollection? handlers)) await Task.WhenAll(handlers.Select(h => h.HandleAsync(this, packet, cancellationToken))); + if (_packetCompletionSources.TryGetValue(packet.Id, out List>? completionSources)) + { + lock (completionSources) + { + foreach (TaskCompletionSource completionSource in completionSources) + completionSource.SetResult(packet); + } + } + return packet; } @@ -141,4 +153,104 @@ public abstract class BaseClientNode : Node await buffer.CopyToAsync(networkStream, cancellationToken); await networkStream.FlushAsync(cancellationToken); } + + /// + /// Sends a packet, and waits for a specific packet to be received. + /// + /// The packet to send. + /// A cancellation token that can be used to cancel the asynchronous operation. + /// The type of the packet to send. + /// The type of the packet to return. + /// The received packet. + /// + /// This method will consume all incoming packets, raising their associated handlers if such packets are recognised. + /// + public async Task SendAndReceive(TSend packetToSend, CancellationToken cancellationToken = default) + where TSend : Packet + where TReceive : Packet + { + var attribute = typeof(TReceive).GetCustomAttribute(); + if (attribute is null) + throw new ArgumentException($"The packet type {typeof(TReceive).Name} is not a valid packet."); + + var completionSource = new TaskCompletionSource(); + if (!_packetCompletionSources.TryGetValue(attribute.Id, out List>? completionSources)) + { + completionSources = new List>(); + _packetCompletionSources.TryAdd(attribute.Id, completionSources); + } + + lock (completionSources) + { + if (!completionSources.Contains(completionSource)) + completionSources.Add(completionSource); + } + + await SendPacketAsync(packetToSend, cancellationToken); + return await WaitForPacketAsync(completionSource, cancellationToken); + } + + /// + /// Waits for a specific packet to be received. + /// + /// A cancellation token that can be used to cancel the asynchronous operation. + /// The type of the packet to return. + /// The received packet. + /// + /// This method will consume all incoming packets, raising their associated handlers if such packets are recognised. + /// + public Task WaitForPacketAsync(CancellationToken cancellationToken = default) + where TPacket : Packet + { + var completionSource = new TaskCompletionSource(); + return WaitForPacketAsync(completionSource, cancellationToken); + } + + private async Task WaitForPacketAsync( + TaskCompletionSource completionSource, + CancellationToken cancellationToken = default + ) + where TPacket : Packet + { + var attribute = typeof(TPacket).GetCustomAttribute(); + if (attribute is null) + throw new ArgumentException($"The packet type {typeof(TPacket).Name} is not a valid packet."); + + if (!_packetCompletionSources.TryGetValue(attribute.Id, out List>? completionSources)) + { + completionSources = new List>(); + _packetCompletionSources.TryAdd(attribute.Id, completionSources); + } + + lock (completionSources) + { + if (!completionSources.Contains(completionSource)) + completionSources.Add(completionSource); + } + + _ = Task.Run(async () => + { + while (!cancellationToken.IsCancellationRequested) + { + Packet? packet = await ReadNextPacketAsync(cancellationToken); + if (packet is TPacket typedPacket) + { + completionSource.TrySetResult(typedPacket); + return; + } + } + }, cancellationToken); + + var packet = (TPacket) await Task.Run(() => completionSource.Task, cancellationToken); + if (_packetCompletionSources.TryGetValue(attribute.Id, out completionSources)) + { + lock (completionSources) + { + completionSources.Remove(completionSource); + if (completionSources.Count == 0) _packetCompletionSources.TryRemove(attribute.Id, out _); + } + } + + return packet; + } }