oliverbooth.dev/OliverBooth.Common/Services/SessionService.cs

245 lines
7.8 KiB
C#

using System.Diagnostics.CodeAnalysis;
using System.Net;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using OliverBooth.Common.Data.Web;
using OliverBooth.Common.Data.Web.Users;
using ISession = OliverBooth.Common.Data.Web.Users.ISession;
namespace OliverBooth.Common.Services;
internal sealed class SessionService : BackgroundService, ISessionService
{
private readonly ILogger<SessionService> _logger;
private readonly IUserService _userService;
private readonly IDbContextFactory<WebContext> _webContextFactory;
/// <summary>
/// Initializes a new instance of the <see cref="SessionService" /> class.
/// </summary>
/// <param name="logger">The logger.</param>
/// <param name="userService">The user service.</param>
/// <param name="webContextFactory">The <see cref="WebContext" /> factory.</param>
public SessionService(ILogger<SessionService> logger,
IUserService userService,
IDbContextFactory<WebContext> webContextFactory)
{
_logger = logger;
_userService = userService;
_webContextFactory = webContextFactory;
}
/// <inheritdoc />
public ISession CreateSession(HttpRequest request, IUser user)
{
if (request is null)
{
throw new ArgumentNullException(nameof(request));
}
if (user is null)
{
throw new ArgumentNullException(nameof(user));
}
using WebContext context = _webContextFactory.CreateDbContext();
var now = DateTimeOffset.UtcNow;
var session = new Session
{
UserId = user.Id,
IpAddress = request.HttpContext.Connection.RemoteIpAddress!,
Created = now,
Updated = now,
LastAccessed = now,
Expires = now + TimeSpan.FromDays(1),
RequiresTotp = !string.IsNullOrWhiteSpace(user.Totp),
UserAgent = request.Headers.UserAgent.ToString()
};
EntityEntry<Session> entry = context.Sessions.Add(session);
context.SaveChanges();
return entry.Entity;
}
/// <inheritdoc />
public void DeleteSession(ISession session)
{
using WebContext context = _webContextFactory.CreateDbContext();
context.Sessions.Remove((Session)session);
context.SaveChanges();
}
/// <inheritdoc />
public IActionResult DeleteSessionCookie(HttpResponse response)
{
response.Cookies.Delete("sid");
return new RedirectToPageResult("/Admin/Login");
}
/// <inheritdoc />
public void SaveSessionCookie(HttpResponse response, ISession session)
{
if (response is null)
{
throw new ArgumentNullException(nameof(response));
}
if (session is null)
{
throw new ArgumentNullException(nameof(session));
}
Span<byte> buffer = stackalloc byte[16];
if (!session.Id.TryWriteBytes(buffer))
{
return;
}
IPAddress? remoteIpAddress = response.HttpContext.Connection.RemoteIpAddress;
_logger.LogDebug("Writing cookie 'sid' to HTTP response for {RemoteAddr}", remoteIpAddress);
response.Cookies.Append("sid", Convert.ToBase64String(buffer), new CookieOptions
{
Expires = DateTimeOffset.UtcNow + TimeSpan.FromDays(30),
Secure = true,
SameSite = SameSiteMode.Strict
});
}
/// <inheritdoc />
public bool TryGetCurrentUser(HttpRequest request, HttpResponse response, [NotNullWhen(true)] out IUser? user)
{
user = null;
if (!TryGetSession(request, out ISession? session))
{
_logger.LogDebug("Session not found; redirecting");
DeleteSessionCookie(response);
return false;
}
if (!ValidateSession(request, session))
{
_logger.LogDebug("Session invalid; redirecting");
DeleteSessionCookie(response);
return false;
}
if (!_userService.TryGetUser(session.UserId, out user))
{
_logger.LogDebug("User not found; redirecting");
DeleteSessionCookie(response);
return false;
}
return true;
}
/// <inheritdoc />
public bool TryGetSession(Guid sessionId, [NotNullWhen(true)] out ISession? session)
{
using WebContext context = _webContextFactory.CreateDbContext();
session = context.Sessions.FirstOrDefault(s => s.Id == sessionId);
return session is not null;
}
/// <inheritdoc />
public bool TryGetSession(HttpRequest request, [NotNullWhen(true)] out ISession? session)
{
if (request is null)
{
throw new ArgumentNullException(nameof(request));
}
session = null;
IPAddress? remoteIpAddress = request.HttpContext.Connection.RemoteIpAddress;
if (remoteIpAddress is null)
{
return false;
}
if (!request.Cookies.TryGetValue("sid", out string? sessionIdCookie))
{
return false;
}
Span<byte> bytes = stackalloc byte[16];
if (!Convert.TryFromBase64Chars(sessionIdCookie, bytes, out int bytesWritten) || bytesWritten < 16)
{
return false;
}
var sessionId = new Guid(bytes);
return TryGetSession(sessionId, out session);
}
/// <inheritdoc />
public bool ValidateSession(HttpRequest request, ISession session)
{
if (request is null)
{
throw new ArgumentNullException(nameof(request));
}
if (session is null)
{
throw new ArgumentNullException(nameof(session));
}
IPAddress? remoteIpAddress = request.HttpContext.Connection.RemoteIpAddress;
if (remoteIpAddress is null)
{
return false;
}
if (session.Expires <= DateTimeOffset.UtcNow)
{
_logger.LogInformation("Session {Id} has expired (client {Ip})", session.Id, remoteIpAddress);
return false;
}
Span<byte> remoteAddressBytes = stackalloc byte[16];
Span<byte> sessionAddressBytes = stackalloc byte[16];
if (!remoteIpAddress.TryWriteBytes(remoteAddressBytes, out _) ||
!session.IpAddress.TryWriteBytes(sessionAddressBytes, out _))
{
_logger.LogWarning("Failed to write bytes for session {Id}", session.Id);
return false;
}
if (!remoteAddressBytes.SequenceEqual(sessionAddressBytes))
{
_logger.LogInformation("Session {Id} has IP mismatch (wanted {Expected}, got {Actual})", session.Id,
session.IpAddress, remoteIpAddress);
return false;
}
var userAgent = request.Headers.UserAgent.ToString();
if (session.UserAgent != userAgent)
{
_logger.LogInformation("Session {Id} has user agent mismatch (wanted {Expected}, got {Actual})", session.Id,
session.UserAgent, userAgent);
return false;
}
if (!_userService.TryGetUser(session.UserId, out _))
{
_logger.LogWarning("User {Id} not found for session {Session} (client {Ip})", session.UserId, session.Id,
remoteIpAddress);
return false;
}
return true;
}
/// <inheritdoc />
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
await using WebContext context = await _webContextFactory.CreateDbContextAsync(stoppingToken);
context.Sessions.RemoveRange(context.Sessions.Where(s => s.Expires <= DateTimeOffset.UtcNow));
await context.SaveChangesAsync(stoppingToken);
}
}