{-# LANGUAGE OverloadedStrings #-}

module Arbiter.Worker.Heartbeat
  ( withJobsHeartbeat
  ) where

import Arbiter.Core.Exceptions (throwJobStolen)
import Arbiter.Core.HighLevel (JobOperation)
import Arbiter.Core.HighLevel qualified as Arb
import Arbiter.Core.Job.Types (Job (..), JobRead, ObservabilityHooks (..))
import Control.Concurrent.MVar qualified as MVar
import Control.Monad (forever, unless)
import Control.Monad.IO.Class (liftIO)
import Data.Foldable (toList, traverse_)
import Data.List.NonEmpty (NonEmpty)
import Data.Text qualified as T
import Data.Time (NominalDiffTime, UTCTime, getCurrentTime)
import Data.Void (Void, absurd)
import UnliftIO (MonadUnliftIO)
import UnliftIO.Async (race)
import UnliftIO.Concurrent (threadDelay)

import Arbiter.Worker.Logger (LogConfig)
import Arbiter.Worker.Logger.Internal (runHook)

-- | Run an action with a heartbeat that extends visibility timeout for all jobs
--
-- The heartbeat runs in a separate thread and extends the visibility timeout at
-- regular intervals, preventing long-running jobs from becoming visible and being
-- claimed by another worker.
--
-- Uses 'race' to coordinate the heartbeat and action threads. If the heartbeat
-- detects a stolen job, its exception propagates out (cancelling the action).
-- If the action completes first, the heartbeat is cancelled cleanly — no stale
-- async exceptions can leak into the worker loop.
--
-- The heartbeat distinguishes between:
--
--   * Job successfully heartbeated - continue normally
--   * Job already completed (acked\/canceled by handler) - ignore, not an error
--   * Job stolen by another worker (attempts changed) - throw to stop duplicate work
--
-- Calls onJobHeartbeat hook at each interval for monitoring long-running jobs.
withJobsHeartbeat
  :: forall m registry payload a
   . ( JobOperation m registry payload
     , MonadUnliftIO m
     )
  => ObservabilityHooks m payload
  -- ^ Observability hooks (for heartbeat hook)
  -> Int
  -- ^ Heartbeat interval in seconds (e.g., 30)
  -> NominalDiffTime
  -- ^ Visibility timeout in seconds (e.g., 60)
  -> UTCTime
  -- ^ Start time (for calculating elapsed time in heartbeat hook)
  -> NonEmpty (JobRead payload)
  -- ^ The job(s) being processed
  -> LogConfig
  -- ^ Log configuration
  -> Maybe (MVar.MVar ())
  -- ^ Liveness signal (pulsed after each successful heartbeat)
  -> m a
  -- ^ Action to run with heartbeat protection
  -> m a
withJobsHeartbeat :: forall (m :: * -> *) (registry :: JobPayloadRegistry) payload a.
(JobOperation m registry payload, MonadUnliftIO m) =>
ObservabilityHooks m payload
-> Int
-> NominalDiffTime
-> UTCTime
-> NonEmpty (JobRead payload)
-> LogConfig
-> Maybe (MVar ())
-> m a
-> m a
withJobsHeartbeat ObservabilityHooks m payload
hooks Int
intervalSecs NominalDiffTime
timeoutSecs UTCTime
startTime NonEmpty (JobRead payload)
jobs LogConfig
logCfg Maybe (MVar ())
mLivenessMVar m a
action =
  (Void -> a) -> (a -> a) -> Either Void a -> a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Void -> a
forall a. Void -> a
absurd a -> a
forall a. a -> a
id (Either Void a -> a) -> m (Either Void a) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Void -> m a -> m (Either Void a)
forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> m b -> m (Either a b)
race (ObservabilityHooks m payload
-> Int
-> NominalDiffTime
-> UTCTime
-> NonEmpty (JobRead payload)
-> LogConfig
-> Maybe (MVar ())
-> m Void
forall (m :: * -> *) (registry :: JobPayloadRegistry) payload.
(JobOperation m registry payload, MonadUnliftIO m) =>
ObservabilityHooks m payload
-> Int
-> NominalDiffTime
-> UTCTime
-> NonEmpty (JobRead payload)
-> LogConfig
-> Maybe (MVar ())
-> m Void
heartbeatLoop ObservabilityHooks m payload
hooks Int
intervalSecs NominalDiffTime
timeoutSecs UTCTime
startTime NonEmpty (JobRead payload)
jobs LogConfig
logCfg Maybe (MVar ())
mLivenessMVar) m a
action

-- | Heartbeat loop that extends visibility for all jobs at regular intervals
--
-- Runs forever, so the return type is 'Void' — it can only exit by throwing.
-- Uses batch operations for detailed per-job status. Only throws on stolen jobs
-- (another worker reclaimed the job). Jobs that were already acked/canceled by
-- the handler are silently ignored.
heartbeatLoop
  :: forall m registry payload
   . ( JobOperation m registry payload
     , MonadUnliftIO m
     )
  => ObservabilityHooks m payload
  -- ^ Observability hooks
  -> Int
  -- ^ Interval in seconds
  -> NominalDiffTime
  -- ^ Timeout in seconds
  -> UTCTime
  -- ^ Start time
  -> NonEmpty (JobRead payload)
  -- ^ Job(s) to heartbeat
  -> LogConfig
  -- ^ Log configuration
  -> Maybe (MVar.MVar ())
  -- ^ Liveness signal
  -> m Void
heartbeatLoop :: forall (m :: * -> *) (registry :: JobPayloadRegistry) payload.
(JobOperation m registry payload, MonadUnliftIO m) =>
ObservabilityHooks m payload
-> Int
-> NominalDiffTime
-> UTCTime
-> NonEmpty (JobRead payload)
-> LogConfig
-> Maybe (MVar ())
-> m Void
heartbeatLoop ObservabilityHooks m payload
hooks Int
intervalSecs NominalDiffTime
timeoutSecs UTCTime
startTime NonEmpty (JobRead payload)
jobs LogConfig
logCfg Maybe (MVar ())
mLivenessMVar = m () -> m Void
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m Void) -> m () -> m Void
forall a b. (a -> b) -> a -> b
$ do
  -- Wait for the interval
  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
