Add request/response packets

Introduces the concept of a "callback ID". In the event that more than one of the same packet type are received, a callback ID differentiates them.
This commit is contained in:
Oliver Booth 2022-07-07 13:17:13 +01:00
parent 869e18d615
commit 19869dddcb
No known key found for this signature in database
GPG Key ID: 32A00B35503AF634
8 changed files with 106 additions and 18 deletions

View File

@ -17,7 +17,7 @@ Console.WriteLine($"Connected to {client.RemoteEndPoint}. My session is {client.
var ping = new PingPacket();
Console.WriteLine($"Sending ping packet with payload: {BitConverter.ToString(ping.Payload)}");
var pong = await client.SendAndReceive<PingPacket, PongPacket>(ping);
var pong = await client.SendAndReceiveAsync<PingPacket, PongPacket>(ping);
Console.WriteLine($"Received pong packet with payload: {BitConverter.ToString(pong.Payload)}");
Console.WriteLine(pong.Payload.SequenceEqual(ping.Payload) ? "Payload matches!" : "Payload does not match!");

View File

@ -19,7 +19,7 @@ internal sealed class PingPacketHandler : PacketHandler<PingPacket>
public override async Task HandleAsync(BaseClientNode recipient, PingPacket packet, CancellationToken cancellationToken = default)
{
Console.WriteLine($"Client {recipient.SessionId} sent ping with payload {BitConverter.ToString(packet.Payload)}");
var pong = new PongPacket(packet.Payload);
var pong = new PongPacket(packet.CallbackId, packet.Payload);
await recipient.SendPacketAsync(pong, cancellationToken);
}
}

View File

