{-# LANGUAGE OverloadedStrings #-}

module Arbiter.Hasql.Encode
  ( buildEncoder
  , buildStatementRowCount
  , convertPlaceholders
  , encodeSomeParam
  , colEncoder
  ) where

import Arbiter.Core.Codec (Col (..), ParamType (..), Params, SomeParam (..))
import Data.Functor.Contravariant (contramap)
import Data.Int (Int64)
import Data.Text (Text)
import Data.Text qualified as T
import Hasql.Decoders qualified as D
import Hasql.Encoders qualified as E
import Hasql.Statement qualified as S

buildStatementRowCount :: Text -> Params -> S.Statement () Int64
buildStatementRowCount :: Text -> Params -> Statement () Int64
buildStatementRowCount Text
sql Params
ps =
  Text -> Params () -> Result Int64 -> Statement () Int64
forall params result.
Text -> Params params -> Result result -> Statement params result
S.preparable (Text -> Text
convertPlaceholders Text
sql) (Params -> Params ()
buildEncoder Params
ps) Result Int64
D.rowsAffected

buildEncoder :: Params -> E.Params ()
buildEncoder :: Params -> Params ()
buildEncoder = [Params ()] -> Params ()
forall a. Monoid a => [a] -> a
mconcat ([Params ()] -> Params ())
-> (Params -> [Params ()]) -> Params -> Params ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SomeParam -> Params ()) -> Params -> [Params ()]
forall a b. (a -> b) -> [a] -> [b]
map SomeParam -> Params ()
encodeSomeParam

encodeSomeParam :: SomeParam -> E.Params ()
encodeSomeParam :: SomeParam -> Params ()
encodeSomeParam (SomeParam ParamType a
pt a
v) = case ParamType a
pt of
  PScalar Col a
c -> (() -> a) -> Params a -> Params ()
forall a' a. (a' -> a) -> Params a -> Params a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
contramap (a -> () -> a
forall a b. a -> b -> a
const a
v) (Params a -> Params ()) -> Params a -> Params ()
forall a b. (a -> b) -> a -> b
$ NullableOrNot Value a -> Params a
forall a. NullableOrNot Value a -> Params a
E.param (Value a -> NullableOrNot Value a
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
E.nonNullable (Col a -> Value a
forall a. Col a -> Value a
colEncoder Col a
c))
  PNullable Col a1
c -> (() -> a) -> Params a -> Params ()
forall a' a. (a' -> a) -> Params a -> Params a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
contramap (a -> () -> a
forall a b. a -> b -> a
const a
v) (Params a -> Params ()) -> Params a -> Params ()
forall a b. (a -> b) -> a -> b
$ NullableOrNot Value a -> Params a
forall a. NullableOrNot Value a -> Params a
E.param (Value a1 -> NullableOrNot Value (Maybe a1)
forall (encoder :: * -> *) a.
encoder a -> NullableOrNot encoder (Maybe a)
E.nullable (Col a1 -> Value a1
forall a. Col a -> Value a
colEncoder Col a1
c))
  PArray Col a1
c ->
    (() -> a) -> Params a -> Params ()
forall a' a. (a' -> a) -> Params a -> Params a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
contramap (a -> () -> a
forall a b. a -> b -> a
const a
v) (Params a -> Params ()) -> Params a -> Params ()
forall a b. (a -> b) -> a -> b
$ NullableOrNot Value a -> Params a
forall a. NullableOrNot Value a -> Params a
E.param (Value a -> NullableOrNot Value a
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
E.nonNullable (Array a -> Value a
forall a. Array a -> Value a
E.array ((forall a. (a -> a1 -> a) -> a -> a -> a) -> Array a1 -> Array a
forall b c.
(forall a. (a -> b -> a) -> a -> c -> a) -> Array b -> Array c
E.dimension (a -> a1 -> a) -> a -> a -> a
(a -> a1 -> a) -> a -> [a1] -> a
forall a. (a -> a1 -> a) -> a -> a -> a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (NullableOrNot Value a1 -> Array a1
forall a. NullableOrNot Value a -> Array a
E.element (Value a1 -> NullableOrNot Value a1
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
E.nonNullable (Col a1 -> Value a1
forall a. Col a -> Value a
colEncoder Col a1
c))))))
  PNullArray Col a1
c -> (() -> a) -> Params a -> Params ()
forall a' a. (a' -> a) -> Params a -> Params a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
contramap (a -> () -> a
forall a b. a -> b -> a
const a
v) (Params a -> Params ()) -> Params a -> Params ()
forall a b. (a -> b) -> a -> b
$ NullableOrNot Value a -> Params a
forall a. NullableOrNot Value a -> Params a
E.param (Value a -> NullableOrNot Value a
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
E.nonNullable (Array a -> Value a
forall a. Array a -> Value a
E.array ((forall a. (a -> Maybe a1 -> a) -> a -> a -> a)
-> Array (Maybe a1) -> Array a
forall b c.
(forall a. (a -> b -> a) -> a -> c -> a) -> Array b -> Array c
E.dimension (a -> Maybe a1 -> a) -> a -> a -> a
(a -> Maybe a1 -> a) -> a -> [Maybe a1] -> a
forall a. (a -> Maybe a1 -> a) -> a -> a -> a
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (NullableOrNot Value (Maybe a1) -> Array (Maybe a1)
forall a. NullableOrNot Value a -> Array a
E.element (Value a1 -> NullableOrNot Value (Maybe a1)
forall (encoder :: * -> *) a.
encoder a -> NullableOrNot encoder (Maybe a)
E.nullable (Col a1 -> Value a1
forall a. Col a -> Value a
colEncoder Col a1
c))))))

colEncoder :: Col a -> E.Value a
colEncoder :: forall a. Col a -> Value a
colEncoder Col a
CInt4 = Value a
Value Int32
E.int4
colEncoder Col a
CInt8 = Value a
Value Int64
E.int8
colEncoder Col a
CText = Value a
Value Text
E.text
colEncoder Col a
CBool = Value a
Value Bool
E.bool
colEncoder Col a
CTimestamptz = Value a
Value UTCTime
E.timestamptz
colEncoder Col a
CJsonb = Value a
Value Value
E.jsonb
colEncoder Col a
CFloat8 = Value a
Value Double
E.float8

convertPlaceholders :: Text -> Text
convertPlaceholders :: Text -> Text
convertPlaceholders Text
sql =
  case HasCallStack => Text -> Text -> [Text]
Text -> Text -> [Text]
T.splitOn Text
"?" Text
sql of
    [] -> Text
""
    (Text
first : [Text]
rest) ->
      Text
first Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat ((Int -> Text -> Text) -> [Int] -> [Text] -> [Text]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i Text
part -> Text
"$" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show (Int
i :: Int)) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
part) [Int
1 ..] [Text]
rest)