{-# LANGUAGE OverloadedStrings #-}

module Arbiter.Worker.Dispatcher
  ( runDispatcher
  ) where

import Arbiter.Core.HighLevel (QueueOperation)
import Arbiter.Core.HighLevel qualified as Arb
import Arbiter.Core.Job.Schema qualified as Schema
import Arbiter.Core.Job.Types (JobRead)
import Arbiter.Core.QueueRegistry (TableForPayload)
import Control.Monad (void, when)
import Control.Monad.IO.Class (liftIO)
import Data.ByteString.Char8 qualified as BSC
import Data.Foldable (traverse_)
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Proxy (Proxy (..))
import Data.Text qualified as T
import Data.Time (UTCTime, addUTCTime, getCurrentTime)
import GHC.TypeLits (symbolVal)
import UnliftIO (MonadUnliftIO)
import UnliftIO.Exception qualified as Ex
import UnliftIO.MVar qualified as MVar
import UnliftIO.STM qualified as STM

import Arbiter.Worker.Config (HandlerMode (..), WorkerConfig (..))
import Arbiter.Worker.Logger (LogLevel (..))
import Arbiter.Worker.Logger.Internal (logMessage)
import Arbiter.Worker.NotificationListener (withNotificationLoop)

-- | Run the dispatcher loop
--
-- The dispatcher wakes on PostgreSQL NOTIFY (when jobs are inserted) or poll timer
-- expiration, then claims jobs up to available worker capacity. If no workers are
-- free, the wakeup is a no-op.
--
-- __Responsiveness__: Workers that finish between notifications may sit idle until the
-- next notification or poll. For maximum responsiveness with existing jobs, set a short
-- poll interval (e.g., @pollInterval = Just 1@ for 1 second). For workloads with
-- continuous job insertion, NOTIFY provides sub-second latency.
--
-- The dispatcher runs in a dedicated thread with its own database connection.
runDispatcher
  :: forall m registry payload result
   . ( MonadUnliftIO m
     , QueueOperation m registry payload
     )
  => WorkerConfig m payload result
  -> Int
  -- ^ Worker capacity
  -> STM.TBQueue (NonEmpty (JobRead payload))
  -- ^ Work queue (batches of jobs)
  -> STM.TVar Int
  -- ^ Busy worker count
  -> Maybe (MVar.MVar ())
  -- ^ Liveness signal (pulsed after each successful claim cycle)
  -> STM.TVar Bool
  -- ^ Worker finished signal
  -> m ()
runDispatcher :: forall (m :: * -> *) (registry :: JobPayloadRegistry) payload
       result.
(MonadUnliftIO m, QueueOperation m registry payload) =>
WorkerConfig m payload result
-> Int
-> TBQueue (NonEmpty (JobRead payload))
-> TVar Int
-> Maybe (MVar ())
-> TVar Bool
-> m ()
runDispatcher WorkerConfig m payload result
config Int
workerCapacity TBQueue (NonEmpty (JobRead payload))
workQueue TVar Int
busyWorkerCount Maybe (MVar ())
mLivenessMVar TVar Bool
workerFinishedVar = do
  throttleRef <- IO (IORef (Maybe (Int, UTCTime)))
-> m (IORef (Maybe (Int, UTCTime)))
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IORef (Maybe (Int, UTCTime)))
 -> m (IORef (Maybe (Int, UTCTime))))
-> IO (IORef (Maybe (Int, UTCTime)))
-> m (IORef (Maybe (Int, UTCTime)))
forall a b. (a -> b) -> a -> b
$ Maybe (Int, UTCTime) -> IO (IORef (Maybe (Int, UTCTime)))
forall a. a -> IO (IORef a)
newIORef Maybe (Int, UTCTime)
forall a. Maybe a
Nothing

  let
    tableNameVal = String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ Proxy (TableForPayload payload registry) -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (forall {k} (t :: k). Proxy t
forall (t :: Symbol). Proxy t
Proxy @(TableForPayload payload registry))

    calcFreeWorkers :: STM.STM Int
    calcFreeWorkers = do
      busyCount <- TVar Int -> STM Int
