Add support for specific packet reading

This commit is contained in:
Oliver Booth 2022-05-18 17:51:12 +01:00
parent 4d317298b8
commit 96ca5d778b
No known key found for this signature in database
GPG Key ID: 32A00B35503AF634
2 changed files with 118 additions and 14 deletions

View File

@ -9,18 +9,10 @@ client.RegisterPacketHandler(PacketHandler<PongPacket>.Empty);
await client.ConnectAsync(IPAddress.IPv6Loopback, 1234);
Console.WriteLine($"Connected to {client.RemoteEndPoint}");
Console.WriteLine("Sending ping packet...");
var ping = new PingPacket();
await client.SendPacketAsync(ping);
Console.WriteLine("Waiting for response...");
Packet? response = await client.ReadNextPacketAsync();
if (response is PongPacket pong)
{
Console.WriteLine("Received pong packet");
Console.WriteLine(pong.Payload.SequenceEqual(ping.Payload) ? "Payload matches" : "Payload does not match");
}
else
{
Console.WriteLine("Received unknown packet");
}
Console.WriteLine($"Sending ping packet with payload: {BitConverter.ToString(ping.Payload)}");
var pong = await client.SendAndReceive<PingPacket, PongPacket>(ping);
Console.WriteLine($"Received pong packet with payload: {BitConverter.ToString(pong.Payload)}");
Console.WriteLine(pong.Payload.SequenceEqual(ping.Payload) ? "Payload matches!" : "Payload does not match!");

View File

@ -1,4 +1,5 @@
using System.IO.Compression;
using System.Collections.Concurrent;
using System.IO.Compression;
using System.Net;
using System.Net.Sockets;
using System.Reflection;
@ -12,6 +13,8 @@ namespace TcpDotNet;
/// </summary>
public abstract class BaseClientNode : Node
{
private readonly ConcurrentDictionary<int, List<TaskCompletionSource<Packet>>> _packetCompletionSources = new();
/// <summary>
/// Gets a value indicating whether the client is connected.
/// </summary>
@ -100,6 +103,15 @@ public abstract class BaseClientNode : Node
if (RegisteredPacketHandlers.TryGetValue(packetType, out IReadOnlyCollection<PacketHandler>? handlers))
await Task.WhenAll(handlers.Select(h => h.HandleAsync(this, packet, cancellationToken)));
if (_packetCompletionSources.TryGetValue(packet.Id, out List<TaskCompletionSource<Packet>>? completionSources))
{
lock (completionSources)
{
foreach (TaskCompletionSource<Packet> completionSource in completionSources)
completionSource.SetResult(packet);
}
}
return packet;
}
@ -141,4 +153,104 @@ public abstract class BaseClientNode : Node
await buffer.CopyToAsync(networkStream, cancellationToken);
await networkStream.FlushAsync(cancellationToken);
}
/// <summary>
/// Sends a packet, and waits for a specific packet to be received.
/// </summary>
/// <param name="packetToSend">The packet to send.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <typeparam name="TSend">The type of the packet to send.</typeparam>
/// <typeparam name="TReceive">The type of the packet to return.</typeparam>
/// <returns>The received packet.</returns>
/// <remarks>
/// This method will consume all incoming packets, raising their associated handlers if such packets are recognised.
/// </remarks>
public async Task<TReceive> SendAndReceive<TSend, TReceive>(TSend packetToSend, CancellationToken cancellationToken = default)
where TSend : Packet
where TReceive : Packet
{
var attribute = typeof(TReceive).GetCustomAttribute<PacketAttribute>();
if (attribute is null)
throw new ArgumentException($"The packet type {typeof(TReceive).Name} is not a valid packet.");
var completionSource = new TaskCompletionSource<Packet>();
if (!_packetCompletionSources.TryGetValue(attribute.Id, out List<TaskCompletionSource<Packet>>? completionSources))
{
completionSources = new List<TaskCompletionSource<Packet>>();
_packetCompletionSources.TryAdd(attribute.Id, completionSources);
}
lock (completionSources)
{
if (!completionSources.Contains(completionSource))
completionSources.Add(completionSource);
}
await SendPacketAsync(packetToSend, cancellationToken);
return await WaitForPacketAsync<TReceive>(completionSource, cancellationToken);
}
/// <summary>
/// Waits for a specific packet to be received.
/// </summary>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <typeparam name="TPacket">The type of the packet to return.</typeparam>
/// <returns>The received packet.</returns>
/// <remarks>
/// This method will consume all incoming packets, raising their associated handlers if such packets are recognised.
/// </remarks>
public Task<TPacket> WaitForPacketAsync<TPacket>(CancellationToken cancellationToken = default)
where TPacket : Packet
{
var completionSource = new TaskCompletionSource<Packet>();
return WaitForPacketAsync<TPacket>(completionSource, cancellationToken);
}
private async Task<TPacket> WaitForPacketAsync<TPacket>(
TaskCompletionSource<Packet> completionSource,
CancellationToken cancellationToken = default
)
where TPacket : Packet
{
var attribute = typeof(TPacket).GetCustomAttribute<PacketAttribute>();
if (attribute is null)
throw new ArgumentException($"The packet type {typeof(TPacket).Name} is not a valid packet.");
if (!_packetCompletionSources.TryGetValue(attribute.Id, out List<TaskCompletionSource<Packet>>? completionSources))
{
completionSources = new List<TaskCompletionSource<Packet>>();
_packetCompletionSources.TryAdd(attribute.Id, completionSources);
}
lock (completionSources)
{
if (!completionSources.Contains(completionSource))
completionSources.Add(completionSource);
}
_ = Task.Run(async () =>
{
while (!cancellationToken.IsCancellationRequested)
{
Packet? packet = await ReadNextPacketAsync(cancellationToken);
if (packet is TPacket typedPacket)
{
completionSource.TrySetResult(typedPacket);
return;
}
}
}, cancellationToken);
var packet = (TPacket) await Task.Run(() => completionSource.Task, cancellationToken);
if (_packetCompletionSources.TryGetValue(attribute.Id, out completionSources))
{
lock (completionSources)
{
completionSources.Remove(completionSource);
if (completionSources.Count == 0) _packetCompletionSources.TryRemove(attribute.Id, out _);
}
}
return packet;
}
}