Merge branch 'develop' into main

This commit is contained in:
Oliver Booth 2022-07-07 13:17:19 +01:00
commit 2d22b1599d
No known key found for this signature in database
GPG Key ID: 32A00B35503AF634
8 changed files with 106 additions and 18 deletions

View File

@ -17,7 +17,7 @@ Console.WriteLine($"Connected to {client.RemoteEndPoint}. My session is {client.
var ping = new PingPacket(); var ping = new PingPacket();
Console.WriteLine($"Sending ping packet with payload: {BitConverter.ToString(ping.Payload)}"); Console.WriteLine($"Sending ping packet with payload: {BitConverter.ToString(ping.Payload)}");
var pong = await client.SendAndReceive<PingPacket, PongPacket>(ping); var pong = await client.SendAndReceiveAsync<PingPacket, PongPacket>(ping);
Console.WriteLine($"Received pong packet with payload: {BitConverter.ToString(pong.Payload)}"); Console.WriteLine($"Received pong packet with payload: {BitConverter.ToString(pong.Payload)}");
Console.WriteLine(pong.Payload.SequenceEqual(ping.Payload) ? "Payload matches!" : "Payload does not match!"); Console.WriteLine(pong.Payload.SequenceEqual(ping.Payload) ? "Payload matches!" : "Payload does not match!");

View File

@ -19,7 +19,7 @@ internal sealed class PingPacketHandler : PacketHandler<PingPacket>
public override async Task HandleAsync(BaseClientNode recipient, PingPacket packet, CancellationToken cancellationToken = default) public override async Task HandleAsync(BaseClientNode recipient, PingPacket packet, CancellationToken cancellationToken = default)
{ {
Console.WriteLine($"Client {recipient.SessionId} sent ping with payload {BitConverter.ToString(packet.Payload)}"); Console.WriteLine($"Client {recipient.SessionId} sent ping with payload {BitConverter.ToString(packet.Payload)}");
var pong = new PongPacket(packet.Payload); var pong = new PongPacket(packet.CallbackId, packet.Payload);
await recipient.SendPacketAsync(pong, cancellationToken); await recipient.SendPacketAsync(pong, cancellationToken);
} }
} }

View File

