{-# LANGUAGE OverloadedStrings #-}

module Arbiter.Orville.MonadArbiter
  ( orvilleExecuteQuery
  , orvilleExecuteStatement
  , orvilleWithDbTransaction
  , orvilleRunHandlerWithConnection
  ) where

import Arbiter.Core.Array qualified as Array
import Arbiter.Core.Codec (Col (..), NullCol (..), ParamType (..), RowCodec, SomeParam (..), runCodec)
import Arbiter.Core.Exceptions (throwInternal)
import Arbiter.Core.MonadArbiter (Params)
import Control.Monad (foldM)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Aeson (Value, eitherDecodeStrict', encode)
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Text (Text)
import Data.Text qualified as T
import Data.Text.Encoding qualified as TE
import Data.Text.Lazy qualified as TL
import Data.Text.Lazy.Encoding qualified as TLE
import Database.PostgreSQL.LibPQ qualified as LibPQ
import Orville.PostgreSQL qualified as O
import Orville.PostgreSQL.Marshall.FieldDefinition qualified as FieldDef
import Orville.PostgreSQL.Marshall.SqlMarshaller qualified as O
import Orville.PostgreSQL.Raw.PgTextFormatValue qualified as PgText
import Orville.PostgreSQL.Raw.RawSql (RawSql)
import Orville.PostgreSQL.Raw.RawSql qualified as RawSql
import Orville.PostgreSQL.Raw.SqlValue (SqlValue)
import Orville.PostgreSQL.Raw.SqlValue qualified as SqlValue

orvilleExecuteQuery
  :: (O.MonadOrville m)
  => Text
  -> Params
  -> RowCodec a
  -> m [a]
orvilleExecuteQuery :: forall (m :: * -> *) a.
MonadOrville m =>
Text -> Params -> RowCodec a -> m [a]
orvilleExecuteQuery Text
sql Params
params RowCodec a
codec = (Connection -> m [a]) -> m [a]
forall (m :: * -> *) a.
MonadOrville m =>
(Connection -> m a) -> m a
O.withConnection ((Connection -> m [a]) -> m [a]) -> (Connection -> m [a]) -> m [a]
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> do
  rawSql <- Text -> Params -> m RawSql
forall (m :: * -> *). MonadIO m => Text -> Params -> m RawSql
validateAndBuildRawSql Text
sql Params
params
  result <- liftIO $ RawSql.execute conn rawSql
  let marshaller = SqlMarshaller (ZonkAny 0) a -> AnnotatedSqlMarshaller (ZonkAny 0) a
forall writeEntity readEntity.
SqlMarshaller writeEntity readEntity
-> AnnotatedSqlMarshaller writeEntity readEntity
O.annotateSqlMarshallerEmptyAnnotation (SqlMarshaller () a -> SqlMarshaller (ZonkAny 0) a
forall a b c. SqlMarshaller a b -> SqlMarshaller c b
O.marshallReadOnly ((forall x. NullCol x -> SqlMarshaller () x)
-> RowCodec a -> SqlMarshaller () a
forall (f :: * -> *) a.
Applicative f =>
(forall x. NullCol x -> f x) -> RowCodec a -> f a
runCodec NullCol x -> SqlMarshaller () x
forall x. NullCol x -> SqlMarshaller () x
orvilleCol RowCodec a
codec))
  decoded <- liftIO $ O.marshallResultFromSql O.defaultErrorDetailLevel marshaller result
  case decoded of
    Right [a]
rows -> [a] -> m [a]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [a]
rows
    Left MarshallError
err -> Text -> m [a]
forall (m :: * -> *) a. MonadIO m => Text -> m a
throwInternal (Text -> m [a]) -> Text -> m [a]
forall a b. (a -> b) -> a -> b
$ Text
"orville decode error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (MarshallError -> String
forall a. Show a => a -> String
show MarshallError
err)

orvilleExecuteStatement
  :: (O.MonadOrville m)
  => Text
  -> Params
  -> m Int64
orvilleExecuteStatement :: forall (m :: * -> *). MonadOrville m => Text -> Params -> m Int64
orvilleExecuteStatement Text
sql Params
params = (Connection -> m Int64) -> m Int64
forall (m :: * -> *) a.
MonadOrville m =>
(Connection -> m a) -> m a
O.withConnection ((Connection -> m Int64) -> m Int64)
-> (Connection -> m Int64) -> m Int64
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> do
  rawSql <- Text -> Params -> m RawSql
forall (m :: * -> *). MonadIO m => Text -> Params -> m RawSql
validateAndBuildRawSql Text
sql Params
params
  result <- liftIO $ RawSql.execute conn rawSql
  liftIO $ readRowCount result

orvilleWithDbTransaction :: (O.MonadOrville m) => m a -> m a
orvilleWithDbTransaction :: forall (m :: * -> *) a. MonadOrville m => m a -> m a
orvilleWithDbTransaction = m a -> m a
forall (m :: * -> *) a. MonadOrville m => m a -> m a
O.withTransaction

orvilleRunHandlerWithConnection :: (jobs -> m result) -> jobs -> m result
orvilleRunHandlerWithConnection :: forall {k} jobs (m :: k -> *) (result :: k).
(jobs -> m result) -> jobs -> m result
orvilleRunHandlerWithConnection jobs -> m result
handler jobs
jobs = jobs -> m result
handler jobs
jobs

someParamToSqlValue :: SomeParam -> Either Text SqlValue
someParamToSqlValue :: SomeParam -> Either Text SqlValue
someParamToSqlValue (SomeParam (PScalar Col a
c) a
v) =
  SqlValue -> Either Text SqlValue
forall a b. b -> Either a b
Right (SqlValue -> Either Text SqlValue)
-> SqlValue -> Either Text SqlValue
forall a b. (a -> b) -> a -> b
$ FieldDefinition NotNull a -> a -> SqlValue
forall nullability a.
FieldDefinition nullability a -> a -> SqlValue
FieldDef.fieldValueToSqlValue (Text -> Col a -> FieldDefinition NotNull a
forall a. Text -> Col a -> FieldDefinition NotNull a
colFieldDef Text
"" Col a
c) a
v
someParamToSqlValue (SomeParam (PNullable Col a1
c) a
v) =
  SqlValue -> Either Text SqlValue
forall a b. b -> Either a b
Right (SqlValue -> Either Text SqlValue)
-> SqlValue -> Either Text SqlValue
forall a b. (a -> b) -> a -> b
$ FieldDefinition Nullable (Maybe a1) -> Maybe a1 -> SqlValue
forall nullability a.
FieldDefinition nullability a -> a -> SqlValue
FieldDef.fieldValueToSqlValue (FieldDefinition NotNull a1 -> FieldDefinition Nullable (Maybe a1)
forall a.
FieldDefinition NotNull a -> FieldDefinition Nullable (Maybe a)
O.nullableField (Text -> Col a1 -> FieldDefinition NotNull a1
forall a. Text -> Col a -> FieldDefinition NotNull a
colFieldDef Text
"" Col a1
c)) a
Maybe a1
v
someParamToSqlValue (SomeParam (PArray Col a1
c) a
vs) = do
  bs <- (a1 -> Either Text ByteString) -> [a1] -> Either Text [ByteString]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (Col a1 -> a1 -> Either Text ByteString
forall a. Col a -> a -> Either Text ByteString
colToBytes Col a1
c) a
[a1]
vs
  Right $ SqlValue.fromRawBytes $ Array.fmtArray bs
someParamToSqlValue (SomeParam (PNullArray Col a1
c) a
vs) = do
  bs <- (Maybe a1 -> Either Text (Maybe ByteString))
-> [Maybe a1] -> Either Text [Maybe ByteString]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (Col a1 -> Maybe a1 -> Either Text (Maybe ByteString)
forall a. Col a -> Maybe a -> Either Text (Maybe ByteString)
colToNullableBytes Col a1
c) a
[Maybe a1]
vs
  Right $ SqlValue.fromRawBytes $ Array.fmtNullableArray bs

colToBytes :: Col a -> a -> Either Text ByteString
colToBytes :: forall a. Col a -> a -> Either Text ByteString
colToBytes Col a
c a
v = SqlValue -> Either Text ByteString
sqlValueToBytes (SqlValue -> Either Text ByteString)
-> SqlValue -> Either Text ByteString
forall a b. (a -> b) -> a -> b
$ FieldDefinition NotNull a -> a -> SqlValue
forall nullability a.
FieldDefinition nullability a -> a -> SqlValue
FieldDef.fieldValueToSqlValue (Text -> Col a -> FieldDefinition NotNull a
forall a. Text -> Col a -> FieldDefinition NotNull a
colFieldDef Text
"" Col a
c) a
v

colToNullableBytes :: Col a -> Maybe a -> Either Text (Maybe ByteString)
colToNullableBytes :: forall a. Col a -> Maybe a -> Either Text (Maybe ByteString)
colToNullableBytes Col a
_ Maybe a
Nothing = Maybe ByteString -> Either Text (Maybe ByteString)
forall a b. b -> Either a b
Right Maybe ByteString
forall a. Maybe a
Nothing
colToNullableBytes Col a
c (Just a
v) = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> Either Text ByteString -> Either Text (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Col a -> a -> Either Text ByteString
forall a. Col a -> a -> Either Text ByteString
colToBytes Col a
c a
v

sqlValueToBytes :: SqlValue -> Either Text ByteString
sqlValueToBytes :: SqlValue -> Either Text ByteString
sqlValueToBytes =
  (PgTextFormatValue -> Either Text ByteString)
-> (NonEmpty (Either Text ByteString) -> Either Text ByteString)
-> Either Text ByteString
-> SqlValue
-> Either Text ByteString
forall r.
(PgTextFormatValue -> r) -> (NonEmpty r -> r) -> r -> SqlValue -> r
SqlValue.foldSqlValue
    (ByteString -> Either Text ByteString
forall a b. b -> Either a b
Right (ByteString -> Either Text ByteString)
-> (PgTextFormatValue -> ByteString)
-> PgTextFormatValue
-> Either Text ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PgTextFormatValue -> ByteString
PgText.toByteString)
    (Either Text ByteString
-> NonEmpty (Either Text ByteString) -> Either Text ByteString
forall a b. a -> b -> a
const (Either Text ByteString
 -> NonEmpty (Either Text ByteString) -> Either Text ByteString)
-> Either Text ByteString
-> NonEmpty (Either Text ByteString)
-> Either Text ByteString
forall a b. (a -> b) -> a -> b
$ Text -> Either Text ByteString
forall a b. a -> Either a b
Left Text
"sqlValueToBytes: got composite row, expected scalar")
    (Text -> Either Text ByteString
forall a b. a -> Either a b
Left Text
"sqlValueToBytes: got NULL, expected non-null scalar")

validateAndBuildRawSql :: (MonadIO m) => Text -> Params -> m RawSql
validateAndBuildRawSql :: forall (m :: * -> *). MonadIO m => Text -> Params -> m RawSql
validateAndBuildRawSql Text
sqlTemplate Params
params =
  case HasCallStack => Text -> Text -> [Text]
Text -> Text -> [Text]
T.splitOn Text
"?" Text
sqlTemplate of
    [] -> RawSql -> m RawSql
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure RawSql
forall a. Monoid a => a
mempty
    (Text
first : [Text]
rest)
      | Params -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Params
params Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Text] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
rest ->
          Text -> m RawSql
forall (m :: * -> *) a. MonadIO m => Text -> m a
throwInternal (Text -> m RawSql) -> Text -> m RawSql
forall a b. (a -> b) -> a -> b
$
            Text
"SQL parameter count mismatch: expected "
              Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show ([Text] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
rest))
              Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" but got "
              Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show (Params -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Params
params))
      | Bool
