TcpDotNet/TcpDotNet/BaseClientNode.cs

344 lines
12 KiB
C#

using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using System.Reflection;
using System.Runtime.Serialization;
using Chilkat;
using TcpDotNet.Protocol;
using Stream = System.IO.Stream;
using Task = System.Threading.Tasks.Task;
namespace TcpDotNet;
/// <summary>
/// Represents a client node.
/// </summary>
public abstract class BaseClientNode : Node
{
private readonly ObjectIDGenerator _callbackIdGenerator = new();
private readonly ConcurrentDictionary<int, List<TaskCompletionSource<Packet>>> _packetCompletionSources = new();
private EndPoint? _remoteEP;
/// <summary>
/// Initializes a new instance of the <see cref="BaseClientNode" /> class.
/// </summary>
protected BaseClientNode()
{
}
/// <summary>
/// Gets a value indicating whether the client is connected.
/// </summary>
/// <value><see langword="true" /> if the client is connected; otherwise, <see langword="false" />.</value>
public bool IsConnected { get; protected set; }
/// <summary>
/// Gets the remote endpoint.
/// </summary>
/// <value>The <see cref="EndPoint" /> with which the client is communicating.</value>
/// <exception cref="SocketException">An error occurred when attempting to access the socket.</exception>
public EndPoint RemoteEndPoint
{
get => _remoteEP ??= BaseSocket.RemoteEndPoint;
internal set => _remoteEP = value;
}
/// <summary>
/// Gets the session ID of the client.
/// </summary>
/// <value>The session ID.</value>
public Guid SessionId { get; internal set; }
/// <summary>
/// Gets the current state of the client.
/// </summary>
public ClientState State { get; protected internal set; }
/// <summary>
/// Gets or sets a value indicating whether GZip compression is enabled.
/// </summary>
/// <value><see langword="true" /> if compression is enabled; otherwise, <see langword="false" />.</value>
internal bool UseCompression { get; set; } = true;
/// <summary>
/// Gets or sets a value indicating whether encryption is enabled.
/// </summary>
/// <value><see langword="true" /> if encryption is enabled; otherwise, <see langword="false" />.</value>
internal bool UseEncryption { get; set; } = false;
/// <summary>
/// Gets or sets the AES implementation used by this client.
/// </summary>
/// <value>The AES implementation.</value>
internal Crypt2 Aes { get; set; } = CryptographyUtils.GenerateAes(Array.Empty<byte>());
/// <inheritdoc />
public override void Close()
{
IsConnected = false;
State = ClientState.Disconnected;
base.Close();
}
/// <inheritdoc />
public override void Dispose()
{
IsConnected = false;
State = ClientState.Disconnected;
base.Dispose();
}
/// <summary>
/// Reads the next packet from the client's stream.
/// </summary>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <returns>The next packet, or <see langword="null" /> if no valid packet was read.</returns>
public async Task<Packet?> ReadNextPacketAsync(CancellationToken cancellationToken = default)
{
await using var networkStream = new NetworkStream(BaseSocket);
using var networkReader = new ProtocolReader(networkStream);
int length;
try
{
length = networkReader.ReadInt32();
}
catch (EndOfStreamException)
{
State = ClientState.Disconnected;
throw new DisconnectedException();
}
var buffer = new MemoryStream();
Stream targetStream = buffer;
buffer.Write(networkReader.ReadBytes(length));
buffer.Position = 0;
// if (UseCompression) targetStream = new GZipStream(targetStream, CompressionLevel.Optimal);
if (UseEncryption)
{
var data = new byte[targetStream.Length];
_ = await targetStream.ReadAsync(data, 0, data.Length, cancellationToken);
buffer.SetLength(0);
buffer.Write(Aes.DecryptBytes(data));
buffer.Position = 0;
}
using var bufferReader = new ProtocolReader(targetStream);
int packetHeader;
try
{
packetHeader = bufferReader.ReadInt32();
}
catch (EndOfStreamException)
{
State = ClientState.Disconnected;
throw;
}
if (!RegisteredPackets.TryGetValue(packetHeader, out Type? packetType))
{
Console.WriteLine($"Unknown packet {packetHeader:X8}");
return null;
}
const BindingFlags bindingFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance;
ConstructorInfo? constructor =
packetType.GetConstructors(bindingFlags).FirstOrDefault(c => c.GetParameters().Length == 0);
if (constructor is null)
return null;
var packet = (Packet) constructor.Invoke(null);
packet.Deserialize(bufferReader);
await targetStream.DisposeAsync();
if (RegisteredPacketHandlers.TryGetValue(packetType, out IReadOnlyCollection<PacketHandler>? handlers))
await Task.WhenAll(handlers.Select(h => h.HandleAsync(this, packet, cancellationToken)));
if (_packetCompletionSources.TryGetValue(packet.Id, out List<TaskCompletionSource<Packet>>? completionSources))
{
lock (completionSources)
{
foreach (TaskCompletionSource<Packet> completionSource in completionSources)
completionSource.SetResult(packet);
}
}
return packet;
}
/// <summary>
/// Sends a packet to the remote endpoint.
/// </summary>
/// <param name="packet">The packet to send.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <typeparam name="TPacket">The type of the packet.</typeparam>
public async Task SendPacketAsync<TPacket>(TPacket packet, CancellationToken cancellationToken = default)
where TPacket : Packet
{
var buffer = new MemoryStream();
Stream targetStream = buffer;
// if (UseCompression) targetStream = new GZipStream(targetStream, CompressionMode.Compress);
await using var bufferWriter = new ProtocolWriter(targetStream);
bufferWriter.Write(packet.Id);
packet.Serialize(bufferWriter);
await targetStream.FlushAsync(cancellationToken);
buffer.Position = 0;
await using var networkStream = new NetworkStream(BaseSocket);
await using var networkWriter = new ProtocolWriter(networkStream);
if (UseEncryption)
{
byte[] data = buffer.ToArray();
buffer.SetLength(0);
buffer.Write(Aes.EncryptBytes(data));
buffer.Position = 0;
}
var length = (int) buffer.Length;
networkWriter.Write(length);
try
{
await buffer.CopyToAsync(networkStream, cancellationToken);
await networkStream.FlushAsync(cancellationToken);
}
catch (IOException)
{
State = ClientState.Disconnected;
IsConnected = false;
if (this is ProtocolClient client)
client.OnDisconnect(DisconnectReason.EndOfStream);
}
}
/// <summary>
/// Sends a packet, and waits for a specific packet to be received.
/// </summary>
/// <param name="packetToSend">The packet to send.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <typeparam name="TSend">The type of the packet to send.</typeparam>
/// <typeparam name="TReceive">The type of the packet to return.</typeparam>
/// <returns>The received packet.</returns>
/// <remarks>
/// This method will consume all incoming packets, raising their associated handlers if such packets are recognised.
/// </remarks>
public async Task<TReceive> SendAndReceiveAsync<TSend, TReceive>(TSend packetToSend,
CancellationToken cancellationToken = default)
where TSend : Packet
where TReceive : Packet
{
var attribute = typeof(TReceive).GetCustomAttribute<PacketAttribute>();
if (attribute is null)
throw new ArgumentException($"The packet type {typeof(TReceive).Name} is not a valid packet.");
var requestPacket = packetToSend as RequestPacket;
if (requestPacket is not null)
requestPacket.CallbackId = _callbackIdGenerator.GetId(packetToSend, out _);
var completionSource = new TaskCompletionSource<Packet>();
if (!_packetCompletionSources.TryGetValue(attribute.Id, out List<TaskCompletionSource<Packet>>? completionSources))
{
completionSources = new List<TaskCompletionSource<Packet>>();
_packetCompletionSources.TryAdd(attribute.Id, completionSources);
}
lock (completionSources)
{
if (!completionSources.Contains(completionSource))
completionSources.Add(completionSource);
}
await SendPacketAsync(packetToSend, cancellationToken);
TReceive response;
do
{
response = await WaitForPacketAsync<TReceive>(completionSource, cancellationToken);
if (requestPacket is null)
break;
if (response is ResponsePacket responsePacket && responsePacket.CallbackId == requestPacket.CallbackId)
break;
} while (true);
return response;
}
/// <summary>
/// Waits for a specific packet to be received.
/// </summary>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <typeparam name="TPacket">The type of the packet to return.</typeparam>
/// <returns>The received packet.</returns>
/// <remarks>
/// This method will consume all incoming packets, raising their associated handlers if such packets are recognised.
/// </remarks>
public Task<TPacket> WaitForPacketAsync<TPacket>(CancellationToken cancellationToken = default)
where TPacket : Packet
{
var completionSource = new TaskCompletionSource<Packet>();
return WaitForPacketAsync<TPacket>(completionSource, cancellationToken);
}
private async Task<TPacket> WaitForPacketAsync<TPacket>(
TaskCompletionSource<Packet> completionSource,
CancellationToken cancellationToken = default
)
where TPacket : Packet
{
var attribute = typeof(TPacket).GetCustomAttribute<PacketAttribute>();
if (attribute is null)
throw new ArgumentException($"The packet type {typeof(TPacket).Name} is not a valid packet.");
if (!_packetCompletionSources.TryGetValue(attribute.Id, out List<TaskCompletionSource<Packet>>? completionSources))
{
completionSources = new List<TaskCompletionSource<Packet>>();
_packetCompletionSources.TryAdd(attribute.Id, completionSources);
}
lock (completionSources)
{
if (!completionSources.Contains(completionSource))
completionSources.Add(completionSource);
}
_ = Task.Run(async () =>
{
while (!cancellationToken.IsCancellationRequested)
{
try
{
Packet? packet = await ReadNextPacketAsync(cancellationToken);
if (packet is TPacket typedPacket)
{
completionSource.TrySetResult(typedPacket);
return;
}
}
catch (TaskCanceledException)
{
completionSource.SetCanceled();
}
}
completionSource.SetCanceled();
}, cancellationToken);
var packet = (TPacket) await Task.Run(() => completionSource.Task, cancellationToken);
if (_packetCompletionSources.TryGetValue(attribute.Id, out completionSources))
{
lock (completionSources)
{
completionSources.Remove(completionSource);
if (completionSources.Count == 0) _packetCompletionSources.TryRemove(attribute.Id, out _);
}
}
return packet;
}
}