{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}

-- | Compatibility layer for hasql API differences.
--
-- All version-specific code lives here. The rest of arbiter-hasql
-- imports from this module and never uses CPP directly.
module Arbiter.Hasql.Compat
  ( runSQL
  , connectionInTransaction
  , hasqlSettings
  , HasqlSettings
  ) where

import Arbiter.Core.Exceptions (throwInternal)
import Control.Monad.IO.Class (liftIO)
import Data.ByteString (ByteString)
import Data.Text qualified as T
import Data.Text.Encoding qualified as TE
import Data.Text.Encoding.Error qualified as TE
import Database.PostgreSQL.LibPQ qualified as LibPQ
import Hasql.Connection qualified as Hasql
import Hasql.Session qualified as Session
import UnliftIO (MonadUnliftIO)

#if MIN_VERSION_hasql(1,10,0)
import Hasql.Connection.Settings qualified as Settings
#else
import Hasql.Connection.Setting qualified as Setting
import Hasql.Connection.Setting.Connection qualified as ConnSetting
#endif

-- | Run a simple SQL command on a hasql connection (e.g., BEGIN, COMMIT).
runSQL :: (MonadUnliftIO m) => Hasql.Connection -> ByteString -> m ()
runSQL :: forall (m :: * -> *).
MonadUnliftIO m =>
Connection -> ByteString -> m ()
runSQL Connection
conn ByteString
sql = do
  result <- IO (Either SessionError ()) -> m (Either SessionError ())
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SessionError ()) -> m (Either SessionError ()))
-> IO (Either SessionError ()) -> m (Either SessionError ())
forall a b. (a -> b) -> a -> b
$ Connection -> Session () -> IO (Either SessionError ())
forall a. Connection -> Session a -> IO (Either SessionError a)
Hasql.use Connection
conn (Text -> Session ()
runScript (OnDecodeError -> ByteString -> Text
TE.decodeUtf8With OnDecodeError
TE.lenientDecode ByteString
sql))
  case result of
    Right () -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Left SessionError
err -> Text -> m ()
forall (m :: * -> *) a. MonadIO m => Text -> m a
throwInternal (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"hasql runSQL error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (SessionError -> String
forall a. Show a => a -> String
show SessionError
err)

#if MIN_VERSION_hasql(1,10,0)
runScript :: T.Text -> Session.Session ()
runScript :: Text -> Session ()
runScript = Text -> Session ()
Session.script
#else
runScript :: T.Text -> Session.Session ()
runScript = Session.sql
#endif

-- | Returns 'True' if the connection is in a transaction block (valid or
-- aborted) and 'False' if it is idle. Used to skip a redundant @ROLLBACK@
-- when hasql has already cleaned up after an interrupted session.
connectionInTransaction :: Hasql.Connection -> IO Bool
#if MIN_VERSION_hasql(1,10,0)
connectionInTransaction :: Connection -> IO Bool
connectionInTransaction Connection
conn = do
  result <- Connection -> Session Bool -> IO (Either SessionError Bool)
forall a. Connection -> Session a -> IO (Either SessionError a)
Hasql.use Connection
conn (Session Bool -> IO (Either SessionError Bool))
-> Session Bool -> IO (Either SessionError Bool)
forall a b. (a -> b) -> a -> b
$ (Connection -> IO (Either SessionError Bool, Connection))
-> Session Bool
forall a.
(Connection -> IO (Either SessionError a, Connection)) -> Session a
Session.onLibpqConnection ((Connection -> IO (Either SessionError Bool, Connection))
 -> Session Bool)
-> (Connection -> IO (Either SessionError Bool, Connection))
-> Session Bool
forall a b. (a -> b) -> a -> b
$ \Connection
pq -> do
    status <- Connection -> IO TransactionStatus
LibPQ.transactionStatus Connection
pq
    pure (Right (txStatusNeedsRollback status), pq)
  case result of
    Right Bool
inTx -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
inTx
    Left SessionError
_ -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
#else
connectionInTransaction conn =
  Hasql.withLibPQConnection conn $ \pq -> do
    status <- LibPQ.transactionStatus pq
    pure (txStatusNeedsRollback status)
#endif

-- | Only @TransInTrans@ and @TransInError@ accept a @ROLLBACK@ without warning.
txStatusNeedsRollback :: LibPQ.TransactionStatus -> Bool
txStatusNeedsRollback :: TransactionStatus -> Bool
txStatusNeedsRollback TransactionStatus
LibPQ.TransInTrans = Bool
True
txStatusNeedsRollback TransactionStatus
LibPQ.TransInError = Bool
True
txStatusNeedsRollback TransactionStatus
_ = Bool
False

-- | Convert a connection string ByteString to hasql settings.
hasqlSettings :: ByteString -> HasqlSettings
hasqlSettings :: ByteString -> HasqlSettings
hasqlSettings = ByteString -> HasqlSettings
hasqlSettingsFromConnStr

#if MIN_VERSION_hasql(1,10,0)
type HasqlSettings = Settings.Settings
hasqlSettingsFromConnStr :: ByteString -> Settings.Settings
hasqlSettingsFromConnStr :: ByteString -> HasqlSettings
hasqlSettingsFromConnStr = Text -> HasqlSettings
Settings.connectionString (Text -> HasqlSettings)
-> (ByteString -> Text) -> ByteString -> HasqlSettings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OnDecodeError -> ByteString -> Text
TE.decodeUtf8With OnDecodeError
TE.lenientDecode
#else
type HasqlSettings = [Setting.Setting]
hasqlSettingsFromConnStr :: ByteString -> [Setting.Setting]
hasqlSettingsFromConnStr connStr = [Setting.connection (ConnSetting.string (TE.decodeUtf8With TE.lenientDecode connStr))]
#endif