forall a. TVar a -> STM a
STM.readTVar TVar Int
busyWorkerCount
      qLen <- fromIntegral <$> STM.lengthTBQueue workQueue
      pure $ workerCapacity - (busyCount + qLen)

    -- Get free workers if any are available (non-blocking)
    getFreeWorkers :: STM.STM (Maybe Int)
    getFreeWorkers = do
      free <- STM Int
calcFreeWorkers
      pure $ if free > 0 then Just free else Nothing

    claimAndEnqueue :: Int -> m ()
    claimAndEnqueue Int
freeWorkers = do
      eJobs <- m [NonEmpty (JobRead payload)]
-> m (Either SomeException [NonEmpty (JobRead payload)])
forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> m (Either SomeException a)
Ex.tryAny (m [NonEmpty (JobRead payload)]
 -> m (Either SomeException [NonEmpty (JobRead payload)]))
-> m [NonEmpty (JobRead payload)]
-> m (Either SomeException [NonEmpty (JobRead payload)])
forall a b. (a -> b) -> a -> b
$ case WorkerConfig m payload result -> HandlerMode m payload result
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> HandlerMode m payload result
handlerMode WorkerConfig m payload result
config of
        SingleJobMode Map Int64 (Either Text result)
-> Map Int64 Text -> JobHandler m payload result
_ ->
          ([JobRead payload] -> [NonEmpty (JobRead payload)])
-> m [JobRead payload] -> m [NonEmpty (JobRead payload)]
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((JobRead payload -> NonEmpty (JobRead payload))
-> [JobRead payload] -> [NonEmpty (JobRead payload)]
forall a b. (a -> b) -> [a] -> [b]
map (JobRead payload -> [JobRead payload] -> NonEmpty (JobRead payload)
forall a. a -> [a] -> NonEmpty a
:| [])) (Int -> NominalDiffTime -> m [JobRead payload]
forall (m :: * -> *) (registry :: JobPayloadRegistry) payload.
QueueOperation m registry payload =>
Int -> NominalDiffTime -> m [JobRead payload]
Arb.claimNextVisibleJobs Int
freeWorkers (WorkerConfig m payload result -> NominalDiffTime
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> NominalDiffTime
visibilityTimeout WorkerConfig m payload result
config))
        BatchedJobsMode Int
batchSize BatchedJobHandler m payload result
_ ->
          Int -> Int -> NominalDiffTime -> m [NonEmpty (JobRead payload)]
forall (m :: * -> *) (registry :: JobPayloadRegistry) payload.
QueueOperation m registry payload =>
Int -> Int -> NominalDiffTime -> m [NonEmpty (JobRead payload)]
Arb.claimNextVisibleJobsBatched Int
batchSize Int
freeWorkers (WorkerConfig m payload result -> NominalDiffTime
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> NominalDiffTime
visibilityTimeout WorkerConfig m payload result
config)
      case eJobs of
        Left SomeException
e -> do
          m (Either SomeException ()) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (Either SomeException ()) -> m ())
