{-# LANGUAGE OverloadedStrings #-}

-- | @hasql@ implementation helpers for 'MonadArbiter'.
--
-- Handlers receive a @Hasql.Connection.Connection@ for running typed hasql
-- queries inside the worker transaction:
--
-- @
-- import Arbiter.Hasql.MonadArbiter
-- import Hasql.Connection qualified as Hasql
--
-- instance MonadArbiter MyApp where
--   type Handler MyApp jobs result = Hasql.Connection -> jobs -> MyApp result
--   executeQuery             = hasqlExecuteQuery
--   executeStatement         = hasqlExecuteStatement
--   withDbTransaction        = hasqlWithDbTransaction
--   runHandlerWithConnection = hasqlRunHandlerWithConnection
-- @
module Arbiter.Hasql.MonadArbiter
  ( -- * MonadArbiter implementation
    hasqlExecuteQuery
  , hasqlExecuteStatement
  , hasqlWithDbTransaction
  , hasqlRunHandlerWithConnection

    -- * Connection pool management
  , HasqlConnectionPool (..)
  , HasHasqlPool (..)
  , localHasqlConnection
  ) where

import Arbiter.Core.Codec (RowCodec)
import Arbiter.Core.Exceptions (throwInternal)
import Arbiter.Core.MonadArbiter (Params)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.ByteString.Char8 qualified as BSC
import Data.Int (Int64)
import Data.Pool qualified as Pool
import Data.Text (Text)
import Data.Text qualified as T
import Hasql.Connection qualified as Hasql
import Hasql.Session qualified as Session
import Hasql.Statement qualified as S
import UnliftIO (MonadUnliftIO, mask, onException, throwIO, withRunInIO)
import UnliftIO.Exception (SomeException, try)

import Arbiter.Hasql.Compat qualified as Compat
import Arbiter.Hasql.Decode qualified as Decode
import Arbiter.Hasql.Encode qualified as Encode

-- | Connection pool state for hasql connections.
--
-- Mirrors @SimpleConnectionPool@ from @arbiter-simple@.
data HasqlConnectionPool = HasqlConnectionPool
  { HasqlConnectionPool -> Maybe (Pool Connection)
connectionPool :: Maybe (Pool.Pool Hasql.Connection)
  -- ^ The underlying resource pool. 'Nothing' when using connection-only mode
  -- via 'inTransaction'.
  , HasqlConnectionPool -> Maybe Connection
activeConn :: Maybe Hasql.Connection
  -- ^ Pinned connection when inside a transaction
  , HasqlConnectionPool -> Int
transactionDepth :: Int
  -- ^ Current nesting depth (0 = no active transaction)
  }

-- | Typeclass for monads that carry a hasql connection pool.
class (Monad m) => HasHasqlPool m where
  getHasqlPool :: m HasqlConnectionPool
  localHasqlPool :: (HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a

-- | Pin a hasql connection for transactional work.
--
-- All arbiter operations within the callback will use this connection.
-- The caller must have already issued @BEGIN@ on the connection.
localHasqlConnection :: (HasHasqlPool m) => Hasql.Connection -> m a -> m a
localHasqlConnection :: forall (m :: * -> *) a. HasHasqlPool m => Connection -> m a -> m a
localHasqlConnection Connection
conn = (HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a
forall a.
(HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a
forall (m :: * -> *) a.
HasHasqlPool m =>
(HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a
localHasqlPool (\HasqlConnectionPool
pool -> HasqlConnectionPool
pool {activeConn = Just conn, transactionDepth = 1})

hasqlExecuteQuery
  :: (HasHasqlPool m, MonadIO m)
  => Text
  -> Params
  -> RowCodec a
  -> m [a]
hasqlExecuteQuery :: forall (m :: * -> *) a.
(HasHasqlPool m, MonadIO m) =>
Text -> Params -> RowCodec a -> m [a]
hasqlExecuteQuery Text
sql Params
params RowCodec a
codec = (Connection -> IO [a]) -> m [a]
forall (m :: * -> *) a.
(HasHasqlPool m, MonadIO m) =>
(Connection -> IO a) -> m a
withConn ((Connection -> IO [a]) -> m [a])
-> (Connection -> IO [a]) -> m [a]
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> IO [a] -> IO [a]
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [a] -> IO [a]) -> IO [a] -> IO [a]
forall a b. (a -> b) -> a -> b
$ do
  let stmt :: Statement () [a]
stmt = Text -> Params () -> Result [a] -> Statement () [a]
forall params result.
Text -> Params params -> Result result -> Statement params result
S.preparable (Text -> Text
Encode.convertPlaceholders Text
sql) (Params -> Params ()
Encode.buildEncoder Params
params) (RowCodec a -> Result [a]
forall a. RowCodec a -> Result [a]
Decode.hasqlRowDecoder RowCodec a
codec)
  result <- Connection -> Session [a] -> IO (Either SessionError [a])
forall a. Connection -> Session a -> IO (Either SessionError a)
Hasql.use Connection
conn (() -> Statement () [a] -> Session [a]
forall params result.
params -> Statement params result -> Session result
Session.statement () Statement () [a]
stmt)
  case result of
    Right [a]
rows -> [a] -> IO [a]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [a]
rows
    Left SessionError
err -> Text -> IO [a]
forall (m :: * -> *) a. MonadIO m => Text -> m a
throwInternal (Text -> IO [a]) -> Text -> IO [a]
forall a b. (a -> b) -> a -> b
$ Text
"hasql query error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (SessionError -> String
forall a. Show a => a -> String
show SessionError
err)

hasqlExecuteStatement
  :: (HasHasqlPool m, MonadIO m)
  => Text
  -> Params
  -> m Int64
hasqlExecuteStatement :: forall (m :: * -> *).
(HasHasqlPool m, MonadIO m) =>
Text -> Params -> m Int64
hasqlExecuteStatement Text
sql Params
params = (Connection -> IO Int64) -> m Int64
forall (m :: * -> *) a.
(HasHasqlPool m, MonadIO m) =>
(Connection -> IO a) -> m a
withConn ((Connection -> IO Int64) -> m Int64)
-> (Connection -> IO Int64) -> m Int64
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> IO Int64 -> IO Int64
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int64 -> IO Int64) -> IO Int64 -> IO Int64
forall a b. (a -> b) -> a -> b
$ do
  let stmt :: Statement () Int64
stmt = Text -> Params -> Statement () Int64
Encode.buildStatementRowCount Text
sql Params
params
  result <- Connection -> Session Int64 -> IO (Either SessionError Int64)
forall a. Connection -> Session a -> IO (Either SessionError a)
Hasql.use Connection
conn (() -> Statement () Int64 -> Session Int64
forall params result.
params -> Statement params result -> Session result
Session.statement () Statement () Int64
stmt)
  case result of
    Right Int64
n -> Int64 -> IO Int64
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int64
n
    Left SessionError
err -> Text -> IO Int64
forall (m :: * -> *) a. MonadIO m => Text -> m a
throwInternal (Text -> IO Int64) -> Text -> IO Int64
forall a b. (a -> b) -> a -> b
$ Text
"hasql statement error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (SessionError -> String
forall a. Show a => a -> String
show SessionError
err)

-- | Run a block of code within a database transaction.
--
-- Supports nested transactions via savepoints, matching @arbiter-simple@.
hasqlWithDbTransaction :: (HasHasqlPool m, MonadUnliftIO m) => m a -> m a
hasqlWithDbTransaction :: forall (m :: * -> *) a.
(HasHasqlPool m, MonadUnliftIO m) =>
m a -> m a
hasqlWithDbTransaction m a
action = do
  pool <- m HasqlConnectionPool
forall (m :: * -> *). HasHasqlPool m => m HasqlConnectionPool
getHasqlPool
  let depth = HasqlConnectionPool -> Int
transactionDepth HasqlConnectionPool
pool
  case (activeConn pool, depth) of
    (Maybe Connection
Nothing, Int
_) -> case HasqlConnectionPool -> Maybe (Pool Connection)
connectionPool HasqlConnectionPool
pool of
      Maybe (Pool Connection)
Nothing -> Text -> m a
forall (m :: * -> *) a. MonadIO m => Text -> m a
throwInternal Text
"No active connection and no connection pool available"
      Just Pool Connection
p -> ((forall a. m a -> IO a) -> IO a) -> m a
forall b. ((forall a. m a -> IO a) -> IO b) -> m b
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO (((forall a. m a -> IO a) -> IO a) -> m a)
-> ((forall a. m a -> IO a) -> IO a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> IO a
run ->
        Pool Connection -> (Connection -> IO a) -> IO a
forall a r. Pool a -> (a -> IO r) -> IO r
Pool.withResource Pool Connection
p ((Connection -> IO a) -> IO a) -> (Connection -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Connection
conn ->
          Connection -> IO a -> IO a
forall a. Connection -> IO a -> IO a
beginCommitOrRollback Connection
conn (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$
            m a -> IO a
forall a. m a -> IO a
run ((HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a
forall a.
(HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a
forall (m :: * -> *) a.
HasHasqlPool m =>
(HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a
localHasqlPool (\HasqlConnectionPool
hpool -> HasqlConnectionPool
hpool {activeConn = Just conn, transactionDepth = 1}) m a
action)
    (Just Connection
conn, Int
0) -> ((forall a. m a -> IO a) -> IO a) -> m a
forall b. ((forall a. m a -> IO a) -> IO b) -> m b
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO (((forall a. m a -> IO a) -> IO a) -> m a)
-> ((forall a. m a -> IO a) -> IO a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> IO a
run ->
      Connection -> IO a -> IO a
forall a. Connection -> IO a -> IO a
beginCommitOrRollback Connection
conn (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$
        m a -> IO a
forall a. m a -> IO a
run ((HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a
forall a.
(HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a
forall (m :: * -> *) a.
HasHasqlPool m =>
(HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a
localHasqlPool (\HasqlConnectionPool
p -> HasqlConnectionPool
p {transactionDepth = 1}) m a
action)
    (Just Connection
conn, Int
d) -> ((forall a. m a -> m a) -> m a) -> m a
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m a) -> m a)
-> ((forall a. m a -> m a) -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
      let spName :: ByteString
spName = ByteString
"arbiter_sp_" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> String -> ByteString
BSC.pack (Int -> String
forall a. Show a => a -> String
show Int
d)
      IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Connection -> ByteString -> IO ()
forall (m :: * -> *).
MonadUnliftIO m =>
Connection -> ByteString -> m ()
Compat.runSQL Connection
conn (ByteString
"SAVEPOINT " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
spName)
      a <-
        m a -> m a
forall a. m a -> m a
restore ((HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a
forall a.
(HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a
forall (m :: * -> *) a.
HasHasqlPool m =>
(HasqlConnectionPool -> HasqlConnectionPool) -> m a -> m a
localHasqlPool (\HasqlConnectionPool
p -> HasqlConnectionPool
p {transactionDepth = d + 1}) m a
action)
          m a -> m () -> m a
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
`onException` IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Connection -> ByteString -> IO ()
forall (m :: * -> *).
MonadUnliftIO m =>
Connection -> ByteString -> m ()
Compat.runSQL Connection
conn (ByteString
"ROLLBACK TO SAVEPOINT " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
spName))
      liftIO $ Compat.runSQL conn ("RELEASE SAVEPOINT " <> spName)
      pure a

beginCommitOrRollback :: forall a. Hasql.Connection -> IO a -> IO a
beginCommitOrRollback :: forall a. Connection -> IO a -> IO a
beginCommitOrRollback Connection
conn IO a
action = do
  Connection -> ByteString -> IO ()
forall (m :: * -> *).
MonadUnliftIO m =>
Connection -> ByteString -> m ()
Compat.runSQL Connection
conn ByteString
"BEGIN"
  eitherResult <- IO a -> IO (Either SomeException a)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
try IO a
action :: IO (Either SomeException a)
  case eitherResult of
    Right a
result -> do
      Connection -> ByteString -> IO ()
forall (m :: * -> *).
MonadUnliftIO m =>
Connection -> ByteString -> m ()
Compat.runSQL Connection
conn ByteString
"COMMIT"
      a -> IO a
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
result
    Left SomeException
exc -> do
      _ <- IO () -> IO (Either SomeException ())
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
try (Connection -> ByteString -> IO ()
forall (m :: * -> *).
MonadUnliftIO m =>
Connection -> ByteString -> m ()
Compat.runSQL Connection
conn ByteString
"ROLLBACK") :: IO (Either SomeException ())
      throwIO exc

-- | Invoke a handler by passing the active hasql connection.
--
-- The handler receives a @Hasql.Connection@ so it can run typed hasql
-- queries within the worker transaction.
hasqlRunHandlerWithConnection
  :: (HasHasqlPool m, MonadIO m)
  => (Hasql.Connection -> jobs -> m result)
  -> jobs
  -> m result
hasqlRunHandlerWithConnection :: forall (m :: * -> *) jobs result.
(HasHasqlPool m, MonadIO m) =>
(Connection -> jobs -> m result) -> jobs -> m result
hasqlRunHandlerWithConnection Connection -> jobs -> m result
handler jobs
jobs = do
  pool <- m HasqlConnectionPool
forall (m :: * -> *). HasHasqlPool m => m HasqlConnectionPool
getHasqlPool
  case activeConn pool of
    Just Connection
conn -> Connection -> jobs -> m result
handler Connection
conn jobs
jobs
    Maybe Connection
Nothing -> Text -> m result
forall (m :: * -> *) a. MonadIO m => Text -> m a
throwInternal Text
"hasqlRunHandlerWithConnection: no active connection"

-- ---------------------------------------------------------------------------
-- Internal
-- ---------------------------------------------------------------------------

-- | Get a connection from the pool state or check one out.
withConn :: (HasHasqlPool m, MonadIO m) => (Hasql.Connection -> IO a) -> m a
withConn :: forall (m :: * -> *) a.
(HasHasqlPool m, MonadIO m) =>
(Connection -> IO a) -> m a
withConn Connection -> IO a
f = do
  pool <- m HasqlConnectionPool
forall (m :: * -> *). HasHasqlPool m => m HasqlConnectionPool
getHasqlPool
  case (activeConn pool, connectionPool pool) of
    (Just Connection
conn, Maybe (Pool Connection)
_) -> IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> IO a -> m a
forall a b. (a -> b) -> a -> b
$ Connection -> IO a
f Connection
conn
    (Maybe Connection
Nothing, Just Pool Connection
p) -> IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> IO a -> m a
forall a b. (a -> b) -> a -> b
$ Pool Connection -> (Connection -> IO a) -> IO a
forall a r. Pool a -> (a -> IO r) -> IO r
Pool.withResource Pool Connection
p Connection -> IO a
f
    (Maybe Connection
Nothing, Maybe (Pool Connection)
Nothing) -> Text -> m a
forall (m :: * -> *) a. MonadIO m => Text -> m a
throwInternal Text
"No active connection and no connection pool available"