-- | Configurable backoff strategies for job retries.
module Arbiter.Worker.BackoffStrategy
  ( BackoffStrategy (..)
  , Jitter (..)
  , ExponentialConfig (..)
  , LinearConfig (..)
  , calculateBackoff
  , applyJitter
  , exponentialBackoff
  , linearBackoff
  , constantBackoff
  ) where

import Data.Int (Int32)
import Data.Time (NominalDiffTime)
import System.Random (randomRIO)

-- | Exponential backoff configuration.
data ExponentialConfig = ExponentialConfig
  { ExponentialConfig -> Double
exponentialBase :: Double
  , ExponentialConfig -> NominalDiffTime
exponentialCap :: NominalDiffTime
  }
  deriving stock (ExponentialConfig -> ExponentialConfig -> Bool
(ExponentialConfig -> ExponentialConfig -> Bool)
-> (ExponentialConfig -> ExponentialConfig -> Bool)
-> Eq ExponentialConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ExponentialConfig -> ExponentialConfig -> Bool
== :: ExponentialConfig -> ExponentialConfig -> Bool
$c/= :: ExponentialConfig -> ExponentialConfig -> Bool
/= :: ExponentialConfig -> ExponentialConfig -> Bool
Eq, Int -> ExponentialConfig -> ShowS
[ExponentialConfig] -> ShowS
ExponentialConfig -> String
(Int -> ExponentialConfig -> ShowS)
-> (ExponentialConfig -> String)
-> ([ExponentialConfig] -> ShowS)
-> Show ExponentialConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ExponentialConfig -> ShowS
showsPrec :: Int -> ExponentialConfig -> ShowS
$cshow :: ExponentialConfig -> String
show :: ExponentialConfig -> String
$cshowList :: [ExponentialConfig] -> ShowS
showList :: [ExponentialConfig] -> ShowS
Show)

-- | Linear backoff configuration.
data LinearConfig = LinearConfig
  { LinearConfig -> NominalDiffTime
linearIncrement :: NominalDiffTime
  , LinearConfig -> NominalDiffTime
linearCap :: NominalDiffTime
  }
  deriving stock (LinearConfig -> LinearConfig -> Bool
(LinearConfig -> LinearConfig -> Bool)
-> (LinearConfig -> LinearConfig -> Bool) -> Eq LinearConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: LinearConfig -> LinearConfig -> Bool
== :: LinearConfig -> LinearConfig -> Bool
$c/= :: LinearConfig -> LinearConfig -> Bool
/= :: LinearConfig -> LinearConfig -> Bool
Eq, Int -> LinearConfig -> ShowS
[LinearConfig] -> ShowS
LinearConfig -> String
(Int -> LinearConfig -> ShowS)
-> (LinearConfig -> String)
-> ([LinearConfig] -> ShowS)
-> Show LinearConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> LinearConfig -> ShowS
showsPrec :: Int -> LinearConfig -> ShowS
$cshow :: LinearConfig -> String
show :: LinearConfig -> String
$cshowList :: [LinearConfig] -> ShowS
showList :: [LinearConfig] -> ShowS
Show)

-- | Strategy for calculating retry delays based on attempt count.
data BackoffStrategy
  = -- | delay = base^attempts (e.g., 2s, 4s, 8s...)
    Exponential ExponentialConfig
  | -- | delay = increment * attempts (e.g., 30s, 60s, 90s...)
    Linear LinearConfig
  | -- | Same delay for all attempts
    Constant NominalDiffTime
  | -- | User-provided function (attempts -> delay)
    Custom (Int32 -> NominalDiffTime)

-- | Jitter strategy to randomize backoff delays.
--
-- Prevents thundering herd when many jobs fail simultaneously and retry at the same time.
data Jitter
  = -- | Use exact calculated delay
    NoJitter
  | -- | delay = random(0, calculated_delay)
    FullJitter
  | -- | delay = calculated_delay/2 + random(0, calculated_delay/2). Recommended.
    EqualJitter
  deriving stock (Jitter -> Jitter -> Bool
(Jitter -> Jitter -> Bool)
-> (Jitter -> Jitter -> Bool) -> Eq Jitter
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Jitter -> Jitter -> Bool
== :: Jitter -> Jitter -> Bool
$c/= :: Jitter -> Jitter -> Bool
/= :: Jitter -> Jitter -> Bool
Eq, Int -> Jitter -> ShowS
[Jitter] -> ShowS
Jitter -> String
(Int -> Jitter -> ShowS)
-> (Jitter -> String) -> ([Jitter] -> ShowS) -> Show Jitter
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Jitter -> ShowS
showsPrec :: Int -> Jitter -> ShowS
$cshow :: Jitter -> String
show :: Jitter -> String
$cshowList :: [Jitter] -> ShowS
showList :: [Jitter] -> ShowS
Show)