otherwise ->
          (RawSql -> (SomeParam, Text) -> m RawSql)
-> RawSql -> [(SomeParam, Text)] -> m RawSql
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
            ( \RawSql
acc (SomeParam
p, Text
txt) -> case SomeParam -> Either Text SqlValue
someParamToSqlValue SomeParam
p of
                Left Text
err -> Text -> m RawSql
forall (m :: * -> *) a. MonadIO m => Text -> m a
throwInternal (Text -> m RawSql) -> Text -> m RawSql
forall a b. (a -> b) -> a -> b
$ Text
"param encoding error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
err
                Right SqlValue
sv -> RawSql -> m RawSql
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RawSql -> m RawSql) -> RawSql -> m RawSql
forall a b. (a -> b) -> a -> b
$ RawSql
acc RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> SqlValue -> RawSql
RawSql.parameter SqlValue
sv RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> Text -> RawSql
RawSql.fromText Text
txt
            )
            (Text -> RawSql
RawSql.fromText Text
first)
            (Params -> [Text] -> [(SomeParam, Text)]
forall a b. [a] -> [b] -> [(a, b)]
zip Params
params [Text]
rest)

orvilleCol :: NullCol a -> O.SqlMarshaller () a
orvilleCol :: forall x. NullCol x -> SqlMarshaller () x
orvilleCol (NotNull Text
name Col a
c) = SqlMarshaller a a -> SqlMarshaller () a
forall a b c. SqlMarshaller a b -> SqlMarshaller c b
O.marshallReadOnly (SqlMarshaller a a -> SqlMarshaller () a)
-> SqlMarshaller a a -> SqlMarshaller () a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> FieldDefinition NotNull a -> SqlMarshaller a a
forall writeEntity fieldValue nullability.
(writeEntity -> fieldValue)
-> FieldDefinition nullability fieldValue
-> SqlMarshaller writeEntity fieldValue
O.marshallField a -> a
forall a. a -> a
id (Text -> Col a -> FieldDefinition NotNull a
forall a. Text -> Col a -> FieldDefinition NotNull a
colFieldDef Text
name Col a
c)
orvilleCol (Nullable Text
name Col a1
c) = SqlMarshaller a a -> SqlMarshaller () a
forall a b c. SqlMarshaller a b -> SqlMarshaller c b
O.marshallReadOnly (SqlMarshaller a a -> SqlMarshaller () a)
-> SqlMarshaller a a -> SqlMarshaller () a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> FieldDefinition Nullable a -> SqlMarshaller a a
forall writeEntity fieldValue nullability.
(writeEntity -> fieldValue)
-> FieldDefinition nullability fieldValue
-> SqlMarshaller writeEntity fieldValue
O.marshallField a -> a
forall a. a -> a
id (FieldDefinition NotNull a1 -> FieldDefinition Nullable (Maybe a1)
forall a.
FieldDefinition NotNull a -> FieldDefinition Nullable (Maybe a)
O.nullableField (Text -> Col a1 -> FieldDefinition NotNull a1
forall a. Text -> Col a -> FieldDefinition NotNull a
colFieldDef Text
name Col a1
c))

