haskell - How to do automatic differentiation on complex datatypes? -
given simple matrix definition based on vector:
import numeric.ad import qualified data.vector v newtype mat = mat { unmat :: v.vector } scale' f = mat . v.map (*f) . unmat add' b = mat $ v.zipwith (+) (unmat a) (unmat b) sub' b = mat $ v.zipwith (-) (unmat a) (unmat b) mul' b = mat $ v.zipwith (*) (unmat a) (unmat b) pow' e = mat $ v.map (^e) (unmat a) sumelems' :: num => mat -> sumelems' = v.sum . unmat
(for demonstration purposes ... using hmatrix thought problem there somehow)
and error function (eq3
):
eq1' :: num => [a] -> [mat a] -> mat eq1' φs = foldl1 add' $ zipwith scale' φs eq3' :: num => mat -> [a] -> [mat a] -> eq3' img φs = negate $ sumelems' (errimg `pow'` (2::int)) errimg = img `sub'` (eq1' φs)
why compiler not able deduce right types in this?
difftest :: forall . (fractional a, ord a) => mat -> [mat a] -> [a] -> [[a]] difftest m φs as0 = gradientdescent go as0 go xs = eq3' m xs φs
the exact error message this:
src/stuff.hs:59:37: not deduce (a ~ numeric.ad.internal.reverse.reverse s a) context (fractional a, ord a) bound type signature difftest :: (fractional a, ord a) => mat -> [mat a] -> [a] -> [[a]] @ src/stuff.hs:58:13-69 or (reflection-1.5.1.2:data.reflection.reifies s numeric.ad.internal.reverse.tape) bound type expected context: reflection-1.5.1.2:data.reflection.reifies s numeric.ad.internal.reverse.tape => [numeric.ad.internal.reverse.reverse s a] -> numeric.ad.internal.reverse.reverse s @ src/stuff.hs:59:21-42 ‘a’ rigid type variable bound type signature difftest :: (fractional a, ord a) => mat -> [mat a] -> [a] -> [[a]] @ src//stuff.hs:58:13 expected type: [numeric.ad.internal.reverse.reverse s a] -> numeric.ad.internal.reverse.reverse s actual type: [a] -> relevant bindings include go :: [a] -> (bound @ src/stuff.hs:60:9) as0 :: [a] (bound @ src/stuff.hs:59:15) φs :: [mat a] (bound @ src/stuff.hs:59:12) m :: mat (bound @ src/stuff.hs:59:10) difftest :: mat -> [mat a] -> [a] -> [[a]] (bound @ src/stuff.hs:59:1) in first argument of ‘gradientdescent’, namely ‘go’ in expression: gradientdescent go as0
the gradientdescent
function ad
has type
gradientdescent :: (traversable f, fractional a, ord a) => (forall s. reifies s tape => f (reverse s a) -> reverse s a) -> f -> [f a]
its first argument requires function of type f r -> r
r
forall s. (reverse s a)
. go
has type [a] -> a
a
type bound in signature of difftest
. these a
s same, reverse s a
isn't same a
.
the reverse
type has instances number of type classes allow convert a
reverse s a
or back. obvious fractional => fractional (reverse s a)
allow convert a
s reverse s a
s realtofrac
.
to so, we'll need able map function a -> b
on mat a
obtain mat b
. easiest way derive functor
instance mat
.
{-# language derivefunctor #-} newtype mat = mat { unmat :: v.vector } deriving functor
we can convert m
, fs
fractional a' => mat a'
fmap realtofrac
.
difftest m fs as0 = gradientdescent go as0 go xs = eq3' (fmap realtofrac m) xs (fmap (fmap realtofrac) fs)
but there's better way hiding in ad package. reverse s a
universally qualified on s
a
same a
1 bound in type signature difftest
. need function a -> (forall s. reverse s a)
. function auto
mode
class, reverse s a
has instance. auto
has wierd type mode t => scalar t -> t
type scalar (reverse s a) = a
. specialized reverse
auto
has type
auto :: (reifies s tape, num a) => -> reverse s
this allows convert our mat a
s mat (reverse s a)
s without messing around conversions , rational
.
{-# language scopedtypevariables #-} {-# language typefamilies #-} difftest :: forall . (fractional a, ord a) => mat -> [mat a] -> [a] -> [[a]] difftest m fs as0 = gradientdescent go as0 go :: forall t. (scalar t ~ a, mode t) => [t] -> t go xs = eq3' (fmap auto m) xs (fmap (fmap auto) fs)
Comments
Post a Comment