diff --git a/TcpDotNet.ListenerIntegrationTest/Program.cs b/TcpDotNet.ListenerIntegrationTest/Program.cs index c53ec68..82f83d7 100644 --- a/TcpDotNet.ListenerIntegrationTest/Program.cs +++ b/TcpDotNet.ListenerIntegrationTest/Program.cs @@ -20,10 +20,10 @@ await Task.Delay(-1); internal sealed class PingPacketHandler : PacketHandler { - public override async Task HandleAsync(BaseClientNode recipient, PingPacket packet) + public override async Task HandleAsync(BaseClientNode recipient, PingPacket packet, CancellationToken cancellationToken = default) { Console.WriteLine($"Client {recipient.SessionId} sent ping with payload {BitConverter.ToString(packet.Payload)}"); var pong = new PongPacket(packet.Payload); - await recipient.SendPacketAsync(pong); + await recipient.SendPacketAsync(pong, cancellationToken); } } diff --git a/TcpDotNet/BaseClientNode.cs b/TcpDotNet/BaseClientNode.cs index 3bb6801..2140300 100644 --- a/TcpDotNet/BaseClientNode.cs +++ b/TcpDotNet/BaseClientNode.cs @@ -53,8 +53,9 @@ public abstract class BaseClientNode : Node /// /// Reads the next packet from the client's stream. /// + /// A cancellation token that can be used to cancel the asynchronous operation. /// The next packet, or if no valid packet was read. - public async Task ReadNextPacketAsync() + public async Task ReadNextPacketAsync(CancellationToken cancellationToken = default) { await using var networkStream = new NetworkStream(BaseSocket); using var networkReader = new ProtocolReader(networkStream); @@ -97,7 +98,7 @@ public abstract class BaseClientNode : Node await targetStream.DisposeAsync(); if (RegisteredPacketHandlers.TryGetValue(packetType, out IReadOnlyCollection? handlers)) - await Task.WhenAll(handlers.Select(h => h.HandleAsync(this, packet))); + await Task.WhenAll(handlers.Select(h => h.HandleAsync(this, packet, cancellationToken))); return packet; } @@ -106,8 +107,9 @@ public abstract class BaseClientNode : Node /// Sends a packet to the remote endpoint. /// /// The packet to send. + /// A cancellation token that can be used to cancel the asynchronous operation. /// The type of the packet. - public async Task SendPacketAsync(TPacket packet) + public async Task SendPacketAsync(TPacket packet, CancellationToken cancellationToken = default) where TPacket : Packet { var buffer = new MemoryStream(); @@ -130,13 +132,13 @@ public abstract class BaseClientNode : Node break; } - await targetStream.FlushAsync(); + await targetStream.FlushAsync(cancellationToken); buffer.Position = 0; await using var networkStream = new NetworkStream(BaseSocket); await using var networkWriter = new ProtocolWriter(networkStream); networkWriter.Write((int) buffer.Length); - await buffer.CopyToAsync(networkStream); - await networkStream.FlushAsync(); + await buffer.CopyToAsync(networkStream, cancellationToken); + await networkStream.FlushAsync(cancellationToken); } -} \ No newline at end of file +} diff --git a/TcpDotNet/Protocol/PacketHandler.cs b/TcpDotNet/Protocol/PacketHandler.cs index e053517..c303e27 100644 --- a/TcpDotNet/Protocol/PacketHandler.cs +++ b/TcpDotNet/Protocol/PacketHandler.cs @@ -11,7 +11,8 @@ public abstract class PacketHandler /// /// The recipient of the packet. /// The packet to handle. - public abstract Task HandleAsync(BaseClientNode recipient, Packet packet); + /// A cancellation token that can be used to cancel the asynchronous operation. + public abstract Task HandleAsync(BaseClientNode recipient, Packet packet, CancellationToken cancellationToken = default); } /// @@ -26,9 +27,9 @@ public abstract class PacketHandler : PacketHandler public static readonly PacketHandler Empty = new NullPacketHandler(); /// - public override Task HandleAsync(BaseClientNode recipient, Packet packet) + public override Task HandleAsync(BaseClientNode recipient, Packet packet, CancellationToken cancellationToken = default) { - if (packet is T actual) return HandleAsync(recipient, actual); + if (packet is T actual) return HandleAsync(recipient, actual, cancellationToken); return Task.CompletedTask; } @@ -37,7 +38,8 @@ public abstract class PacketHandler : PacketHandler /// /// The recipient of the packet. /// The packet to handle. - public abstract Task HandleAsync(BaseClientNode recipient, T packet); + /// A cancellation token that can be used to cancel the asynchronous operation. + public abstract Task HandleAsync(BaseClientNode recipient, T packet, CancellationToken cancellationToken = default); } /// @@ -48,7 +50,7 @@ internal sealed class NullPacketHandler : PacketHandler where T : Packet { /// - public override Task HandleAsync(BaseClientNode recipient, T packet) + public override Task HandleAsync(BaseClientNode recipient, T packet, CancellationToken cancellationToken = default) { return Task.CompletedTask; } diff --git a/TcpDotNet/ProtocolClient.cs b/TcpDotNet/ProtocolClient.cs index 3e3f797..3e1f909 100644 --- a/TcpDotNet/ProtocolClient.cs +++ b/TcpDotNet/ProtocolClient.cs @@ -22,6 +22,7 @@ public sealed class ProtocolClient : BaseClientNode /// /// The remote host to which this client should connect. /// The remote port to which this client should connect. + /// A cancellation token that can be used to cancel the asynchronous operation. /// is . /// contains an empty string. /// @@ -29,9 +30,9 @@ public sealed class ProtocolClient : BaseClientNode /// than . /// /// An error occurred when attempting to access the socket. - public Task ConnectAsync(string host, int port) + public Task ConnectAsync(string host, int port, CancellationToken cancellationToken = default) { - return ConnectAsync(new DnsEndPoint(host, port)); + return ConnectAsync(new DnsEndPoint(host, port), cancellationToken); } /// @@ -39,15 +40,16 @@ public sealed class ProtocolClient : BaseClientNode /// /// The remote to which this client should connect. /// The remote port to which this client should connect. + /// A cancellation token that can be used to cancel the asynchronous operation. /// /// is less than . -or - is greater /// than . -or- is less than 0 or greater than /// 0x00000000FFFFFFFF. /// /// An error occurred when attempting to access the socket. - public Task ConnectAsync(IPAddress address, int port) + public Task ConnectAsync(IPAddress address, int port, CancellationToken cancellationToken = default) { - return ConnectAsync(new IPEndPoint(address, port)); + return ConnectAsync(new IPEndPoint(address, port), cancellationToken); } ///