colFieldDef :: Text -> Col a -> O.FieldDefinition O.NotNull a
colFieldDef :: forall a. Text -> Col a -> FieldDefinition NotNull a
colFieldDef Text
name Col a
CInt4 = String -> FieldDefinition NotNull Int32
O.integerField (Text -> String
T.unpack Text
name)
colFieldDef Text
name Col a
CInt8 = String -> FieldDefinition NotNull Int64
O.bigIntegerField (Text -> String
T.unpack Text
name)
colFieldDef Text
name Col a
CText = String -> FieldDefinition NotNull Text
O.unboundedTextField (Text -> String
T.unpack Text
name)
colFieldDef Text
name Col a
CBool = String -> FieldDefinition NotNull Bool
O.booleanField (Text -> String
T.unpack Text
name)
colFieldDef Text
name Col a
CTimestamptz = String -> FieldDefinition NotNull UTCTime
O.utcTimestampField (Text -> String
T.unpack Text
name)
colFieldDef Text
name Col a
CJsonb = SqlType a -> String -> FieldDefinition NotNull a
forall a. SqlType a -> String -> FieldDefinition NotNull a
O.fieldOfType SqlType a
SqlType Value
jsonbValue (Text -> String
T.unpack Text
name)
colFieldDef Text
name Col a
CFloat8 = String -> FieldDefinition NotNull Double
O.doubleField (Text -> String
T.unpack Text
name)

