{-# LANGUAGE OverloadedStrings #-}

module Arbiter.Worker.Heartbeat
  ( withJobsHeartbeat
  ) where

import Arbiter.Core.Exceptions (throwJobStolen)
import Arbiter.Core.HighLevel (JobOperation)
import Arbiter.Core.HighLevel qualified as Arb
import Arbiter.Core.Job.Types (Job (..), JobRead, ObservabilityHooks (..))
import Control.Concurrent.MVar qualified as MVar
import Control.Monad (forever, unless)
import Control.Monad.IO.Class (liftIO)
import Data.Foldable (toList, traverse_)
import Data.List.NonEmpty (NonEmpty)
import Data.Text qualified as T
import Data.Time (NominalDiffTime, UTCTime, getCurrentTime)
import Data.Void (absurd)
import UnliftIO (MonadUnliftIO)
import UnliftIO.Async (race)
import UnliftIO.Concurrent (threadDelay)

import Arbiter.Worker.Logger (LogConfig)
import Arbiter.Worker.Logger.Internal (runHook)
import Arbiter.Worker.Retry (retryOnExceptionForever)

-- | Run an action with a heartbeat that extends visibility timeout for all jobs.
--
-- The heartbeat runs in a separate thread spawned via 'race' and extends the
-- visibility timeout at regular intervals, preventing long-running jobs from
-- becoming visible and being claimed by another worker.
--
-- The heartbeat distinguishes between:
--
--   * Job successfully heartbeated - continue normally
--   * Job already completed (acked\/canceled by handler) - ignore, not an error
--   * Job stolen by another worker (attempts changed) - throw to stop duplicate work
--
-- Calls onJobHeartbeat hook at each interval for monitoring long-running jobs.
withJobsHeartbeat
  :: forall registry m payload a
   . ( JobOperation m registry payload
     , MonadUnliftIO m
     )
  => ObservabilityHooks m payload
  -- ^ Observability hooks (for heartbeat hook)
  -> NominalDiffTime
  -- ^ Heartbeat interval
  -> NominalDiffTime
  -- ^ Visibility timeout
  -> UTCTime
  -- ^ Start time (for calculating elapsed time in heartbeat hook)
  -> NonEmpty (JobRead payload)
  -- ^ The job(s) being processed
  -> LogConfig
  -- ^ Log configuration
  -> Maybe (MVar.MVar ())
  -- ^ Liveness signal (pulsed after each successful heartbeat)
  -> m a
  -- ^ Action to run with heartbeat protection
  -> m a
withJobsHeartbeat :: forall (registry :: JobPayloadRegistry) (m :: * -> *) payload a.
(JobOperation m registry payload, MonadUnliftIO m) =>
ObservabilityHooks m payload
-> NominalDiffTime
-> NominalDiffTime
-> UTCTime
-> NonEmpty (JobRead payload)
-> LogConfig
-> Maybe (MVar ())
-> m a
-> m a
withJobsHeartbeat ObservabilityHooks m payload
hooks NominalDiffTime
intervalSecs NominalDiffTime
timeoutSecs UTCTime
startTime NonEmpty (JobRead payload)
jobs LogConfig
logCfg Maybe (MVar ())
mLivenessMVar m a
action =
  (Void -> a) -> (a -> a) -> Either Void a -> a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Void -> a
forall a. Void -> a
absurd a -> a
forall a. a -> a
id (Either Void a -> a) -> m (Either Void a) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Void -> m a -> m (Either Void a)
forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> m b -> m (Either a b)
race m Void
heartbeatThread m a
action
  where
    heartbeatThread :: m Void
heartbeatThread =
      LogConfig -> Text -> NominalDiffTime -> m (ZonkAny 0) -> m Void
forall (m :: * -> *) a b.
MonadUnliftIO m =>
LogConfig -> Text -> NominalDiffTime -> m a -> m b
retryOnExceptionForever LogConfig
logCfg Text
"Heartbeat" NominalDiffTime
3 (m (ZonkAny 0) -> m Void) -> m (ZonkAny 0) -> m Void
forall a b. (a -> b) -> a -> b
$
        m () -> m (ZonkAny 0)
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever m ()
tick

    tick :: m ()
tick = do
      IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
forall (m :: * -> *). MonadIO m => Int -> m ()
threadDelay (NominalDiffTime -> Int
forall b. Integral b => NominalDiffTime -> b
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (NominalDiffTime
intervalSecs NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* NominalDiffTime
1_000_000))
      results <- NominalDiffTime -> [JobRead payload] -> m [SetVisibilityResult]
forall (m :: * -> *) (registry :: JobPayloadRegistry) payload.
JobOperation m registry payload =>
NominalDiffTime -> [JobRead payload] -> m [SetVisibilityResult]
Arb.setVisibilityTimeoutBatch NominalDiffTime
timeoutSecs (NonEmpty (JobRead payload) -> [JobRead payload]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty (JobRead payload)
jobs)
      traverse_ (\MVar ()
mv -> IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ MVar () -> () -> IO Bool
forall a. MVar a -> a -> IO Bool
MVar.tryPutMVar MVar ()
mv ()) mLivenessMVar
      let stolenJobs = [Int64
jobId | Arb.JobReclaimed Int64
jobId Int32
_ Int32
_ <- [SetVisibilityResult]
results]
      unless (null stolenJobs) $
        throwJobStolen $
          "Heartbeat detected stolen jobs: "
            <> T.intercalate ", " (map (T.pack . show) stolenJobs)
            <> " (another worker reclaimed them, stopping to prevent duplicate processing)"
      let activeJobIds = [Int64
jobId | Arb.VisibilityExtended Int64
jobId <- [SetVisibilityResult]
results]
          activeJobs = (JobRead payload -> Bool) -> [JobRead payload] -> [JobRead payload]
forall a. (a -> Bool) -> [a] -> [a]
filter (\JobRead payload
job -> JobRead payload -> Int64
forall payload key q insertedAt.
Job payload key q insertedAt -> key
primaryKey JobRead payload
job Int64 -> [Int64] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int64]
activeJobIds) (NonEmpty (JobRead payload) -> [JobRead payload]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty (JobRead payload)
jobs)
      currentTime <- liftIO getCurrentTime
      traverse_
        ( \JobRead payload
job ->
            LogConfig -> Text -> m () -> m ()
forall (m :: * -> *).
MonadUnliftIO m =>
LogConfig -> Text -> m () -> m ()
runHook LogConfig
logCfg Text
"onJobHeartbeat" (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
              ObservabilityHooks m payload
-> JobPayload payload =>
   JobRead payload -> UTCTime -> UTCTime -> m ()
forall (m :: * -> *) payload.
ObservabilityHooks m payload
-> JobPayload payload =>
   JobRead payload -> UTCTime -> UTCTime -> m ()
onJobHeartbeat ObservabilityHooks m payload
hooks JobRead payload
job UTCTime
currentTime UTCTime
startTime
        )
        activeJobs