{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}

-- | Compatibility layer for hasql API differences.
--
-- All version-specific code lives here. The rest of arbiter-hasql
-- imports from this module and never uses CPP directly.
module Arbiter.Hasql.Compat
  ( runSQL
  , hasqlSettings
  , HasqlSettings
  ) where

import Arbiter.Core.Exceptions (throwInternal)
import Control.Monad.IO.Class (liftIO)
import Data.ByteString (ByteString)
import Data.Text qualified as T
import Data.Text.Encoding qualified as TE
import Data.Text.Encoding.Error qualified as TE
import Hasql.Connection qualified as Hasql
import Hasql.Session qualified as Session
import UnliftIO (MonadUnliftIO)

#if MIN_VERSION_hasql(1,10,0)
import Hasql.Connection.Settings qualified as Settings
#else
import Hasql.Connection.Setting qualified as Setting
import Hasql.Connection.Setting.Connection qualified as ConnSetting
#endif

-- | Run a simple SQL command on a hasql connection (e.g., BEGIN, COMMIT).
runSQL :: (MonadUnliftIO m) => Hasql.Connection -> ByteString -> m ()
runSQL :: forall (m :: * -> *).
MonadUnliftIO m =>
Connection -> ByteString -> m ()
runSQL Connection
conn ByteString
sql = do
  result <- IO (Either SessionError ()) -> m (Either SessionError ())
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SessionError ()) -> m (Either SessionError ()))
-> IO (Either SessionError ()) -> m (Either SessionError ())
forall a b. (a -> b) -> a -> b
$ Connection -> Session () -> IO (Either SessionError ())
forall a. Connection -> Session a -> IO (Either SessionError a)
Hasql.use Connection
conn (Text -> Session ()
runScript (OnDecodeError -> ByteString -> Text
TE.decodeUtf8With OnDecodeError
TE.lenientDecode ByteString
sql))
  case result of
    Right () -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Left SessionError
err -> Text -> m ()
forall (m :: * -> *) a. MonadIO m => Text -> m a
throwInternal (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"hasql runSQL error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (SessionError -> String
forall a. Show a => a -> String
show SessionError
err)

#if MIN_VERSION_hasql(1,10,0)
runScript :: T.Text -> Session.Session ()
runScript :: Text -> Session ()
runScript = Text -> Session ()
Session.script
#else
runScript :: T.Text -> Session.Session ()
runScript = Session.sql
#endif

-- | Convert a connection string ByteString to hasql settings.
hasqlSettings :: ByteString -> HasqlSettings
hasqlSettings :: ByteString -> HasqlSettings
hasqlSettings = ByteString -> HasqlSettings
hasqlSettingsFromConnStr

#if MIN_VERSION_hasql(1,10,0)
type HasqlSettings = Settings.Settings
hasqlSettingsFromConnStr :: ByteString -> Settings.Settings
hasqlSettingsFromConnStr :: ByteString -> HasqlSettings
hasqlSettingsFromConnStr = Text -> HasqlSettings
Settings.connectionString (Text -> HasqlSettings)
-> (ByteString -> Text) -> ByteString -> HasqlSettings
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OnDecodeError -> ByteString -> Text
TE.decodeUtf8With OnDecodeError
TE.lenientDecode
#else
type HasqlSettings = [Setting.Setting]
hasqlSettingsFromConnStr :: ByteString -> [Setting.Setting]
hasqlSettingsFromConnStr connStr = [Setting.connection (ConnSetting.string (TE.decodeUtf8With TE.lenientDecode connStr))]
#endif