{-# LANGUAGE OverloadedStrings #-}

module Arbiter.Worker.NotificationListener
  ( withNotificationLoop
  ) where

import Arbiter.Core.Job.Schema (quoteIdentifier)
import Control.Applicative (Alternative ((<|>)))
import Control.Concurrent (threadDelay)
import Control.Monad (forever, void)
import Data.ByteString.Char8 qualified as BSC
import Data.String (fromString)
import Data.Text qualified as T
import Data.Time (NominalDiffTime)
import Database.PostgreSQL.Simple qualified as PS
import Database.PostgreSQL.Simple.Notification qualified as PS
import UnliftIO (MonadUnliftIO, liftIO)
import UnliftIO.Async (Concurrently (..), race_)
import UnliftIO.Exception (bracket)
import UnliftIO.STM qualified as STM

import Arbiter.Worker.Logger (LogConfig, defaultLogConfig)
import Arbiter.Worker.Retry (retryOnException)
import Arbiter.Worker.WorkerState (WorkerState (..))

data ListenerCtx
  = ListenerCtx
  { ListenerCtx -> TVar WorkerState
lcProcessStatus :: STM.TVar WorkerState
  , ListenerCtx -> NominalDiffTime
lcPollDelay :: NominalDiffTime
  , ListenerCtx -> TVar (Maybe Notification)
lcNotificationVar :: STM.TVar (Maybe PS.Notification)
  , ListenerCtx -> Connection
lcConnection :: PS.Connection
  , ListenerCtx -> Maybe (STM ())
lcWakeTrigger :: Maybe (STM.STM ())
  }

type Action m a = Maybe PS.Notification -> m a

-- | Runs the provided action when a notification is received on the specified
-- channel or when the poll delay timer expires. Forks a linked thread that listens
-- for Postgres notifications and communicates with the handler loop via a TVar.
-- If the connection is lost, automatically reconnects with backoff. Only exits
-- when the worker state is set to 'ShuttingDown' or an async exception is received.
withNotificationLoop
  :: (MonadUnliftIO m)
  => String
  -- ^ Postgres connection string
  -> String
  -- ^ Notification channel name (e.g., "email_jobs_created")
  -> STM.TVar WorkerState
  -- ^ Signal for worker state (Running, Paused, ShuttingDown)
  -> NominalDiffTime
  -- ^ Poll delay in seconds — action fires on this interval if no
  -- notifications are received. Also serves as the liveness heartbeat.
  -> Maybe LogConfig
  -- ^ Optional log configuration for internal errors
  -> Maybe (STM.STM ())
  -- ^ Optional wake trigger (e.g., worker finished signal)
  -> Action m ()
  -- ^ Action to run
  -> m ()
withNotificationLoop :: forall (m :: * -> *).
MonadUnliftIO m =>
String
-> String
-> TVar WorkerState
-> NominalDiffTime
-> Maybe LogConfig
-> Maybe (STM ())
-> Action m ()
-> m ()
withNotificationLoop String
connStr String
channel TVar WorkerState
pSt NominalDiffTime
polDel Maybe LogConfig
mLogCfg Maybe (STM ())
mWakeTrigger Action m ()
action =
  TVar WorkerState -> LogConfig -> Text -> m () -> m ()
forall (m :: * -> *).
MonadUnliftIO m =>
TVar WorkerState -> LogConfig -> Text -> m () -> m ()
retryOnException TVar WorkerState
pSt LogConfig
logCfg Text
"Notification listener"
    (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ m Connection
-> (Connection -> m ()) -> (Connection -> m ()) -> m ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
      (IO Connection -> m Connection
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Connection -> m Connection) -> IO Connection -> m Connection
forall a b. (a -> b) -> a -> b
$ String -> IO Connection
connectToDb String
connStr)
      (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (Connection -> IO ()) -> Connection -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> IO ()
PS.close)
    ((Connection -> m ()) -> m ()) -> (Connection -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> do
      nVar <- Maybe Notification -> m (TVar (Maybe Notification))
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
STM.newTVarIO Maybe Notification
forall a. Maybe a
Nothing
      let ctx = TVar WorkerState
-> NominalDiffTime
-> TVar (Maybe Notification)
-> Connection
-> Maybe (STM ())
-> ListenerCtx
ListenerCtx TVar WorkerState
pSt NominalDiffTime
polDel TVar (Maybe Notification)
nVar Connection
conn Maybe (STM ())
mWakeTrigger
      liftIO $ subscribeToChannel (lcConnection ctx) channel
      race_
        (mainLoop ctx action)
        (notificationLoop ctx)
  where
    logCfg :: LogConfig
logCfg = LogConfig
-> (LogConfig -> LogConfig) -> Maybe LogConfig -> LogConfig
forall b a. b -> (a -> b) -> Maybe a -> b
maybe LogConfig
defaultLogConfig LogConfig -> LogConfig
forall a. a -> a
id Maybe LogConfig
mLogCfg

mainLoop :: (MonadUnliftIO m) => ListenerCtx -> Action m a -> m ()
mainLoop :: forall (m :: * -> *) a.
MonadUnliftIO m =>
ListenerCtx -> Action m a -> m ()
mainLoop ListenerCtx
ctx Action m a
action = m ()
loop
  where
    loop :: m ()
loop = do
      status <- TVar WorkerState -> m WorkerState
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
STM.readTVarIO (TVar WorkerState -> m WorkerState)
-> TVar WorkerState -> m WorkerState
forall a b. (a -> b) -> a -> b
$ ListenerCtx -> TVar WorkerState
lcProcessStatus ListenerCtx
ctx

      -- check status first so there is no race if the process wants to shut down
      case status of
        WorkerState
ShuttingDown -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        WorkerState
Paused -> do
          -- When paused, wait for state to change to Running or ShuttingDown
          newStatus <- STM WorkerState -> m WorkerState
forall (m :: * -> *) a. MonadIO m => STM a -> m a
STM.atomically (STM WorkerState -> m WorkerState)
-> STM WorkerState -> m WorkerState
forall a b. (a -> b) -> a -> b
$ do
            s <- TVar WorkerState -> STM WorkerState
forall a. TVar a -> STM a
STM.readTVar (ListenerCtx -> TVar WorkerState
lcProcessStatus ListenerCtx
ctx)
            case s of
              WorkerState
Paused -> STM WorkerState
forall a. STM a
STM.retrySTM -- Block until state changes
              WorkerState
_ -> WorkerState -> STM WorkerState
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure WorkerState
s
          case newStatus of
            WorkerState
ShuttingDown -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            WorkerState
_ -> m ()
loop
        WorkerState
Running -> do
          -- Cancel waiting for a notification when the app shuts down or paused.
          -- There is also an optional timer that, if it expires, fires the
          -- action even if no notification has been received. This provides
          -- assurance that we won't miss anything.
          command <-
            Concurrently m Command -> m Command
forall (m :: * -> *) a. Concurrently m a -> m a
runConcurrently (Concurrently m Command -> m Command)
-> Concurrently m Command -> m Command
forall a b. (a -> b) -> a -> b
$
              m Command -> Concurrently m Command
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently (ListenerCtx -> m Command
forall (m :: * -> *). MonadUnliftIO m => ListenerCtx -> m Command
checkStateChange ListenerCtx
ctx)
                Concurrently m Command
-> Concurrently m Command -> Concurrently m Command
forall a. Concurrently m a -> Concurrently m a -> Concurrently m a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> m Command -> Concurrently m Command
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently (TVar (Maybe Notification) -> m Command
forall (m :: * -> *).
MonadUnliftIO m =>
TVar (Maybe Notification) -> m Command
waitForNotification (TVar (Maybe Notification) -> m Command)
-> TVar (Maybe Notification) -> m Command
forall a b. (a -> b) -> a -> b
$ ListenerCtx -> TVar (Maybe Notification)
lcNotificationVar ListenerCtx
ctx)
                Concurrently m Command
-> Concurrently m Command -> Concurrently m Command
forall a. Concurrently m a -> Concurrently m a -> Concurrently m a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> m Command -> Concurrently m Command
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently (ListenerCtx -> m Command
forall (m :: * -> *). MonadUnliftIO m => ListenerCtx -> m Command
messageWaitTimer ListenerCtx
ctx)
                Concurrently m Command
-> Concurrently m Command -> Concurrently m Command
forall a. Concurrently m a -> Concurrently m a -> Concurrently m a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> m Command -> Concurrently m Command
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently (ListenerCtx -> m Command
forall (m :: * -> *). MonadUnliftIO m => ListenerCtx -> m Command
waitForWakeTrigger ListenerCtx
ctx)

          case command of
            Command
Halt -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            Command
PauseCmd -> m ()
loop -- Go back to top, will re-check state
            NotificationRecv Notification
n -> Action m a
action (Notification -> Maybe Notification
forall a. a -> Maybe a
Just Notification
n) m a -> m () -> m ()
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> m ()
loop
            -- run the action even though there was no message
            Command
TimerExpired -> Action m a
action Maybe Notification
forall a. Maybe a
Nothing m a -> m () -> m ()
forall a b. m a -> m b -> m b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> m ()
loop

data Command
  = Halt
  | PauseCmd
  | NotificationRecv PS.Notification
  | TimerExpired

-- | Blocks until the process status changes from Running to Paused or ShuttingDown.
-- Returns the appropriate command when state changes.
checkStateChange :: (MonadUnliftIO m) => ListenerCtx -> m Command
checkStateChange :: forall (m :: * -> *). MonadUnliftIO m => ListenerCtx -> m Command
checkStateChange ListenerCtx
ctx =
  STM Command -> m Command
forall (m :: * -> *) a. MonadIO m => STM a -> m a
STM.atomically (STM Command -> m Command) -> STM Command -> m Command
forall a b. (a -> b) -> a -> b
$ do
    status <- TVar WorkerState -> STM WorkerState
forall a. TVar a -> STM a
STM.readTVar (ListenerCtx -> TVar WorkerState
lcProcessStatus ListenerCtx
ctx)
    case status of
      WorkerState
ShuttingDown -> Command -> STM Command
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Command
Halt
      WorkerState
Paused -> Command -> STM Command
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Command
PauseCmd
      WorkerState
Running -> STM Command
forall a. STM a
STM.retrySTM -- Block until state changes

-- | Block until a notification is received from the notification TVar.
-- Then block until the result is True.
waitForNotification :: (MonadUnliftIO m) => STM.TVar (Maybe PS.Notification) -> m Command
waitForNotification :: forall (m :: * -> *).
MonadUnliftIO m =>
TVar (Maybe Notification) -> m Command
waitForNotification TVar (Maybe Notification)
notificationVar = STM Command -> m Command
forall (m :: * -> *) a. MonadIO m => STM a -> m a
STM.atomically (STM Command -> m Command) -> STM Command -> m Command
forall a b. (a -> b) -> a -> b
$ do
  mNotificationVar <- TVar (Maybe Notification) -> STM (Maybe Notification)
forall a. TVar a -> STM a
STM.readTVar TVar (Maybe Notification)
notificationVar
  case mNotificationVar of
    Just Notification
n -> do
      TVar (Maybe Notification) -> Maybe Notification -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar TVar (Maybe Notification)
notificationVar Maybe Notification
forall a. Maybe a
Nothing
      Command -> STM Command
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Command -> STM Command) -> Command -> STM Command
forall a b. (a -> b) -> a -> b
$ Notification -> Command
NotificationRecv Notification
n
    Maybe Notification
Nothing -> STM Command
forall a. STM a
STM.retrySTM

-- Block on receiving a Postgres notification. When a notification is received,
-- add it to the notification var and loop.
notificationLoop :: (MonadUnliftIO m) => ListenerCtx -> m ()
notificationLoop :: forall (m :: * -> *). MonadUnliftIO m => ListenerCtx -> m ()
notificationLoop ListenerCtx
ctx = m () -> m ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  n <- IO Notification -> m Notification
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Notification -> m Notification)
-> IO Notification -> m Notification
forall a b. (a -> b) -> a -> b
$ Connection -> IO Notification
PS.getNotification (ListenerCtx -> Connection
lcConnection ListenerCtx
ctx)
  void . STM.atomically $ STM.swapTVar (lcNotificationVar ctx) (Just n)

