haskell - How to do automatic differentiation on complex datatypes? -


given simple matrix definition based on vector:

import numeric.ad import qualified data.vector v  newtype mat = mat { unmat :: v.vector }  scale' f = mat . v.map (*f) . unmat add' b = mat $ v.zipwith (+) (unmat a) (unmat b) sub' b = mat $ v.zipwith (-) (unmat a) (unmat b) mul' b = mat $ v.zipwith (*) (unmat a) (unmat b) pow' e = mat $ v.map (^e) (unmat a)  sumelems' :: num => mat -> sumelems' = v.sum . unmat 

(for demonstration purposes ... using hmatrix thought problem there somehow)

and error function (eq3):

eq1' :: num => [a] -> [mat a] -> mat eq1' φs = foldl1 add' $ zipwith scale' φs  eq3' :: num => mat -> [a] -> [mat a] -> eq3' img φs = negate $ sumelems' (errimg `pow'` (2::int))   errimg = img `sub'` (eq1' φs) 

why compiler not able deduce right types in this?

difftest :: forall . (fractional a, ord a) => mat -> [mat a] -> [a] -> [[a]] difftest m φs as0 = gradientdescent go as0   go xs = eq3' m xs φs 

the exact error message this:

src/stuff.hs:59:37:     not deduce (a ~ numeric.ad.internal.reverse.reverse s a)     context (fractional a, ord a)       bound type signature                  difftest :: (fractional a, ord a) =>                              mat -> [mat a] -> [a] -> [[a]]       @ src/stuff.hs:58:13-69     or (reflection-1.5.1.2:data.reflection.reifies                s numeric.ad.internal.reverse.tape)       bound type expected context:                  reflection-1.5.1.2:data.reflection.reifies                    s numeric.ad.internal.reverse.tape =>                  [numeric.ad.internal.reverse.reverse s a]                  -> numeric.ad.internal.reverse.reverse s       @ src/stuff.hs:59:21-42       ‘a’ rigid type variable bound           type signature             difftest :: (fractional a, ord a) =>                         mat -> [mat a] -> [a] -> [[a]]           @ src//stuff.hs:58:13     expected type: [numeric.ad.internal.reverse.reverse s a]                    -> numeric.ad.internal.reverse.reverse s       actual type: [a] ->     relevant bindings include       go :: [a] -> (bound @ src/stuff.hs:60:9)       as0 :: [a] (bound @ src/stuff.hs:59:15)       φs :: [mat a] (bound @ src/stuff.hs:59:12)       m :: mat (bound @ src/stuff.hs:59:10)       difftest :: mat -> [mat a] -> [a] -> [[a]]         (bound @ src/stuff.hs:59:1)     in first argument of ‘gradientdescent’, namely ‘go’     in expression: gradientdescent go as0 

the gradientdescent function ad has type

gradientdescent :: (traversable f, fractional a, ord a) =>                    (forall s. reifies s tape => f (reverse s a) -> reverse s a) ->                    f -> [f a] 

its first argument requires function of type f r -> r r forall s. (reverse s a). go has type [a] -> a a type bound in signature of difftest. these as same, reverse s a isn't same a.

the reverse type has instances number of type classes allow convert a reverse s a or back. obvious fractional => fractional (reverse s a) allow convert as reverse s as realtofrac.

to so, we'll need able map function a -> b on mat a obtain mat b. easiest way derive functor instance mat.

{-# language derivefunctor #-}  newtype mat = mat { unmat :: v.vector }     deriving functor 

we can convert m , fs fractional a' => mat a' fmap realtofrac.

difftest m fs as0 = gradientdescent go as0   go xs = eq3' (fmap realtofrac m) xs (fmap (fmap realtofrac) fs) 

but there's better way hiding in ad package. reverse s a universally qualified on s a same a 1 bound in type signature difftest. need function a -> (forall s. reverse s a). function auto mode class, reverse s a has instance. auto has wierd type mode t => scalar t -> t type scalar (reverse s a) = a. specialized reverse auto has type

auto :: (reifies s tape, num a) => -> reverse s 

this allows convert our mat as mat (reverse s a)s without messing around conversions , rational.

{-# language scopedtypevariables #-} {-# language typefamilies #-}  difftest :: forall . (fractional a, ord a) => mat -> [mat a] -> [a] -> [[a]] difftest m fs as0 = gradientdescent go as0       go :: forall t. (scalar t ~ a, mode t) => [t] -> t     go xs = eq3' (fmap auto m) xs (fmap (fmap auto) fs) 

Comments

Popular posts from this blog

Payment information shows nothing in one page checkout page magento -

tcpdump - How to check if server received packet (acknowledged) -