{-# LANGUAGE OverloadedStrings #-}

-- | Retry combinator for worker infrastructure threads (notification listener,
-- cron scheduler, etc.) that should survive transient database failures.
module Arbiter.Worker.Retry
  ( retryOnException
  ) where

import Control.Monad (void)
import Data.Text qualified as T
import UnliftIO (MonadUnliftIO, liftIO)
import UnliftIO.Async (race)
import UnliftIO.Concurrent (threadDelay)
import UnliftIO.Exception (tryAny)
import UnliftIO.STM (TVar, atomically, readTVar, readTVarIO, retrySTM)

import Arbiter.Worker.Logger (LogConfig, LogLevel (..))
import Arbiter.Worker.Logger.Internal (logMessage)
import Arbiter.Worker.WorkerState (WorkerState (..))

-- | Run an action in a retry loop, surviving transient failures.
--
-- On synchronous exceptions, checks the worker state — if 'ShuttingDown',
-- exits cleanly; otherwise logs the error and retries after a 5-second delay.
retryOnException
  :: (MonadUnliftIO m)
  => TVar WorkerState
  -> LogConfig
  -> T.Text
  -- ^ Label for log messages (e.g. "Notification listener")
  -> m ()
  -- ^ Action to run
  -> m ()
retryOnException :: forall (m :: * -> *).
MonadUnliftIO m =>
TVar WorkerState -> LogConfig -> Text -> m () -> m ()
retryOnException TVar WorkerState
stateVar LogConfig
logCfg Text
label m ()
action = m ()
loop
  where
    loop :: m ()
loop = do
      result <- m () -> m (Either SomeException ())
forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> m (Either SomeException a)
tryAny m ()
action
      case result of
        Right () -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Left SomeException
e -> do
          status <- TVar WorkerState -> m WorkerState
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar WorkerState
stateVar
          case status of
            WorkerState
ShuttingDown -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            WorkerState
_ -> do
              m (Either SomeException ()) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (Either SomeException ()) -> m ())
-> (IO () -> m (Either SomeException ())) -> IO () -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m () -> m (Either SomeException ())
forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> m (Either SomeException a)
tryAny (m () -> m (Either SomeException ()))
-> (IO () -> m ()) -> IO () -> m (Either SomeException ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$
                LogConfig -> LogLevel -> Text -> IO ()
logMessage LogConfig
logCfg LogLevel
Error (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$
                  Text
label Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" error (retrying): " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (SomeException -> String
forall a. Show a => a -> String
show SomeException
e)
              sleepResult <-
                m () -> m () -> m (Either () ())
forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> m b -> m (Either a b)
race
                  ( IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (STM () -> IO ()) -> STM () -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$
                      TVar WorkerState -> STM WorkerState
forall a. TVar a -> STM a
readTVar TVar WorkerState
stateVar STM WorkerState -> (WorkerState -> STM ()) -> STM ()
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \WorkerState
st ->
                        case WorkerState
st of
                          WorkerState
ShuttingDown -> () -> STM ()
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                          WorkerState
_ -> STM ()
forall a. STM a
retrySTM
                  )
                  (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
forall (m :: * -> *). MonadIO m => Int -> m ()
threadDelay Int
5_000_000)
              case sleepResult of
                Left () -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                Right () -> m ()
loop