-- | Blocks for the duration of the poll delay.
messageWaitTimer :: (MonadUnliftIO m) => ListenerCtx -> m Command
messageWaitTimer :: forall (m :: * -> *). MonadUnliftIO m => ListenerCtx -> m Command
messageWaitTimer ListenerCtx
ctx = do
  let microSecs :: Int
microSecs = NominalDiffTime -> Int
forall b. Integral b => NominalDiffTime -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (ListenerCtx -> NominalDiffTime
lcPollDelay ListenerCtx
ctx NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* NominalDiffTime
1_000_000)
  delay <- Int -> m (TVar Bool)
forall (m :: * -> *). MonadIO m => Int -> m (TVar Bool)
STM.registerDelay Int
microSecs
  STM.atomically $ do
    isExpired <- STM.readTVar delay
    if isExpired
      then pure TimerExpired
      else STM.retrySTM

-- | Blocks until the wake trigger fires, or forever if no trigger is configured.
waitForWakeTrigger :: (MonadUnliftIO m) => ListenerCtx -> m Command
waitForWakeTrigger :: forall (m :: * -> *). MonadUnliftIO m => ListenerCtx -> m Command
waitForWakeTrigger ListenerCtx
ctx = case ListenerCtx -> Maybe (STM ())
lcWakeTrigger ListenerCtx
ctx of
  Maybe (STM ())
