{-# LANGUAGE OverloadedStrings #-} module Arbiter.Orville.MonadArbiter ( orvilleExecuteQuery , orvilleExecuteStatement , orvilleWithDbTransaction , orvilleRunHandlerWithConnection ) where import Arbiter.Core.Array qualified as Array import Arbiter.Core.Codec (Col (..), NullCol (..), ParamType (..), RowCodec, SomeParam (..), runCodec) import Arbiter.Core.Exceptions (throwInternal) import Arbiter.Core.MonadArbiter (Params) import Control.Monad (foldM) import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Aeson (Value, eitherDecodeStrict', encode) import Data.ByteString (ByteString) import Data.Int (Int64) import Data.Text (Text) import Data.Text qualified as T import Data.Text.Encoding qualified as TE import Data.Text.Lazy qualified as TL import Data.Text.Lazy.Encoding qualified as TLE import Database.PostgreSQL.LibPQ qualified as LibPQ import Orville.PostgreSQL qualified as O import Orville.PostgreSQL.Marshall.FieldDefinition qualified as FieldDef import Orville.PostgreSQL.Marshall.SqlMarshaller qualified as O import Orville.PostgreSQL.Raw.PgTextFormatValue qualified as PgText import Orville.PostgreSQL.Raw.RawSql (RawSql) import Orville.PostgreSQL.Raw.RawSql qualified as RawSql import Orville.PostgreSQL.Raw.SqlValue (SqlValue) import Orville.PostgreSQL.Raw.SqlValue qualified as SqlValue orvilleExecuteQuery :: (O.MonadOrville m) => Text -> Params -> RowCodec a -> m [a] orvilleExecuteQuery :: forall (m :: * -> *) a. MonadOrville m => Text -> Params -> RowCodec a -> m [a] orvilleExecuteQuery Text sql Params params RowCodec a codec = (Connection -> m [a]) -> m [a] forall (m :: * -> *) a. MonadOrville m => (Connection -> m a) -> m a O.withConnection ((Connection -> m [a]) -> m [a]) -> (Connection -> m [a]) -> m [a] forall a b. (a -> b) -> a -> b $ \Connection conn -> do rawSql <- Text -> Params -> m RawSql forall (m :: * -> *). MonadIO m => Text -> Params -> m RawSql validateAndBuildRawSql Text sql Params params result <- liftIO $ RawSql.execute conn rawSql let marshaller = SqlMarshaller (ZonkAny 0) a -> AnnotatedSqlMarshaller (ZonkAny 0) a forall writeEntity readEntity. SqlMarshaller writeEntity readEntity -> AnnotatedSqlMarshaller writeEntity readEntity O.annotateSqlMarshallerEmptyAnnotation (SqlMarshaller () a -> SqlMarshaller (ZonkAny 0) a forall a b c. SqlMarshaller a b -> SqlMarshaller c b O.marshallReadOnly ((forall x. NullCol x -> SqlMarshaller () x) -> RowCodec a -> SqlMarshaller () a forall (f :: * -> *) a. Applicative f => (forall x. NullCol x -> f x) -> RowCodec a -> f a runCodec NullCol x -> SqlMarshaller () x forall x. NullCol x -> SqlMarshaller () x orvilleCol RowCodec a codec)) decoded <- liftIO $ O.marshallResultFromSql O.defaultErrorDetailLevel marshaller result case decoded of Right [a] rows -> [a] -> m [a] forall a. a -> m a forall (f :: * -> *) a. Applicative f => a -> f a pure [a] rows Left MarshallError err -> Text -> m [a] forall (m :: * -> *) a. MonadIO m => Text -> m a throwInternal (Text -> m [a]) -> Text -> m [a] forall a b. (a -> b) -> a -> b $ Text "orville decode error: " Text -> Text -> Text forall a. Semigroup a => a -> a -> a <> String -> Text T.pack (MarshallError -> String forall a. Show a => a -> String show MarshallError err) orvilleExecuteStatement :: (O.MonadOrville m) => Text -> Params -> m Int64 orvilleExecuteStatement :: forall (m :: * -> *). MonadOrville m => Text -> Params -> m Int64 orvilleExecuteStatement Text sql Params params = (Connection -> m Int64) -> m Int64 forall (m :: * -> *) a. MonadOrville m => (Connection -> m a) -> m a O.withConnection ((Connection -> m Int64) -> m Int64) -> (Connection -> m Int64) -> m Int64 forall a b. (a -> b) -> a -> b $ \Connection conn -> do rawSql <- Text -> Params -> m RawSql forall (m :: * -> *). MonadIO m => Text -> Params -> m RawSql validateAndBuildRawSql Text sql Params params result <- liftIO $ RawSql.execute conn rawSql liftIO $ readRowCount result orvilleWithDbTransaction :: (O.MonadOrville m) => m a -> m a orvilleWithDbTransaction :: forall (m :: * -> *) a. MonadOrville m => m a -> m a orvilleWithDbTransaction = m a -> m a forall (m :: * -> *) a. MonadOrville m => m a -> m a O.withTransaction orvilleRunHandlerWithConnection :: (jobs -> m result) -> jobs -> m result orvilleRunHandlerWithConnection :: forall {k} jobs (m :: k -> *) (result :: k). (jobs -> m result) -> jobs -> m result orvilleRunHandlerWithConnection jobs -> m result handler jobs jobs = jobs -> m result handler jobs jobs someParamToSqlValue :: SomeParam -> Either Text SqlValue someParamToSqlValue :: SomeParam -> Either Text SqlValue someParamToSqlValue (SomeParam (PScalar Col a c) a v) = SqlValue -> Either Text SqlValue forall a b. b -> Either a b Right (SqlValue -> Either Text SqlValue) -> SqlValue -> Either Text SqlValue forall a b. (a -> b) -> a -> b $ FieldDefinition NotNull a -> a -> SqlValue forall nullability a. FieldDefinition nullability a -> a -> SqlValue FieldDef.fieldValueToSqlValue (Text -> Col a -> FieldDefinition NotNull a forall a. Text -> Col a -> FieldDefinition NotNull a colFieldDef Text "" Col a c) a v someParamToSqlValue (SomeParam (PNullable Col a1 c) a v) = SqlValue -> Either Text SqlValue forall a b. b -> Either a b Right (SqlValue -> Either Text SqlValue) -> SqlValue -> Either Text SqlValue forall a b. (a -> b) -> a -> b $ FieldDefinition Nullable (Maybe a1) -> Maybe a1 -> SqlValue forall nullability a. FieldDefinition nullability a -> a -> SqlValue FieldDef.fieldValueToSqlValue (FieldDefinition NotNull a1 -> FieldDefinition Nullable (Maybe a1) forall a. FieldDefinition NotNull a -> FieldDefinition Nullable (Maybe a) O.nullableField (Text -> Col a1 -> FieldDefinition NotNull a1 forall a. Text -> Col a -> FieldDefinition NotNull a colFieldDef Text "" Col a1 c)) a Maybe a1 v someParamToSqlValue (SomeParam (PArray Col a1 c) a vs) = do bs <- (a1 -> Either Text ByteString) -> [a1] -> Either Text [ByteString] forall (t :: * -> *) (f :: * -> *) a b. (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b) forall (f :: * -> *) a b. Applicative f => (a -> f b) -> [a] -> f [b] traverse (Col a1 -> a1 -> Either Text ByteString forall a. Col a -> a -> Either Text ByteString colToBytes Col a1 c) a [a1] vs Right $ SqlValue.fromRawBytes $ Array.fmtArray bs someParamToSqlValue (SomeParam (PNullArray Col a1 c) a vs) = do bs <- (Maybe a1 -> Either Text (Maybe ByteString)) -> [Maybe a1] -> Either Text [Maybe ByteString] forall (t :: * -> *) (f :: * -> *) a b. (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b) forall (f :: * -> *) a b. Applicative f => (a -> f b) -> [a] -> f [b] traverse (Col a1 -> Maybe a1 -> Either Text (Maybe ByteString) forall a. Col a -> Maybe a -> Either Text (Maybe ByteString) colToNullableBytes Col a1 c) a [Maybe a1] vs Right $ SqlValue.fromRawBytes $ Array.fmtNullableArray bs colToBytes :: Col a -> a -> Either Text ByteString colToBytes :: forall a. Col a -> a -> Either Text ByteString colToBytes Col a c a v = SqlValue -> Either Text ByteString sqlValueToBytes (SqlValue -> Either Text ByteString) -> SqlValue -> Either Text ByteString forall a b. (a -> b) -> a -> b $ FieldDefinition NotNull a -> a -> SqlValue forall nullability a. FieldDefinition nullability a -> a -> SqlValue FieldDef.fieldValueToSqlValue (Text -> Col a -> FieldDefinition NotNull a forall a. Text -> Col a -> FieldDefinition NotNull a colFieldDef Text "" Col a c) a v colToNullableBytes :: Col a -> Maybe a -> Either Text (Maybe ByteString) colToNullableBytes :: forall a. Col a -> Maybe a -> Either Text (Maybe ByteString) colToNullableBytes Col a _ Maybe a Nothing = Maybe ByteString -> Either Text (Maybe ByteString) forall a b. b -> Either a b Right Maybe ByteString forall a. Maybe a Nothing colToNullableBytes Col a c (Just a v) = ByteString -> Maybe ByteString forall a. a -> Maybe a Just (ByteString -> Maybe ByteString) -> Either Text ByteString -> Either Text (Maybe ByteString) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> Col a -> a -> Either Text ByteString forall a. Col a -> a -> Either Text ByteString colToBytes Col a c a v sqlValueToBytes :: SqlValue -> Either Text ByteString sqlValueToBytes :: SqlValue -> Either Text ByteString sqlValueToBytes = (PgTextFormatValue -> Either Text ByteString) -> (NonEmpty (Either Text ByteString) -> Either Text ByteString) -> Either Text ByteString -> SqlValue -> Either Text ByteString forall r. (PgTextFormatValue -> r) -> (NonEmpty r -> r) -> r -> SqlValue -> r SqlValue.foldSqlValue (ByteString -> Either Text ByteString forall a b. b -> Either a b Right (ByteString -> Either Text ByteString) -> (PgTextFormatValue -> ByteString) -> PgTextFormatValue -> Either Text ByteString forall b c a. (b -> c) -> (a -> b) -> a -> c . PgTextFormatValue -> ByteString PgText.toByteString) (Either Text ByteString -> NonEmpty (Either Text ByteString) -> Either Text ByteString forall a b. a -> b -> a const (Either Text ByteString -> NonEmpty (Either Text ByteString) -> Either Text ByteString) -> Either Text ByteString -> NonEmpty (Either Text ByteString) -> Either Text ByteString forall a b. (a -> b) -> a -> b $ Text -> Either Text ByteString forall a b. a -> Either a b Left Text "sqlValueToBytes: got composite row, expected scalar") (Text -> Either Text ByteString forall a b. a -> Either a b Left Text "sqlValueToBytes: got NULL, expected non-null scalar") validateAndBuildRawSql :: (MonadIO m) => Text -> Params -> m RawSql validateAndBuildRawSql :: forall (m :: * -> *). MonadIO m => Text -> Params -> m RawSql validateAndBuildRawSql Text sqlTemplate Params params = case HasCallStack => Text -> Text -> [Text] Text -> Text -> [Text] T.splitOn Text "?" Text sqlTemplate of [] -> RawSql -> m RawSql forall a. a -> m a forall (f :: * -> *) a. Applicative f => a -> f a pure RawSql forall a. Monoid a => a mempty (Text first : [Text] rest) | Params -> Int forall a. [a] -> Int forall (t :: * -> *) a. Foldable t => t a -> Int length Params params Int -> Int -> Bool forall a. Eq a => a -> a -> Bool /= [Text] -> Int forall a. [a] -> Int forall (t :: * -> *) a. Foldable t => t a -> Int length [Text] rest -> Text -> m RawSql forall (m :: * -> *) a. MonadIO m => Text -> m a throwInternal (Text -> m RawSql) -> Text -> m RawSql forall a b. (a -> b) -> a -> b $ Text "SQL parameter count mismatch: expected " Text -> Text -> Text forall a. Semigroup a => a -> a -> a <> String -> Text T.pack (Int -> String forall a. Show a => a -> String show ([Text] -> Int forall a. [a] -> Int forall (t :: * -> *) a. Foldable t => t a -> Int length [Text] rest)) Text -> Text -> Text forall a. Semigroup a => a -> a -> a <> Text " but got " Text -> Text -> Text forall a. Semigroup a => a -> a -> a <> String -> Text T.pack (Int -> String forall a. Show a => a -> String show (Params -> Int forall a. [a] -> Int forall (t :: * -> *) a. Foldable t => t a -> Int length Params params)) | Bool otherwise -> (RawSql -> (SomeParam, Text) -> m RawSql) -> RawSql -> [(SomeParam, Text)] -> m RawSql forall (t :: * -> *) (m :: * -> *) b a. (Foldable t, Monad m) => (b -> a -> m b) -> b -> t a -> m b foldM ( \RawSql acc (SomeParam p, Text txt) -> case SomeParam -> Either Text SqlValue someParamToSqlValue SomeParam p of Left Text err -> Text -> m RawSql forall (m :: * -> *) a. MonadIO m => Text -> m a throwInternal (Text -> m RawSql) -> Text -> m RawSql forall a b. (a -> b) -> a -> b $ Text "param encoding error: " Text -> Text -> Text forall a. Semigroup a => a -> a -> a <> Text err Right SqlValue sv -> RawSql -> m RawSql forall a. a -> m a forall (f :: * -> *) a. Applicative f => a -> f a pure (RawSql -> m RawSql) -> RawSql -> m RawSql forall a b. (a -> b) -> a -> b $ RawSql acc RawSql -> RawSql -> RawSql forall a. Semigroup a => a -> a -> a <> SqlValue -> RawSql RawSql.parameter SqlValue sv RawSql -> RawSql -> RawSql forall a. Semigroup a => a -> a -> a <> Text -> RawSql RawSql.fromText Text txt ) (Text -> RawSql RawSql.fromText Text first) (Params -> [Text] -> [(SomeParam, Text)] forall a b. [a] -> [b] -> [(a, b)] zip Params params [Text] rest) orvilleCol :: NullCol a -> O.SqlMarshaller () a orvilleCol :: forall x. NullCol x -> SqlMarshaller () x orvilleCol (NotNull Text name Col a c) = SqlMarshaller a a -> SqlMarshaller () a forall a b c. SqlMarshaller a b -> SqlMarshaller c b O.marshallReadOnly (SqlMarshaller a a -> SqlMarshaller () a) -> SqlMarshaller a a -> SqlMarshaller () a forall a b. (a -> b) -> a -> b $ (a -> a) -> FieldDefinition NotNull a -> SqlMarshaller a a forall writeEntity fieldValue nullability. (writeEntity -> fieldValue) -> FieldDefinition nullability fieldValue -> SqlMarshaller writeEntity fieldValue O.marshallField a -> a forall a. a -> a id (Text -> Col a -> FieldDefinition NotNull a forall a. Text -> Col a -> FieldDefinition NotNull a colFieldDef Text name Col a c) orvilleCol (Nullable Text name Col a1 c) = SqlMarshaller a a -> SqlMarshaller () a forall a b c. SqlMarshaller a b -> SqlMarshaller c b O.marshallReadOnly (SqlMarshaller a a -> SqlMarshaller () a) -> SqlMarshaller a a -> SqlMarshaller () a forall a b. (a -> b) -> a -> b $ (a -> a) -> FieldDefinition Nullable a -> SqlMarshaller a a forall writeEntity fieldValue nullability. (writeEntity -> fieldValue) -> FieldDefinition nullability fieldValue -> SqlMarshaller writeEntity fieldValue O.marshallField a -> a forall a. a -> a id (FieldDefinition NotNull a1 -> FieldDefinition Nullable (Maybe a1) forall a. FieldDefinition NotNull a -> FieldDefinition Nullable (Maybe a) O.nullableField (Text -> Col a1 -> FieldDefinition NotNull a1 forall a. Text -> Col a -> FieldDefinition NotNull a colFieldDef Text name Col a1 c)) colFieldDef :: Text -> Col a -> O.FieldDefinition O.NotNull a colFieldDef :: forall a. Text -> Col a -> FieldDefinition NotNull a colFieldDef Text name Col a CInt4 = String -> FieldDefinition NotNull Int32 O.integerField (Text -> String T.unpack Text name) colFieldDef Text name Col a CInt8 = String -> FieldDefinition NotNull Int64 O.bigIntegerField (Text -> String T.unpack Text name) colFieldDef Text name Col a CText = String -> FieldDefinition NotNull Text O.unboundedTextField (Text -> String T.unpack Text name) colFieldDef Text name Col a CBool = String -> FieldDefinition NotNull Bool O.booleanField (Text -> String T.unpack Text name) colFieldDef Text name Col a CTimestamptz = String -> FieldDefinition NotNull UTCTime O.utcTimestampField (Text -> String T.unpack Text name) colFieldDef Text name Col a CJsonb = SqlType a -> String -> FieldDefinition NotNull a forall a. SqlType a -> String -> FieldDefinition NotNull a O.fieldOfType SqlType a SqlType Value jsonbValue (Text -> String T.unpack Text name) colFieldDef Text name Col a CFloat8 = String -> FieldDefinition NotNull Double O.doubleField (Text -> String T.unpack Text name) jsonbValue :: O.SqlType Value jsonbValue :: SqlType Value jsonbValue = (Value -> Text) -> (Text -> Either String Value) -> SqlType Text -> SqlType Value forall b a. (b -> a) -> (a -> Either String b) -> SqlType a -> SqlType b O.tryConvertSqlType (LazyText -> Text TL.toStrict (LazyText -> Text) -> (Value -> LazyText) -> Value -> Text forall b c a. (b -> c) -> (a -> b) -> a -> c . ByteString -> LazyText TLE.decodeUtf8 (ByteString -> LazyText) -> (Value -> ByteString) -> Value -> LazyText forall b c a. (b -> c) -> (a -> b) -> a -> c . Value -> ByteString forall a. ToJSON a => a -> ByteString encode) (ByteString -> Either String Value forall a. FromJSON a => ByteString -> Either String a eitherDecodeStrict' (ByteString -> Either String Value) -> (Text -> ByteString) -> Text -> Either String Value forall b c a. (b -> c) -> (a -> b) -> a -> c . Text -> ByteString TE.encodeUtf8) SqlType Text O.jsonb readRowCount :: LibPQ.Result -> IO Int64 readRowCount :: Result -> IO Int64 readRowCount Result res = do mbTuples <- Result -> IO (Maybe ByteString) LibPQ.cmdTuples Result res case mbTuples of Maybe ByteString Nothing -> Int64 -> IO Int64 forall a. a -> IO a forall (f :: * -> *) a. Applicative f => a -> f a pure Int64 0 Just ByteString bs -> case SqlValue -> Either String Int SqlValue.toInt (ByteString -> SqlValue SqlValue.fromRawBytes ByteString bs) of Right Int n -> Int64 -> IO Int64 forall a. a -> IO a forall (f :: * -> *) a. Applicative f => a -> f a pure (Int -> Int64 forall a b. (Integral a, Num b) => a -> b fromIntegral Int n) Left String _ -> Int64 -> IO Int64 forall a. a -> IO a forall (f :: * -> *) a. Applicative f => a -> f a pure Int64 0