@ -2,6 +2,7 @@
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Reflection; using System.Reflection;
using System.Runtime.Serialization;
using Chilkat; using Chilkat;
using TcpDotNet.Protocol; using TcpDotNet.Protocol;
using Stream = System.IO.Stream; using Stream = System.IO.Stream;
@ -14,6 +15,7 @@ namespace TcpDotNet;
/// </summary> /// </summary>
public abstract class BaseClientNode : Node public abstract class BaseClientNode : Node
{ {
private readonly ObjectIDGenerator _callbackIdGenerator = new();
private readonly ConcurrentDictionary<int, List<TaskCompletionSource<Packet>>> _packetCompletionSources = new(); private readonly ConcurrentDictionary<int, List<TaskCompletionSource<Packet>>> _packetCompletionSources = new();
/// <summary> /// <summary>
@ -212,7 +214,8 @@ public abstract class BaseClientNode : Node
/// <remarks> /// <remarks>
/// This method will consume all incoming packets, raising their associated handlers if such packets are recognised. /// This method will consume all incoming packets, raising their associated handlers if such packets are recognised.
/// </remarks> /// </remarks>
public async Task<TReceive> SendAndReceive<TSend, TReceive>(TSend packetToSend, CancellationToken cancellationToken = default) public async Task<TReceive> SendAndReceiveAsync<TSend, TReceive>(TSend packetToSend,
CancellationToken cancellationToken = default)
where TSend : Packet where TSend : Packet
where TReceive : Packet where TReceive : Packet
{ {
@ -220,6 +223,10 @@ public abstract class BaseClientNode : Node
if (attribute is null) if (attribute is null)
throw new ArgumentException($"The packet type {typeof(TReceive).Name} is not a valid packet."); throw new ArgumentException($"The packet type {typeof(TReceive).Name} is not a valid packet.");
var requestPacket = packetToSend as RequestPacket;
if (requestPacket is not null)
requestPacket.CallbackId = _callbackIdGenerator.GetId(this, out _);
var completionSource = new TaskCompletionSource<Packet>(); var completionSource = new TaskCompletionSource<Packet>();
if (!_packetCompletionSources.TryGetValue(attribute.Id, out List<TaskCompletionSource<Packet>>? completionSources)) if (!_packetCompletionSources.TryGetValue(attribute.Id, out List<TaskCompletionSource<Packet>>? completionSources))
{ {
@ -234,7 +241,18 @@ public abstract class BaseClientNode : Node
} }
await SendPacketAsync(packetToSend, cancellationToken); await SendPacketAsync(packetToSend, cancellationToken);
return await WaitForPacketAsync<TReceive>(completionSource, cancellationToken); TReceive response;
do
{
response = await WaitForPacketAsync<TReceive>(completionSource, cancellationToken);
if (requestPacket is null)
break;
if (response is ResponsePacket responsePacket && responsePacket.CallbackId == requestPacket.CallbackId)
break;
} while (true);
return response;
} }
/// <summary> /// <summary>

View File

@ -1,17 +1,21 @@
namespace TcpDotNet.Protocol.Packets.ClientBound; namespace TcpDotNet.Protocol.Packets.ClientBound;
/// <summary>
/// Represents a packet which performs a heartbeat response.
/// </summary>
[Packet(0x7FFFFFF1)] [Packet(0x7FFFFFF1)]
public sealed class PongPacket : Packet public sealed class PongPacket : ResponsePacket
{ {
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="PongPacket" /> class. /// Initializes a new instance of the <see cref="PongPacket" /> class.
/// </summary> /// </summary>
public PongPacket(byte[] payload) public PongPacket(long callbackId, byte[] payload)
: base(callbackId)
{ {
Payload = payload[..]; Payload = payload[..];
} }
internal PongPacket() internal PongPacket() : base(0)
{ {
Payload = Array.Empty<byte>(); Payload = Array.Empty<byte>();
} }
@ -23,18 +27,18 @@ public sealed class PongPacket : Packet
public byte[] Payload { get; private set; } public byte[] Payload { get; private set; }
/// <inheritdoc /> /// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader) protected internal override async Task DeserializeAsync(ProtocolReader reader)
{ {
await base.DeserializeAsync(reader);
int length = reader.ReadInt32(); int length = reader.ReadInt32();
Payload = reader.ReadBytes(length); Payload = reader.ReadBytes(length);
return Task.CompletedTask;
} }
/// <inheritdoc /> /// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer) protected internal override async Task SerializeAsync(ProtocolWriter writer)
{ {
await base.SerializeAsync(writer);
writer.Write(Payload.Length); writer.Write(Payload.Length);
writer.Write(Payload); writer.Write(Payload);
return Task.CompletedTask;
} }
} }

View File

@ -2,8 +2,11 @@
namespace TcpDotNet.Protocol.Packets.ServerBound; namespace TcpDotNet.Protocol.Packets.ServerBound;
/// <summary>
/// Represents a packet which performs a heartbeat request.
/// </summary>
[Packet(0x7FFFFFF0)] [Packet(0x7FFFFFF0)]
public sealed class PingPacket : Packet public sealed class PingPacket : RequestPacket
{ {
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="PingPacket" /> class. /// Initializes a new instance of the <see cref="PingPacket" /> class.
@ -22,18 +25,18 @@ public sealed class PingPacket : Packet
public byte[] Payload { get; private set; } public byte[] Payload { get; private set; }
/// <inheritdoc /> /// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader) protected internal override async Task DeserializeAsync(ProtocolReader reader)
{ {
await base.DeserializeAsync(reader);
int length = reader.ReadInt32(); int length = reader.ReadInt32();
Payload = reader.ReadBytes(length); Payload = reader.ReadBytes(length);
return Task.CompletedTask;
} }
/// <inheritdoc /> /// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer) protected internal override async Task SerializeAsync(ProtocolWriter writer)
{ {
await base.SerializeAsync(writer);
writer.Write(Payload.Length); writer.Write(Payload.Length);
writer.Write(Payload); writer.Write(Payload);
return Task.CompletedTask;
} }
} }

View File

@ -0,0 +1,27 @@
namespace TcpDotNet.Protocol;
/// <summary>
/// Represents a request packet, which forms a request/response packet pair.
/// </summary>
public abstract class RequestPacket : Packet
{
/// <summary>
/// Gets the request identifier.
/// </summary>
/// <value>The request identifier.</value>
public long CallbackId { get; internal set; }
/// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader)
{
CallbackId = reader.ReadInt64();
return Task.CompletedTask;
}
/// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer)
{
writer.Write(CallbackId);
return Task.CompletedTask;
}
}

View File

@ -0,0 +1,36 @@
namespace TcpDotNet.Protocol;
/// <summary>
/// Represents a response packet, which forms a request/response packet pair.
/// </summary>
public abstract class ResponsePacket : Packet
{
/// <summary>
/// Initializes a new instance of the <see cref="ResponsePacket" /> class.
/// </summary>
/// <param name="callbackId">The callback ID.</param>
protected ResponsePacket(long callbackId)
{
CallbackId = callbackId;
}
/// <summary>
/// Gets the response identifier.
/// </summary>
/// <value>The response identifier.</value>
public long CallbackId { get; private set; }
/// <inheritdoc />
protected internal override Task DeserializeAsync(ProtocolReader reader)
{
CallbackId = reader.ReadInt64();
return Task.CompletedTask;
}
/// <inheritdoc />
protected internal override Task SerializeAsync(ProtocolWriter writer)
{
writer.Write(CallbackId);
return Task.CompletedTask;
}
}

View File

@ -95,7 +95,7 @@ public sealed class ProtocolClient : BaseClientNode
State = ClientState.Handshaking; State = ClientState.Handshaking;
var handshakeRequest = new HandshakeRequestPacket(ProtocolVersion); var handshakeRequest = new HandshakeRequestPacket(ProtocolVersion);
var handshakeResponse = var handshakeResponse =
await SendAndReceive<HandshakeRequestPacket, HandshakeResponsePacket>(handshakeRequest, cancellationToken); await SendAndReceiveAsync<HandshakeRequestPacket, HandshakeResponsePacket>(handshakeRequest, cancellationToken);
if (handshakeResponse.HandshakeResponse != HandshakeResponse.Success) if (handshakeResponse.HandshakeResponse != HandshakeResponse.Success)
{ {
@ -117,7 +117,7 @@ public sealed class ProtocolClient : BaseClientNode
Aes = CryptographyUtils.GenerateAes(key); Aes = CryptographyUtils.GenerateAes(key);
var encryptionResponse = new EncryptionResponsePacket(encryptedPayload, rsa.Encrypt(key, true)); var encryptionResponse = new EncryptionResponsePacket(encryptedPayload, rsa.Encrypt(key, true));
var sessionPacket = await SendAndReceive<EncryptionResponsePacket, SessionExchangePacket>(encryptionResponse, cancellationToken); var sessionPacket = await SendAndReceiveAsync<EncryptionResponsePacket, SessionExchangePacket>(encryptionResponse, cancellationToken);
SessionId = sessionPacket.Session; SessionId = sessionPacket.Session;
UseEncryption = true; UseEncryption = true;