Nothing -> IO Command -> m Command
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Command -> m Command) -> IO Command -> m Command
forall a b. (a -> b) -> a -> b
$ IO () -> IO Command
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO Command) -> IO () -> IO Command
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
threadDelay Int
forall a. Bounded a => a
maxBound
  Just STM ()
trigger -> STM Command -> m Command
forall (m :: * -> *) a. MonadIO m => STM a -> m a
STM.atomically (STM Command -> m Command) -> STM Command -> m Command
forall a b. (a -> b) -> a -> b
$ STM ()
trigger STM () -> STM Command -> STM Command
forall a b. STM a -> STM b -> STM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Command -> STM Command
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Command
TimerExpired

connectToDb :: String -> IO PS.Connection
connectToDb :: String -> IO Connection
connectToDb = ByteString -> IO Connection
PS.connectPostgreSQL (ByteString -> IO Connection)
-> (String -> ByteString) -> String -> IO Connection
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
BSC.pack

-- | Issue a LISTEN command to the database for a specific notification channel.
subscribeToChannel :: PS.Connection -> String -> IO ()
subscribeToChannel :: Connection -> String -> IO ()
subscribeToChannel Connection
conn String
channel =
  IO Int64 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int64 -> IO ()) -> (String -> IO Int64) -> String -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Query -> IO Int64
PS.execute_ Connection
conn (Query -> IO Int64) -> (String -> Query) -> String -> IO Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Query
forall a. IsString a => String -> a
fromString (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$
    Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
      Text
"LISTEN " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
quoteIdentifier (String -> Text
T.pack String
channel)