-- |
-- Module: Symtegration.Differentiation
-- Description: Differentiate mathematical expressions.
-- Copyright: Copyright 2024 Yoo Chung
-- License: Apache-2.0
-- Maintainer: dev@chungyc.org
--
-- Differentiate symbolic representations of mathematical expressions.
-- This module does not actually implement differentiation,
-- but is rather a thin wrapper over "Numeric.AD" providing
-- derivatives for 'Expression' with some simplification applied.
module Symtegration.Differentiation (differentiate) where

import Data.Text (Text)
import Numeric.AD.Rank1.Forward
import Symtegration.Symbolic
import Symtegration.Symbolic.Simplify

-- $setup
-- >>> import Symtegration.Symbolic.Haskell

-- | Differentiates a mathematical expression.
--
-- >>> toHaskell $ differentiate "x" $ "x" ** 2
-- "2 * x"
-- >>> toHaskell $ differentiate "x" $ "a" * sin "x"
-- "a * cos x"
--
-- This uses [Numeric.AD](https://hackage.haskell.org/package/ad).
differentiate ::
  -- | Symbol representing the variable.
  Text ->
  -- | Symbolic representation of the mathematical expression to differentiate.
  Expression ->
  -- | The derivative.
  Expression
differentiate :: Text -> Expression -> Expression
differentiate Text
v Expression
e = Expression -> Expression
tidy (Expression -> Expression) -> Expression -> Expression
forall a b. (a -> b) -> a -> b
$ Text -> Expression -> Expression
simplifyForVariable Text
v (Expression -> Expression) -> Expression -> Expression
forall a b. (a -> b) -> a -> b
$ (Forward Expression -> Forward Expression)
-> Expression -> Expression
forall a. Num a => (Forward a -> Forward a) -> a -> a
diff Forward Expression -> Forward Expression
f (Expression -> Expression) -> Expression -> Expression
forall a b. (a -> b) -> a -> b
$ Text -> Expression
Symbol Text
v
  where
    f :: Forward Expression -> Forward Expression
f = Expression
-> (Text -> Forward Expression -> Forward Expression)
-> Forward Expression
-> Forward Expression
forall b a. Floating b => Expression -> (Text -> a -> b) -> a -> b
toFunction Expression
e Text -> Forward Expression -> Forward Expression
forall {a}. (Scalar a ~ Expression, Mode a) => Text -> a -> a
assign
    assign :: Text -> a -> a
assign Text
x
      | Text
v Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
x = a -> a
forall a. a -> a
id
      | Bool
otherwise = a -> a -> a
forall a b. a -> b -> a
const (a -> a -> a) -> a -> a -> a
forall a b. (a -> b) -> a -> b
$ Scalar a -> a
forall t. Mode t => Scalar t -> t
auto (Scalar a -> a) -> Scalar a -> a
forall a b. (a -> b) -> a -> b
$ Text -> Expression
Symbol Text
x