refactor: validate session separately

This commit is contained in:
Oliver Booth 2024-02-24 15:04:03 +00:00
parent 0d670554e6
commit 951500ca91
Signed by: oliverbooth
GPG Key ID: E60B570D1B7557B5
4 changed files with 40 additions and 24 deletions

View File

@ -69,7 +69,7 @@ public sealed class AdminController : ControllerBase
[HttpGet("logout")] [HttpGet("logout")]
public IActionResult Logout() public IActionResult Logout()
{ {
if (_sessionService.TryGetSession(Request, out ISession? session, true)) if (_sessionService.TryGetSession(Request, out ISession? session))
_sessionService.DeleteSession(session); _sessionService.DeleteSession(session);
return _sessionService.DeleteSessionCookie(Response); return _sessionService.DeleteSessionCookie(Response);

View File

@ -23,12 +23,18 @@ public class Index : PageModel
public IActionResult OnGet() public IActionResult OnGet()
{ {
if (!_sessionService.TryGetSession(HttpContext.Request, out ISession? session)) if (!_sessionService.TryGetSession(Request, out ISession? session))
{ {
_logger.LogDebug("Session not found; redirecting"); _logger.LogDebug("Session not found; redirecting");
return _sessionService.DeleteSessionCookie(Response); return _sessionService.DeleteSessionCookie(Response);
} }
if (!_sessionService.ValidateSession(Request, session))
{
_logger.LogDebug("Session invalid; redirecting");
return _sessionService.DeleteSessionCookie(Response);
}
if (!_userService.TryGetUser(session.UserId, out IUser? user)) if (!_userService.TryGetUser(session.UserId, out IUser? user))
{ {
_logger.LogDebug("User not found; redirecting"); _logger.LogDebug("User not found; redirecting");

View File

@ -1,4 +1,5 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using Microsoft.AspNetCore.Mvc;
using OliverBooth.Data.Web; using OliverBooth.Data.Web;
using ISession = OliverBooth.Data.Blog.ISession; using ISession = OliverBooth.Data.Blog.ISession;
@ -47,10 +48,20 @@ public interface ISessionService
/// When this method returns, contains the session with the specified request, if the user is found; otherwise, /// When this method returns, contains the session with the specified request, if the user is found; otherwise,
/// <see langword="null" />. /// <see langword="null" />.
/// </param> /// </param>
/// <param name="includeInvalid">
/// <see langword="true" /> to include invalid sessions in the search; otherwise, <see langword="false" />.
/// </param>
/// <returns><see langword="true" /> if the session was found; otherwise, <see langword="false" />.</returns> /// <returns><see langword="true" /> if the session was found; otherwise, <see langword="false" />.</returns>
/// <exception cref="ArgumentNullException"><paramref name="request" /> is <see langword="null" />.</exception> /// <exception cref="ArgumentNullException"><paramref name="request" /> is <see langword="null" />.</exception>
bool TryGetSession(HttpRequest request, [NotNullWhen(true)] out ISession? session, bool includeInvalid = false); bool TryGetSession(HttpRequest request, [NotNullWhen(true)] out ISession? session);
}
/// <summary>
/// Validates the session with the incoming HTTP request.
/// </summary>
/// <param name="request">The HTTP request.</param>
/// <param name="session">The session.</param>
/// <returns><see langword="true" /> if the session is valid; otherwise, <see langword="false" />.</returns>
/// <exception cref="ArgumentNullException">
/// <para><paramref name="request" /> is <see langword="null" />.</para>
/// -or-
/// <para><paramref name="session" /> is <see langword="null" />.</para>
/// </exception>
bool ValidateSession(HttpRequest request, ISession session);
}

View File

@ -1,5 +1,6 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Net; using System.Net;
using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.ChangeTracking; using Microsoft.EntityFrameworkCore.ChangeTracking;
using OliverBooth.Data.Blog; using OliverBooth.Data.Blog;
@ -71,8 +72,7 @@ internal sealed class SessionService : ISessionService
} }
/// <inheritdoc /> /// <inheritdoc />
public bool TryGetSession(HttpRequest request, [NotNullWhen(true)] out ISession? session, public bool TryGetSession(HttpRequest request, [NotNullWhen(true)] out ISession? session)
bool includeInvalid = false)
{ {
if (request is null) throw new ArgumentNullException(nameof(request)); if (request is null) throw new ArgumentNullException(nameof(request));
@ -88,12 +88,20 @@ internal sealed class SessionService : ISessionService
return false; return false;
var sessionId = new Guid(bytes); var sessionId = new Guid(bytes);
if (!TryGetSession(sessionId, out session)) return TryGetSession(sessionId, out session);
return false; }
if (!includeInvalid && session.Expires >= DateTimeOffset.UtcNow) /// <inheritdoc />
public bool ValidateSession(HttpRequest request, ISession session)
{
if (request is null) throw new ArgumentNullException(nameof(request));
if (session is null) throw new ArgumentNullException(nameof(session));
IPAddress? remoteIpAddress = request.HttpContext.Connection.RemoteIpAddress;
if (remoteIpAddress is null) return false;
if (session.Expires >= DateTimeOffset.UtcNow)
{ {
session = null;
return false; return false;
} }
@ -101,22 +109,13 @@ internal sealed class SessionService : ISessionService
Span<byte> sessionAddressBytes = stackalloc byte[16]; Span<byte> sessionAddressBytes = stackalloc byte[16];
if (!remoteIpAddress.TryWriteBytes(remoteAddressBytes, out _) || if (!remoteIpAddress.TryWriteBytes(remoteAddressBytes, out _) ||
!session.IpAddress.TryWriteBytes(sessionAddressBytes, out _)) !session.IpAddress.TryWriteBytes(sessionAddressBytes, out _))
{
session = null;
return false; return false;
}
if (!includeInvalid && !remoteAddressBytes.SequenceEqual(sessionAddressBytes)) if (!remoteAddressBytes.SequenceEqual(sessionAddressBytes))
{
session = null;
return false; return false;
}
if (!includeInvalid && _userService.TryGetUser(session.UserId, out _)) if (_userService.TryGetUser(session.UserId, out _))
{
session = null;
return false; return false;
}
return true; return true;
} }