forall (m :: * -> *). MonadIO m => Int -> m ()
threadDelay (Int
intervalSecs Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1_000_000)

  -- Extend visibility and get detailed status for each job
  results <- NominalDiffTime -> [JobRead payload] -> m [SetVisibilityResult]
forall (m :: * -> *) (registry :: JobPayloadRegistry) payload.
JobOperation m registry payload =>
NominalDiffTime -> [JobRead payload] -> m [SetVisibilityResult]
Arb.setVisibilityTimeoutBatch NominalDiffTime
timeoutSecs (NonEmpty (JobRead payload) -> [JobRead payload]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty (JobRead payload)
jobs)

  -- Signal liveness after successful heartbeat
  traverse_ (\MVar ()
mv -> IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ MVar () -> () -> IO Bool
forall a. MVar a -> a -> IO Bool
MVar.tryPutMVar MVar ()
mv ()) mLivenessMVar

  -- Check for stolen jobs (another worker reclaimed them)
  let stolenJobs = [Int64
jobId | Arb.JobReclaimed Int64
jobId Int32
_ Int32
_ <- [SetVisibilityResult]
results]
  unless (null stolenJobs) $
    throwJobStolen $
      "Heartbeat detected stolen jobs: "
        <> T.intercalate ", " (map (T.pack . show) stolenJobs)
        <> " (another worker reclaimed them, stopping to prevent duplicate processing)"

  -- Call heartbeat hook only for jobs that are still active (successfully heartbeated)
  -- Jobs that were acked/canceled (JobNotFound) are no longer being worked
  let activeJobIds = [Int64
jobId | Arb.VisibilityExtended Int64
jobId <- [SetVisibilityResult]
results]
      activeJobs = (JobRead payload -> Bool) -> [JobRead payload] -> [JobRead payload]
forall a. (a -> Bool) -> [a] -> [a]
filter (\JobRead payload
job -> JobRead payload -> Int64
forall payload key q insertedAt.
Job payload key q insertedAt -> key
primaryKey JobRead payload
job Int64 -> [Int64] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int64]
activeJobIds) (NonEmpty (JobRead payload) -> [JobRead payload]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty (JobRead payload)
jobs)
  currentTime <- liftIO getCurrentTime
  traverse_
    ( \JobRead payload
job ->
        LogConfig -> Text -> m () -> m ()
forall (m :: * -> *).
MonadUnliftIO m =>
LogConfig -> Text -> m () -> m ()
runHook LogConfig
logCfg Text
"onJobHeartbeat" (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
          ObservabilityHooks m payload
-> JobPayload payload =>
   JobRead payload -> UTCTime -> UTCTime -> m ()
forall (m :: * -> *) payload.
ObservabilityHooks m payload
-> JobPayload payload =>
   JobRead payload -> UTCTime -> UTCTime -> m ()
onJobHeartbeat ObservabilityHooks m payload
hooks JobRead payload
job UTCTime
currentTime UTCTime
startTime
    )
    activeJobs