-- | Calculate backoff delay for given attempt count (1-indexed).
calculateBackoff :: BackoffStrategy -> Int32 -> NominalDiffTime
calculateBackoff :: BackoffStrategy -> Int32 -> NominalDiffTime
calculateBackoff BackoffStrategy
strategy Int32
attempts = case BackoffStrategy
strategy of
  Exponential (ExponentialConfig Double
base NominalDiffTime
cap) ->
    let delay :: Double
delay = Double -> Double -> Double
forall a. Ord a => a -> a -> a
min (NominalDiffTime -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac NominalDiffTime
cap) (Double
base Double -> Int32 -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^ Int32
attempts)
     in Double -> NominalDiffTime
forall a b. (Real a, Fractional b) => a -> b
realToFrac (Double -> Double -> Double
forall a. Ord a => a -> a -> a
min (NominalDiffTime -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac NominalDiffTime
cap :: Double) Double
delay)
  Linear (LinearConfig NominalDiffTime
increment NominalDiffTime
cap) ->
    let delay :: NominalDiffTime
delay = NominalDiffTime
increment NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* Int32 -> NominalDiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
attempts
     in NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Ord a => a -> a -> a
min NominalDiffTime
cap NominalDiffTime
delay
  Constant NominalDiffTime
delay ->
    NominalDiffTime
delay
  Custom Int32 -> NominalDiffTime
f ->
    Int32 -> NominalDiffTime
f Int32
attempts

-- | Apply jitter to a calculated delay.
applyJitter :: Jitter -> NominalDiffTime -> IO NominalDiffTime
applyJitter :: Jitter -> NominalDiffTime -> IO NominalDiffTime
applyJitter Jitter
jitter NominalDiffTime
delay = case Jitter
jitter of
  Jitter
NoJitter -> NominalDiffTime -> IO NominalDiffTime
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NominalDiffTime
delay
  Jitter
FullJitter -> do
    -- random(0, delay) - convert to Double for randomness
    let delayD :: Double
delayD = NominalDiffTime -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac NominalDiffTime
delay :: Double
    jitteredD <- (Double, Double) -> IO Double
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (Double
0, Double
delayD)
    pure (realToFrac jitteredD)
  Jitter
EqualJitter -> do
    -- delay/2 + random(0, delay/2)
    let half :: NominalDiffTime
half = NominalDiffTime
delay NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Fractional a => a -> a -> a
/ NominalDiffTime
2
        halfD :: Double
halfD = NominalDiffTime -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac NominalDiffTime
half :: Double
    jitterAmountD <- (Double, Double) -> IO Double
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (Double
0, Double
halfD)
    pure (half + realToFrac jitterAmountD)

exponentialBackoff :: Double -> NominalDiffTime -> BackoffStrategy
exponentialBackoff :: Double -> NominalDiffTime -> BackoffStrategy
exponentialBackoff Double
base NominalDiffTime
cap = ExponentialConfig -> BackoffStrategy
Exponential (Double -> NominalDiffTime -> ExponentialConfig
ExponentialConfig Double
base NominalDiffTime
cap)

linearBackoff :: NominalDiffTime -> NominalDiffTime -> BackoffStrategy
linearBackoff :: NominalDiffTime -> NominalDiffTime -> BackoffStrategy
linearBackoff NominalDiffTime
increment NominalDiffTime
cap = LinearConfig -> BackoffStrategy
Linear (NominalDiffTime -> NominalDiffTime -> LinearConfig
LinearConfig NominalDiffTime
increment NominalDiffTime
cap)

constantBackoff :: NominalDiffTime -> BackoffStrategy
constantBackoff :: NominalDiffTime -> BackoffStrategy
constantBackoff = NominalDiffTime -> BackoffStrategy
Constant