jsonbValue :: O.SqlType Value
jsonbValue :: SqlType Value
jsonbValue =
  (Value -> Text)
-> (Text -> Either String Value) -> SqlType Text -> SqlType Value
forall b a.
(b -> a) -> (a -> Either String b) -> SqlType a -> SqlType b
O.tryConvertSqlType
    (LazyText -> Text
TL.toStrict (LazyText -> Text) -> (Value -> LazyText) -> Value -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> LazyText
TLE.decodeUtf8 (ByteString -> LazyText)
-> (Value -> ByteString) -> Value -> LazyText
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Value -> ByteString
forall a. ToJSON a => a -> ByteString
encode)
    (ByteString -> Either String Value
forall a. FromJSON a => ByteString -> Either String a
eitherDecodeStrict' (ByteString -> Either String Value)
-> (Text -> ByteString) -> Text -> Either String Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
TE.encodeUtf8)
    SqlType Text
O.jsonb

readRowCount :: LibPQ.Result -> IO Int64
readRowCount :: Result -> IO Int64
readRowCount Result
res = do
  mbTuples <- Result -> IO (Maybe ByteString)
LibPQ.cmdTuples Result
res
  case mbTuples of
    Maybe ByteString
Nothing -> Int64 -> IO Int64
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int64
0
    Just ByteString
bs -> case SqlValue -> Either String Int
SqlValue.toInt (ByteString -> SqlValue
SqlValue.fromRawBytes ByteString
bs) of
      Right Int
n -> Int64 -> IO Int64
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
      Left String
_ -> Int64 -> IO Int64
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int64
0