{-# LANGUAGE OverloadedStrings #-}

-- | Retry combinators for worker infrastructure threads (notification listener,
-- cron scheduler, etc.) that should survive transient database failures.
module Arbiter.Worker.Retry
  ( retryOnException
  , retryOnExceptionForever
  ) where

import Arbiter.Core.Exceptions
  ( JobException
  , JobNotFoundException
  , JobStolenException
  )
import Control.Monad (forever)
import Data.Maybe (isJust)
import Data.Text qualified as T
import Data.Time (NominalDiffTime)
import UnliftIO (MonadUnliftIO, SomeException, fromException, liftIO, throwIO)
import UnliftIO.Async (race)
import UnliftIO.Concurrent (threadDelay)
import UnliftIO.Exception (tryAny)
import UnliftIO.STM (TVar, atomically, readTVar, readTVarIO, retrySTM)

import Arbiter.Worker.Logger (LogConfig, LogLevel (..))
import Arbiter.Worker.Logger.Internal (tryLog)
import Arbiter.Worker.WorkerState (WorkerState (..))

-- | Run an action in a retry loop, surviving transient failures.
--
-- On synchronous exceptions, checks the worker state - if 'ShuttingDown',
-- exits cleanly; otherwise logs the error and retries after a 5-second delay.
retryOnException
  :: (MonadUnliftIO m)
  => TVar WorkerState
  -> LogConfig
  -> T.Text
  -- ^ Label for log messages (e.g. "Notification listener")
  -> m ()
  -- ^ Action to run
  -> m ()
retryOnException :: forall (m :: * -> *).
MonadUnliftIO m =>
TVar WorkerState -> LogConfig -> Text -> m () -> m ()
retryOnException TVar WorkerState
stateVar LogConfig
logCfg Text
label m ()
action = m ()
loop
  where
    loop :: m ()
loop = do
      result <- m () -> m (Either SomeException ())
forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> m (Either SomeException a)
tryAny m ()
action
      case result of
        Right () -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Left SomeException
e -> do
          status <- TVar WorkerState -> m WorkerState
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar WorkerState
stateVar
          case status of
            WorkerState
ShuttingDown -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            WorkerState
_ -> do
              LogConfig -> LogLevel -> Text -> m ()
forall (m :: * -> *).
MonadUnliftIO m =>
LogConfig -> LogLevel -> Text -> m ()
tryLog LogConfig
logCfg LogLevel
Error (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$
                Text
label Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" error (retrying): " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (SomeException -> String
forall a. Show a => a -> String
show SomeException
e)
              sleepResult <-
                m () -> m () -> m (Either () ())
forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> m b -> m (Either a b)
race
                  ( IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (STM () -> IO ()) -> STM () -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$
                      TVar WorkerState -> STM WorkerState
forall a. TVar a -> STM a
readTVar TVar WorkerState
stateVar STM WorkerState -> (WorkerState -> STM ()) -> STM ()
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \WorkerState
st ->
                        case WorkerState
st of
                          WorkerState
ShuttingDown -> () -> STM ()
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                          WorkerState
_ -> STM ()
forall a. STM a
retrySTM
                  )
                  (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
forall (m :: * -> *). MonadIO m => Int -> m ()
threadDelay Int
5_000_000)
              case sleepResult of
                Left () -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                Right () -> m ()
loop

-- | Like 'retryOnException' but never returns on its own, even if the worker
-- is shutting down. Job signals propagate so they reach the worker layer
-- where they have semantic meaning ('JobException' user decisions,
-- 'JobStolenException' and 'JobNotFoundException' reclaim signals).
-- Everything else (including transient DB errors) is retried.
retryOnExceptionForever
  :: (MonadUnliftIO m)
  => LogConfig
  -> T.Text
  -- ^ Label for log messages
  -> NominalDiffTime
  -- ^ Delay between retries on transient failure
  -> m a
  -- ^ Action to run (typically itself a forever loop)
  -> m b
retryOnExceptionForever :: forall (m :: * -> *) a b.
MonadUnliftIO m =>
LogConfig -> Text -> NominalDiffTime -> m a -> m b
retryOnExceptionForever LogConfig
logCfg Text
label NominalDiffTime
delay m a
action = m () -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m b) -> m () -> m b
forall a b. (a -> b) -> a -> b
$ do
  result <- m a -> m (Either SomeException a)
forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> m (Either SomeException a)
tryAny m a
action
  case result of
    Right a
_ -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Left SomeException
e
      | SomeException -> Bool
isJobSignal SomeException
e -> SomeException -> m ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO SomeException
e
      | Bool
otherwise -> do
          LogConfig -> LogLevel -> Text -> m ()
forall (m :: * -> *).
MonadUnliftIO m =>
LogConfig -> LogLevel -> Text -> m ()
tryLog LogConfig
logCfg LogLevel
Error (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$
            Text
label Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" error (retrying): " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (SomeException -> String
forall a. Show a => a -> String
show SomeException
e)
          IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
forall (m :: * -> *). MonadIO m => Int -> m ()
threadDelay (NominalDiffTime -> Int
forall b. Integral b => NominalDiffTime -> b
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (NominalDiffTime
delay NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* NominalDiffTime
1_000_000))

isJobSignal :: SomeException -> Bool
isJobSignal :: SomeException -> Bool
isJobSignal SomeException
e =
  Maybe JobException -> Bool
forall a. Maybe a -> Bool
isJust (SomeException -> Maybe JobException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e :: Maybe JobException)
    Bool -> Bool -> Bool
|| Maybe JobStolenException -> Bool
forall a. Maybe a -> Bool
isJust (SomeException -> Maybe JobStolenException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e :: Maybe JobStolenException)
    Bool -> Bool -> Bool
|| Maybe JobNotFoundException -> Bool
forall a. Maybe a -> Bool
isJust (SomeException -> Maybe JobNotFoundException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e :: Maybe JobNotFoundException)