(original) (raw)
{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-}
module TensorFlow.Tensor where
import Data.ByteString (ByteString) import Data.String (IsString(..)) import qualified Data.Text as Text import Lens.Family2 ((^.)) import Lens.Family2.State ((%=), use)
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (device) import TensorFlow.Build import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..)) import TensorFlow.Types ( TensorType , TensorData(..) , ListOf(..) ) import qualified TensorFlow.Internal.FFI as FFI
data Tensor v a where Tensor :: TensorKind v => {Tensor v a -> v Output tensorOutput :: v Output} -> Tensor v a
newtype Value a = Value {Value a -> a runValue :: a} deriving a -> Value b -> Value a (a -> b) -> Value a -> Value b (forall a b. (a -> b) -> Value a -> Value b) -> (forall a b. a -> Value b -> Value a) -> Functor Value forall a b. a -> Value b -> Value a forall a b. (a -> b) -> Value a -> Value b forall (f :: * -> *). (forall a b. (a -> b) -> f a -> f b) -> (forall a b. a -> f b -> f a) -> Functor f <$ :: a -> Value b -> Value a c<c<c< :: forall a b. a -> Value b -> Value a fmap :: (a -> b) -> Value a -> Value b $cfmap :: forall a b. (a -> b) -> Value a -> Value b Functor
instance Applicative Value where pure :: a -> Value a pure = a -> Value a forall a. a -> Value a Value Value f :: a -> b f <*> :: Value (a -> b) -> Value a -> Value b <*> Value x :: a x = b -> Value b forall a. a -> Value a Value (b -> Value b) -> b -> Value b forall a b. (a -> b) -> a -> b $ a -> b f a x
instance Monad Value where f :: Value a f >>= :: Value a -> (a -> Value b) -> Value b
= g :: a -> Value b g = a -> Value b g (a -> Value b) -> a -> Value b forall a b. (a -> b) -> a -> b $ Value a -> a forall a. Value a -> a runValue Value a f
newtype Ref a = Ref {Ref a -> a runRef :: a} deriving a -> Ref b -> Ref a (a -> b) -> Ref a -> Ref b (forall a b. (a -> b) -> Ref a -> Ref b) -> (forall a b. a -> Ref b -> Ref a) -> Functor Ref forall a b. a -> Ref b -> Ref a forall a b. (a -> b) -> Ref a -> Ref b forall (f :: * -> *). (forall a b. (a -> b) -> f a -> f b) -> (forall a b. a -> f b -> f a) -> Functor f <$ :: a -> Ref b -> Ref a c<c<c< :: forall a b. a -> Ref b -> Ref a fmap :: (a -> b) -> Ref a -> Ref b $cfmap :: forall a b. (a -> b) -> Ref a -> Ref b Functor
instance Applicative Ref where pure :: a -> Ref a pure = a -> Ref a forall a. a -> Ref a Ref Ref f :: a -> b f <*> :: Ref (a -> b) -> Ref a -> Ref b <*> Ref x :: a x = b -> Ref b forall a. a -> Ref a Ref (b -> Ref b) -> b -> Ref b forall a b. (a -> b) -> a -> b $ a -> b f a x
instance Monad Ref where f :: Ref a f >>= :: Ref a -> (a -> Ref b) -> Ref b
= g :: a -> Ref b g = a -> Ref b g (a -> Ref b) -> a -> Ref b forall a b. (a -> b) -> a -> b $ Ref a -> a forall a. Ref a -> a runRef Ref a f
value :: Tensor Ref a -> Tensor Value a value :: Tensor Ref a -> Tensor Value a value (Tensor o :: Ref Output o) = Value Output -> Tensor Value a forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a Tensor (Value Output -> Tensor Value a) -> Value Output -> Tensor Value a forall a b. (a -> b) -> a -> b $ Output -> Value Output forall a. a -> Value a Value (Output -> Value Output) -> Output -> Value Output forall a b. (a -> b) -> a -> b $ Ref Output -> Output forall a. Ref a -> a runRef Ref Output o
renderValue :: MonadBuild m => Tensor v a -> m (Tensor Value a) renderValue :: Tensor v a -> m (Tensor Value a) renderValue (Tensor o :: v Output o) = Tensor Build a -> m (Tensor Value a) forall (m :: * -> *) a. MonadBuild m => Tensor Build a -> m (Tensor Value a) render (Tensor Build a -> m (Tensor Value a)) -> Tensor Build a -> m (Tensor Value a) forall a b. (a -> b) -> a -> b $ BuildT Identity Output -> Tensor Build a forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a Tensor (BuildT Identity Output -> Tensor Build a) -> BuildT Identity Output -> Tensor Build a forall a b. (a -> b) -> a -> b $ v Output -> BuildT Identity Output forall (v :: * -> *) a. TensorKind v => v a -> Build a toBuild v Output o
data Feed = Feed Output FFI.TensorData
class Rendered t where renderedOutput :: t a -> Output
instance Rendered (Tensor Value) where renderedOutput :: Tensor Value a -> Output renderedOutput = Value Output -> Output forall a. Value a -> a runValue (Value Output -> Output) -> (Tensor Value a -> Value Output) -> Tensor Value a -> Output forall b c a. (b -> c) -> (a -> b) -> a -> c . Tensor Value a -> Value Output forall (v :: * -> *) a. Tensor v a -> v Output tensorOutput
instance Rendered (Tensor Ref) where renderedOutput :: Tensor Ref a -> Output renderedOutput = Ref Output -> Output forall a. Ref a -> a runRef (Ref Output -> Output) -> (Tensor Ref a -> Ref Output) -> Tensor Ref a -> Output forall b c a. (b -> c) -> (a -> b) -> a -> c . Tensor Ref a -> Ref Output forall (v :: * -> *) a. Tensor v a -> v Output tensorOutput
tensorNodeName :: Rendered t => t a -> NodeName tensorNodeName :: t a -> NodeName tensorNodeName = Output -> NodeName outputNodeName (Output -> NodeName) -> (t a -> Output) -> t a -> NodeName forall b c a. (b -> c) -> (a -> b) -> a -> c . t a -> Output forall (t :: * -> *) a. Rendered t => t a -> Output renderedOutput
feed :: Rendered t => t a -> TensorData a -> Feed feed :: t a -> TensorData a -> Feed feed t :: t a t (TensorData td :: TensorData td) = Output -> TensorData -> Feed Feed (t a -> Output forall (t :: * -> *) a. Rendered t => t a -> Output renderedOutput t a t) TensorData td
tensorFromName :: TensorKind v => Text.Text -> Tensor v a tensorFromName :: Text -> Tensor v a tensorFromName = v Output -> Tensor v a forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a Tensor (v Output -> Tensor v a) -> (Text -> v Output) -> Text -> Tensor v a forall b c a. (b -> c) -> (a -> b) -> a -> c . Output -> v Output forall (f :: * -> *) a. Applicative f => a -> f a pure (Output -> v Output) -> (Text -> Output) -> Text -> v Output forall b c a. (b -> c) -> (a -> b) -> a -> c . String -> Output forall a. IsString a => String -> a fromString (String -> Output) -> (Text -> String) -> Text -> Output forall b c a. (b -> c) -> (a -> b) -> a -> c . Text -> String Text.unpack
tensorValueFromName :: Text.Text -> Tensor Value a tensorValueFromName :: Text -> Tensor Value a tensorValueFromName = Text -> Tensor Value a forall (v :: * -> *) a. TensorKind v => Text -> Tensor v a tensorFromName
tensorRefFromName :: Text.Text -> Tensor Ref a tensorRefFromName :: Text -> Tensor Ref a tensorRefFromName = Text -> Tensor Ref a forall (v :: * -> *) a. TensorKind v => Text -> Tensor v a tensorFromName
type TensorList v = ListOf (Tensor v)
tensorListOutputs :: Rendered (Tensor v) => TensorList v as -> [Output] tensorListOutputs :: TensorList v as -> [Output] tensorListOutputs Nil = [] tensorListOutputs (t :: Tensor v a t :/ ts :: ListOf (Tensor v) as ts) = Tensor v a -> Output forall (t :: * -> *) a. Rendered t => t a -> Output renderedOutput Tensor v a t Output -> [Output] -> [Output] forall a. a -> [a] -> [a] : ListOf (Tensor v) as -> [Output] forall (v :: * -> ) (as :: []). Rendered (Tensor v) => TensorList v as -> [Output] tensorListOutputs ListOf (Tensor v) as ts
colocateWith :: (MonadBuild m, Rendered t) => t b -> m a -> m a colocateWith :: t b -> m a -> m a colocateWith t :: t b t x :: m a x = do Device d <- Build Device -> m Device forall (m :: * -> *) a. MonadBuild m => Build a -> m a build (Build Device -> m Device) -> Build Device -> m Device forall a b. (a -> b) -> a -> b $ Text -> Device Device (Text -> Device) -> (NodeDef -> Text) -> NodeDef -> Device forall b c a. (b -> c) -> (a -> b) -> a -> c . (NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text forall s a t b. s -> FoldLike a s t a b -> a ^. FoldLike Text NodeDef NodeDef Text Text forall (f :: * -> *) s a. (Functor f, HasField s "device" a) => LensLike' f s a device) (NodeDef -> Device) -> BuildT Identity NodeDef -> Build Device forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> NodeName -> BuildT Identity NodeDef lookupNode (Output -> NodeName outputNodeName (Output -> NodeName) -> Output -> NodeName forall a b. (a -> b) -> a -> b $ t b -> Output forall (t :: * -> *) a. Rendered t => t a -> Output renderedOutput t b t) Maybe Device -> m a -> m a forall (m :: * -> *) a. MonadBuild m => Maybe Device -> m a -> m a withDevice (Device -> Maybe Device forall a. a -> Maybe a Just Device d) m a x
render :: MonadBuild m => Tensor Build a -> m (Tensor Value a) render :: Tensor Build a -> m (Tensor Value a) render (Tensor t :: BuildT Identity Output t) = Value Output -> Tensor Value a forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a Tensor (Value Output -> Tensor Value a) -> (Output -> Value Output) -> Output -> Tensor Value a forall b c a. (b -> c) -> (a -> b) -> a -> c . Output -> Value Output forall a. a -> Value a Value (Output -> Tensor Value a) -> m Output -> m (Tensor Value a) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> BuildT Identity Output -> m Output forall (m :: * -> *) a. MonadBuild m => Build a -> m a build BuildT Identity Output t
expr :: TensorKind v => Tensor v a -> Tensor Build a expr :: Tensor v a -> Tensor Build a expr (Tensor o :: v Output o) = BuildT Identity Output -> Tensor Build a forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a Tensor (BuildT Identity Output -> Tensor Build a) -> BuildT Identity Output -> Tensor Build a forall a b. (a -> b) -> a -> b $ v Output -> BuildT Identity Output forall (v :: * -> *) a. TensorKind v => v a -> Build a toBuild v Output o
addSummary :: (MonadBuild m, TensorKind v) => Tensor v ByteString -> m () addSummary :: Tensor v ByteString -> m () addSummary t :: Tensor v ByteString t = Build () -> m () forall (m :: * -> *) a. MonadBuild m => Build a -> m a build (Build () -> m ()) -> Build () -> m () forall a b. (a -> b) -> a -> b $ do
Outputo <- v Output -> BuildT Identity Output forall (v :: * -> *) a. TensorKind v => v a -> Build a toBuild (v Output -> BuildT Identity Output) -> v Output -> BuildT Identity Output forall a b. (a -> b) -> a -> b $ Tensor v ByteString -> v Output forall (v :: * -> *) a. Tensor v a -> v Output tensorOutput Tensor v ByteString t Lens' GraphState [Output] forall (f :: * -> *). Identical f => LensLike' f GraphState [Output] summaries (forall (f :: * -> *). Identical f => LensLike' f GraphState [Output]) -> ([Output] -> [Output]) -> Build () forall s (m :: * -> *) a b. MonadState s m => Setter s s a b -> (a -> b) -> m () %= (Output o Output -> [Output] -> [Output] forall a. a -> [a] -> [a] :)
collectAllSummaries :: MonadBuild m => m [SummaryTensor] collectAllSummaries :: m [SummaryTensor] collectAllSummaries = Build [SummaryTensor] -> m [SummaryTensor] forall (m :: * -> *) a. MonadBuild m => Build a -> m a build (Build [SummaryTensor] -> m [SummaryTensor]) -> Build [SummaryTensor] -> m [SummaryTensor] forall a b. (a -> b) -> a -> b $ (Output -> SummaryTensor) -> [Output] -> [SummaryTensor] forall a b. (a -> b) -> [a] -> [b] map (Value Output -> SummaryTensor forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a Tensor (Value Output -> SummaryTensor) -> (Output -> Value Output) -> Output -> SummaryTensor forall b c a. (b -> c) -> (a -> b) -> a -> c . Output -> Value Output forall a. a -> Value a Value) ([Output] -> [SummaryTensor]) -> BuildT Identity [Output] -> Build [SummaryTensor] forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> FoldLike [Output] GraphState GraphState [Output] [Output] -> BuildT Identity [Output] forall s (m :: * -> *) a t b. MonadState s m => FoldLike a s t a b -> m a use FoldLike [Output] GraphState GraphState [Output] [Output] Lens' GraphState [Output] summaries
type SummaryTensor = Tensor Value ByteString
class Monad v => TensorKind v where toBuild :: v a -> Build a
instance TensorKind Value where toBuild :: Value a -> Build a toBuild = a -> Build a forall (m :: * -> *) a. Monad m => a -> m a return (a -> Build a) -> (Value a -> a) -> Value a -> Build a forall b c a. (b -> c) -> (a -> b) -> a -> c . Value a -> a forall a. Value a -> a runValue
instance TensorKind Ref where toBuild :: Ref a -> Build a toBuild = a -> Build a forall (m :: * -> *) a. Monad m => a -> m a return (a -> Build a) -> (Ref a -> a) -> Ref a -> Build a forall b c a. (b -> c) -> (a -> b) -> a -> c . Ref a -> a forall a. Ref a -> a runRef
instance TensorKind Build where toBuild :: Build a -> Build a toBuild = Build a -> Build a forall a. a -> a id
class ToTensor t where toTensor :: TensorType a => t a -> Tensor Build a
instance TensorKind v => ToTensor (Tensor v) where toTensor :: Tensor v a -> Tensor Build a toTensor = Tensor v a -> Tensor Build a forall (v :: * -> *) a. TensorKind v => Tensor v a -> Tensor Build a expr