{-# LANGUAGE OverloadedStrings #-}

module Arbiter.Worker.Dispatcher
  ( runDispatcher
  ) where

import Arbiter.Core.HighLevel (QueueOperation)
import Arbiter.Core.HighLevel qualified as Arb
import Arbiter.Core.Job.Schema qualified as Schema
import Arbiter.Core.Job.Types (JobRead)
import Arbiter.Core.QueueRegistry (TableForPayload)
import Data.ByteString.Char8 qualified as BSC
import Data.Foldable (traverse_)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Proxy (Proxy (..))
import Data.Text qualified as T
import GHC.TypeLits (symbolVal)
import UnliftIO (MonadUnliftIO)
import UnliftIO.Exception qualified as Ex
import UnliftIO.MVar qualified as MVar
import UnliftIO.STM qualified as STM

import Arbiter.Worker.Config (HandlerMode (..), WorkerConfig (..))
import Arbiter.Worker.Logger (LogLevel (..))
import Arbiter.Worker.Logger.Internal (tryLog)
import Arbiter.Worker.NotificationListener (withNotificationLoop)

-- | Run the dispatcher loop
--
-- The dispatcher wakes on PostgreSQL NOTIFY (when jobs are inserted) or poll timer
-- expiration, then claims jobs up to available worker capacity. If no workers are
-- free, the wakeup is a no-op.
--
-- __Responsiveness__: Workers that finish between notifications may sit idle until the
-- next notification or poll. For maximum responsiveness with existing jobs, set a short
-- poll interval (e.g., @pollInterval = Just 1@ for 1 second). For workloads with
-- continuous job insertion, NOTIFY provides sub-second latency.
--
-- The dispatcher runs in a dedicated thread with its own database connection.
runDispatcher
  :: forall m registry payload result
   . ( MonadUnliftIO m
     , QueueOperation m registry payload
     )
  => WorkerConfig m payload result
  -> Int
  -- ^ Worker capacity
  -> STM.TBQueue (NonEmpty (JobRead payload))
  -- ^ Work queue (batches of jobs)
  -> STM.TVar Int
  -- ^ Busy worker count
  -> Maybe (MVar.MVar ())
  -- ^ Liveness signal (pulsed after each successful claim cycle)
  -> STM.TVar Bool
  -- ^ Worker finished signal
  -> m ()
runDispatcher :: forall (m :: * -> *) (registry :: JobPayloadRegistry) payload
       result.
(MonadUnliftIO m, QueueOperation m registry payload) =>
WorkerConfig m payload result
-> Int
-> TBQueue (NonEmpty (JobRead payload))
-> TVar Int
-> Maybe (MVar ())
-> TVar Bool
-> m ()
runDispatcher WorkerConfig m payload result
config Int
workerCapacity TBQueue (NonEmpty (JobRead payload))
workQueue TVar Int
busyWorkerCount Maybe (MVar ())
mLivenessMVar TVar Bool
workerFinishedVar = do
  let
    tableNameVal :: Text
tableNameVal = String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ Proxy (TableForPayload payload registry) -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (forall {k} (t :: k). Proxy t
forall (t :: Symbol). Proxy t
Proxy @(TableForPayload payload registry))

    calcFreeWorkers :: STM.STM Int
    calcFreeWorkers :: STM Int
calcFreeWorkers = do
      busyCount <- TVar Int -> STM Int
forall a. TVar a -> STM a
STM.readTVar TVar Int
busyWorkerCount
      qLen <- fromIntegral <$> STM.lengthTBQueue workQueue
      pure $ workerCapacity - (busyCount + qLen)

    -- Get free workers if any are available (non-blocking)
    getFreeWorkers :: STM.STM (Maybe Int)
    getFreeWorkers :: STM (Maybe Int)
getFreeWorkers = do
      free <- STM Int
calcFreeWorkers
      pure $ if free > 0 then Just free else Nothing

    claimAndEnqueue :: Int -> m ()
    claimAndEnqueue :: Int -> m ()
claimAndEnqueue Int
freeWorkers = do
      eJobs <- m [NonEmpty (JobRead payload)]
-> m (Either SomeException [NonEmpty (JobRead payload)])
forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> m (Either SomeException a)
Ex.tryAny (m [NonEmpty (JobRead payload)]
 -> m (Either SomeException [NonEmpty (JobRead payload)]))
-> m [NonEmpty (JobRead payload)]
-> m (Either SomeException [NonEmpty (JobRead payload)])
forall a b. (a -> b) -> a -> b
$ case WorkerConfig m payload result -> HandlerMode m payload result
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> HandlerMode m payload result
handlerMode WorkerConfig m payload result
config of
        SingleJobMode Map Int64 (Either Text result)
-> Map Int64 Text -> JobHandler m payload result
_ ->
          ([JobRead payload] -> [NonEmpty (JobRead payload)])
-> m [JobRead payload] -> m [NonEmpty (JobRead payload)]
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((JobRead payload -> NonEmpty (JobRead payload))
-> [JobRead payload] -> [NonEmpty (JobRead payload)]
forall a b. (a -> b) -> [a] -> [b]
map (JobRead payload -> [JobRead payload] -> NonEmpty (JobRead payload)
forall a. a -> [a] -> NonEmpty a
:| [])) (Int -> NominalDiffTime -> m [JobRead payload]
forall (m :: * -> *) (registry :: JobPayloadRegistry) payload.
QueueOperation m registry payload =>
Int -> NominalDiffTime -> m [JobRead payload]
Arb.claimNextVisibleJobs Int
freeWorkers (WorkerConfig m payload result -> NominalDiffTime
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> NominalDiffTime
visibilityTimeout WorkerConfig m payload result
config))
        BatchedJobsMode Int
batchSize BatchedJobHandler m payload result
_ ->
          Int -> Int -> NominalDiffTime -> m [NonEmpty (JobRead payload)]
forall (m :: * -> *) (registry :: JobPayloadRegistry) payload.
QueueOperation m registry payload =>
Int -> Int -> NominalDiffTime -> m [NonEmpty (JobRead payload)]
Arb.claimNextVisibleJobsBatched Int
batchSize Int
freeWorkers (WorkerConfig m payload result -> NominalDiffTime
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> NominalDiffTime
visibilityTimeout WorkerConfig m payload result
config)
      case eJobs of
        Left SomeException
e -> do
          LogConfig -> LogLevel -> Text -> m ()
forall (m :: * -> *).
MonadUnliftIO m =>
LogConfig -> LogLevel -> Text -> m ()
tryLog (WorkerConfig m payload result -> LogConfig
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> LogConfig
logConfig WorkerConfig m payload result
config) LogLevel
Error (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"Dispatcher exception: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (SomeException -> String
forall a. Show a => a -> String
show SomeException
e)
        Right [NonEmpty (JobRead payload)]
batches -> do
          STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
STM.atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ (NonEmpty (JobRead payload) -> STM ())
-> [NonEmpty (JobRead payload)] -> STM ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (TBQueue (NonEmpty (JobRead payload))
-> NonEmpty (JobRead payload) -> STM ()
forall a. TBQueue a -> a -> STM ()
STM.writeTBQueue TBQueue (NonEmpty (JobRead payload))
workQueue) [NonEmpty (JobRead payload)]
batches
          (MVar () -> m Bool) -> Maybe (MVar ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ((MVar () -> () -> m Bool) -> () -> MVar () -> m Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip MVar () -> () -> m Bool
forall (m :: * -> *) a. MonadIO m => MVar a -> a -> m Bool
MVar.tryPutMVar ()) Maybe (MVar ())
mLivenessMVar

    -- Claim jobs on wakeup if workers are available
    claimOnWakeup :: m ()
    claimOnWakeup :: m ()
claimOnWakeup = do
      mFree <- STM (Maybe Int) -> m (Maybe Int)
forall (m :: * -> *) a. MonadIO m => STM a -> m a
STM.atomically STM (Maybe Int)
getFreeWorkers
      traverse_ claimAndEnqueue mFree

  -- The notification loop wakes on DB notifications, poll timer, or worker completion
  let notificationChannel :: String
notificationChannel = Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text -> Text
Schema.notificationChannelForTable Text
tableNameVal
      workerFinishedTrigger :: Maybe (STM ())
workerFinishedTrigger = STM () -> Maybe (STM ())
forall a. a -> Maybe a
Just (STM () -> Maybe (STM ())) -> STM () -> Maybe (STM ())
forall a b. (a -> b) -> a -> b
$ do
        d <- TVar Bool -> STM Bool
forall a. TVar a -> STM a
STM.readTVar TVar Bool
workerFinishedVar
        STM.checkSTM d
        STM.writeTVar workerFinishedVar False
  String
-> String
-> TVar WorkerState
-> NominalDiffTime
-> Maybe LogConfig
-> Maybe (STM ())
-> m ()
-> Action m ()
-> m ()
forall (m :: * -> *).
MonadUnliftIO m =>
String
-> String
-> TVar WorkerState
-> NominalDiffTime
-> Maybe LogConfig
-> Maybe (STM ())
-> m ()
-> Action m ()
-> m ()
withNotificationLoop
    (ByteString -> String
BSC.unpack (ByteString -> String)
-> (WorkerConfig m payload result -> ByteString)
-> WorkerConfig m payload result
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WorkerConfig m payload result -> ByteString
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> ByteString
connStr (WorkerConfig m payload result -> String)
-> WorkerConfig m payload result -> String
forall a b. (a -> b) -> a -> b
$ WorkerConfig m payload result
config)
    String
notificationChannel
    (WorkerConfig m payload result -> TVar WorkerState
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> TVar WorkerState
workerStateVar WorkerConfig m payload result
config)
    (WorkerConfig m payload result -> NominalDiffTime
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> NominalDiffTime
pollInterval WorkerConfig m payload result
config)
    (LogConfig -> Maybe LogConfig
forall a. a -> Maybe a
Just (LogConfig -> Maybe LogConfig) -> LogConfig -> Maybe LogConfig
forall a b. (a -> b) -> a -> b
$ WorkerConfig m payload result -> LogConfig
forall (m :: * -> *) payload result.
WorkerConfig m payload result -> LogConfig
logConfig WorkerConfig m payload result
config)
    Maybe (STM ())
workerFinishedTrigger
    m ()
claimOnWakeup
    (m () -> Action m ()
forall a b. a -> b -> a
const m ()
claimOnWakeup)