-> (m () -> m (Either SomeException ())) -> m () -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m () -> m (Either SomeException ())
forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> m (Either SomeException a)
Ex.tryAny (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ LogConfig -> LogLevel -> Text -> IO ()
logMessage (WorkerConfig m payload result -> LogConfig
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> LogConfig
logConfig WorkerConfig m payload result
config) LogLevel
Error (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String
"Dispatcher exception: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> SomeException -> String
forall a. Show a => a -> String
show SomeException
e
        Right [NonEmpty (JobRead payload)]
batches -> do
          STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
STM.atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ (NonEmpty (JobRead payload) -> STM ())
-> [NonEmpty (JobRead payload)] -> STM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (TBQueue (NonEmpty (JobRead payload))
-> NonEmpty (JobRead payload) -> STM ()
forall a. TBQueue a -> a -> STM ()
STM.writeTBQueue TBQueue (NonEmpty (JobRead payload))
workQueue) [NonEmpty (JobRead payload)]
batches
          (MVar () -> m Bool) -> Maybe (MVar ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ((MVar () -> () -> m Bool) -> () -> MVar () -> m Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip MVar () -> () -> m Bool
forall (m :: * -> *) a. MonadIO m => MVar a -> a -> m Bool
MVar.tryPutMVar ()) Maybe (MVar ())
mLivenessMVar

    -- Apply throttle limits, returning the number of jobs allowed to claim
    applyThrottle :: IORef (Maybe (Int, UTCTime)) -> Int -> IO Int
    applyThrottle IORef (Maybe (Int, UTCTime))
ref Int
freeWorkers = case WorkerConfig m payload result -> Maybe (IO (Int, NominalDiffTime))
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> Maybe (IO (Int, NominalDiffTime))
claimThrottle WorkerConfig m payload result
config of
      Maybe (IO (Int, NominalDiffTime))
Nothing -> Int -> IO Int
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
freeWorkers
      Just IO (Int, NominalDiffTime)
getThrottle -> do
        (maxClaims, window) <- IO (Int, NominalDiffTime)
getThrottle
        now <- getCurrentTime
        -- Ensure we have a valid window with available tokens
        (tokens, windowStart) <-
          readIORef ref >>= \case
            Maybe (Int, UTCTime)
Nothing ->
              -- First cycle: start a fresh window
              (Int, UTCTime) -> IO (Int, UTCTime)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
maxClaims, UTCTime
now)
            Just (Int
tokens, UTCTime
windowStart)
              | UTCTime
now UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
>= NominalDiffTime -> UTCTime -> UTCTime
addUTCTime NominalDiffTime
window UTCTime
windowStart ->
                  -- Window expired: reset
                  (Int, UTCTime) -> IO (Int, UTCTime)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
maxClaims, UTCTime
now)
              | Int
tokens Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 ->
                  -- Budget remains
                  (Int, UTCTime) -> IO (Int, UTCTime)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
tokens, UTCTime
windowStart)
              | Bool
otherwise ->
                  -- Budget exhausted: skip this cycle, next wakeup will re-check
                  (Int, UTCTime) -> IO (Int, UTCTime)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
0, UTCTime
windowStart)
        -- Spend from budget
        let allowed = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
freeWorkers Int
tokens
        writeIORef ref (Just (tokens - allowed, windowStart))
        pure allowed

    -- Claim jobs on wakeup if workers are available
    claimOnWakeup :: m ()
    claimOnWakeup = do
      mFree <- STM (Maybe Int) -> m (Maybe Int)
forall (m :: * -> *) a. MonadIO m => STM a -> m a
STM.atomically STM (Maybe Int)
getFreeWorkers
      case mFree of
        Maybe Int
Nothing -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Just Int
freeWorkers -> do
          allowed <- IO Int -> m Int
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> m Int) -> IO Int -> m Int
forall a b. (a -> b) -> a -> b
$ IORef (Maybe (Int, UTCTime)) -> Int -> IO Int
applyThrottle IORef (Maybe (Int, UTCTime))
throttleRef Int
freeWorkers
          when (allowed > 0) $ claimAndEnqueue allowed

  -- Claim on startup
  claimOnWakeup

  -- The notification loop wakes on DB notifications, poll timer, or worker completion
  let notificationChannel = Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text -> Text
Schema.notificationChannelForTable Text
tableNameVal
      workerFinishedTrigger = STM () -> Maybe (STM ())
forall a. a -> Maybe a
Just (STM () -> Maybe (STM ())) -> STM () -> Maybe (STM ())
forall a b. (a -> b) -> a -> b
$ do
        d <- TVar Bool -> STM Bool
forall a. TVar a -> STM a
STM.readTVar TVar Bool
workerFinishedVar
        STM.checkSTM d
        STM.writeTVar workerFinishedVar False
  withNotificationLoop
    (BSC.unpack . connStr $ config)
    notificationChannel
    (workerStateVar config)
    (pollInterval config)
    (Just $ logConfig config)
    workerFinishedTrigger
    (const claimOnWakeup)