@ -2,6 +2,7 @@
using System.Net;
using System.Net.Sockets;
using System.Reflection;
using System.Runtime.Serialization;
using Chilkat;
using TcpDotNet.Protocol;
using Stream = System.IO.Stream;
@ -14,6 +15,7 @@ namespace TcpDotNet;
/// </summary>
public abstract class BaseClientNode : Node
{
private readonly ObjectIDGenerator _callbackIdGenerator = new();
private readonly ConcurrentDictionary<int, List<TaskCompletionSource<Packet>>> _packetCompletionSources = new();
/// <summary>
@ -212,7 +214,8 @@ public abstract class BaseClientNode : Node
/// <remarks>
/// This method will consume all incoming packets, raising their associated handlers if such packets are recognised.
/// </remarks>
public async Task<TReceive> SendAndReceive<TSend, TReceive>(TSend packetToSend, CancellationToken cancellationToken = default)
public async Task<TReceive> SendAndReceiveAsync<TSend, TReceive>(TSend packetToSend,
CancellationToken cancellationToken = default)
where TSend : Packet
where TReceive : Packet
{
@ -220,6 +223,10 @@ public abstract class BaseClientNode : Node
if (attribute is null)
throw new ArgumentException($"The packet type {typeof(TReceive).Name} is not a valid packet.");
var requestPacket = packetToSend as RequestPacket;
if (requestPacket is not null)
requestPacket.CallbackId = _callbackIdGenerator.GetId(this, out _);
var completionSource = new TaskCompletionSource<Packet>();
if (!_packetCompletionSources.TryGetValue(attribute.Id, out List<TaskCompletionSource<Packet>>? completionSources))
{
@ -234,7 +241,18 @@ public abstract class BaseClientNode : Node
}
await SendPacketAsync(packetToSend, cancellationToken);
return await WaitForPacketAsync<TReceive>(completionSource, cancellationToken);
TReceive response;
do
{
response = await WaitForPacketAsync<TReceive>(completionSource, cancellationToken);
if (requestPacket is null)
break;
if (response is ResponsePacket responsePacket && responsePacket.CallbackId == requestPacket.CallbackId)
break;
} while (true);
return response;
}
/// <summary>

View File

@ -1,17 +1,21 @@
namespace TcpDotNet.Protocol.Packets.ClientBound;
/// <summary>
/// Represents a packet which performs a heartbeat response.
/// </summary>
[Packet(0x7FFFFFF1)]
public sealed class PongPacket : Packet
public sealed class PongPacket : ResponsePacket
{
/// <summary>
/// Initializes a new instance of the <see cref="PongPacket" /> class.
/// </summary>
public PongPacket(byte[] payload)
public PongPacket(long callbackId, byte[] payload)
: base(callbackId)
{
Payload = payload[..];
}
internal PongPacket()
internal PongPacket() : base(0)
{
Payload = Array.Empty<byte>();
}
@ -23,18 +27,18 @@ public sealed class PongPacket : Packet
public byte[] Payload { get; private set; }
/// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader)
protected internal override async Task DeserializeAsync(ProtocolReader reader)
{
await base.DeserializeAsync(reader);
int length = reader.ReadInt32();
Payload = reader.ReadBytes(length);
return Task.CompletedTask;
}
/// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer)
protected internal override async Task SerializeAsync(ProtocolWriter writer)
{
await base.SerializeAsync(writer);
writer.Write(Payload.Length);
writer.Write(Payload);
return Task.CompletedTask;
}
}

View File

@ -2,8 +2,11 @@
namespace TcpDotNet.Protocol.Packets.ServerBound;
/// <summary>
/// Represents a packet which performs a heartbeat request.
/// </summary>
[Packet(0x7FFFFFF0)]
public sealed class PingPacket : Packet
public sealed class PingPacket : RequestPacket
{
/// <summary>
/// Initializes a new instance of the <see cref="PingPacket" /> class.
@ -22,18 +25,18 @@ public sealed class PingPacket : Packet
public byte[] Payload { get; private set; }
/// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader)
protected internal override async Task DeserializeAsync(ProtocolReader reader)
{
await base.DeserializeAsync(reader);
int length = reader.ReadInt32();
Payload = reader.ReadBytes(length);
return Task.CompletedTask;
}
/// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer)
protected internal override async Task SerializeAsync(ProtocolWriter writer)
{
await base.SerializeAsync(writer);
writer.Write(Payload.Length);
writer.Write(Payload);
return Task.CompletedTask;
}
}

View File

@ -0,0 +1,27 @@
namespace TcpDotNet.Protocol;
/// <summary>
/// Represents a request packet, which forms a request/response packet pair.
/// </summary>
public abstract class RequestPacket : Packet
{
/// <summary>
/// Gets the request identifier.
/// </summary>
/// <value>The request identifier.</value>
public long CallbackId { get; internal set; }
/// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader)
{
CallbackId = reader.ReadInt64();
return Task.CompletedTask;
}
/// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer)
{
writer.Write(CallbackId);
return Task.CompletedTask;
}
}

View File

@ -0,0 +1,36 @@
namespace TcpDotNet.Protocol;
/// <summary>
/// Represents a response packet, which forms a request/response packet pair.
/// </summary>
public abstract class ResponsePacket : Packet
{
/// <summary>
/// Initializes a new instance of the <see cref="ResponsePacket" /> class.
/// </summary>
/// <param name="callbackId">The callback ID.</param>
protected ResponsePacket(long callbackId)
{
CallbackId = callbackId;
}
/// <summary>
/// Gets the response identifier.
/// </summary>
/// <value>The response identifier.</value>
public long CallbackId { get; private set; }
/// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader)
{
CallbackId = reader.ReadInt64();
return Task.CompletedTask;
}
/// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer)
{
writer.Write(CallbackId);
return Task.CompletedTask;
}
}

View File

@ -95,7 +95,7 @@ public sealed class ProtocolClient : BaseClientNode
State = ClientState.Handshaking;
var handshakeRequest = new HandshakeRequestPacket(ProtocolVersion);
var handshakeResponse =
await SendAndReceive<HandshakeRequestPacket, HandshakeResponsePacket>(handshakeRequest, cancellationToken);
await SendAndReceiveAsync<HandshakeRequestPacket, HandshakeResponsePacket>(handshakeRequest, cancellationToken);
if (handshakeResponse.HandshakeResponse != HandshakeResponse.Success)
{
@ -117,7 +117,7 @@ public sealed class ProtocolClient : BaseClientNode
Aes = CryptographyUtils.GenerateAes(key);
var encryptionResponse = new EncryptionResponsePacket(encryptedPayload, rsa.Encrypt(key, true));
var sessionPacket = await SendAndReceive<EncryptionResponsePacket, SessionExchangePacket>(encryptionResponse, cancellationToken);
var sessionPacket = await SendAndReceiveAsync<EncryptionResponsePacket, SessionExchangePacket>(encryptionResponse, cancellationToken);
SessionId = sessionPacket.Session;
UseEncryption = true;