(original) (raw)
{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-}
module TensorFlow.Session ( Session, SessionT, Options, sessionConfig, sessionTarget, sessionTracer, runSession, runSessionWithOptions, MonadBuild(..), extend, addGraphDef, run, runWithFeeds, run_, runWithFeeds_, asyncProdNodes, ) where
import Data.ProtoLens.Message(defMessage) import Control.Monad (forever, unless, void) import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask) import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.Trans.Class (MonadTrans, lift) import Control.Monad.Trans.Reader (ReaderT(..), ask, asks) import Data.ByteString (ByteString) import Data.Default (Default, def) import Data.ProtoLens (showMessage) import Data.Set (Set) import Data.Text.Encoding (encodeUtf8) import Lens.Family2 (Lens', (^.), (&), (.~)) import Lens.Family2.Unchecked (lens) import Proto.Tensorflow.Core.Framework.Graph (GraphDef) import Proto.Tensorflow.Core.Framework.Graph_Fields (node) import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto) import TensorFlow.Build import TensorFlow.Nodes import TensorFlow.Output (NodeName, unNodeName) import TensorFlow.Tensor
import qualified Data.ByteString.Builder as Builder import qualified Data.Map.Strict as Map import qualified Data.Set as Set import qualified TensorFlow.Internal.FFI as FFI
type Tracer = Builder.Builder -> IO ()
data SessionState = SessionState { SessionState -> Session rawSession :: FFI.Session , SessionState -> IO () -> IO () asyncCollector :: IO () -> IO ()
, SessionState -> Tracer
newtype SessionT m a = Session (ReaderT SessionState (BuildT m) a) deriving (a -> SessionT m b -> SessionT m a (a -> b) -> SessionT m a -> SessionT m b (forall a b. (a -> b) -> SessionT m a -> SessionT m b) -> (forall a b. a -> SessionT m b -> SessionT m a) -> Functor (SessionT m) forall a b. a -> SessionT m b -> SessionT m a forall a b. (a -> b) -> SessionT m a -> SessionT m b forall (m :: * -> *) a b. Functor m => a -> SessionT m b -> SessionT m a forall (m :: * -> *) a b. Functor m => (a -> b) -> SessionT m a -> SessionT m b forall (f :: * -> *). (forall a b. (a -> b) -> f a -> f b) -> (forall a b. a -> f b -> f a) -> Functor f <$ :: a -> SessionT m b -> SessionT m a c<c<c< :: forall (m :: * -> *) a b. Functor m => a -> SessionT m b -> SessionT m a fmap :: (a -> b) -> SessionT m a -> SessionT m b $cfmap :: forall (m :: * -> *) a b. Functor m => (a -> b) -> SessionT m a -> SessionT m b Functor, Functor (SessionT m) a -> SessionT m a Functor (SessionT m) => (forall a. a -> SessionT m a) -> (forall a b. SessionT m (a -> b) -> SessionT m a -> SessionT m b) -> (forall a b c. (a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c) -> (forall a b. SessionT m a -> SessionT m b -> SessionT m b) -> (forall a b. SessionT m a -> SessionT m b -> SessionT m a) -> Applicative (SessionT m) SessionT m a -> SessionT m b -> SessionT m b SessionT m a -> SessionT m b -> SessionT m a SessionT m (a -> b) -> SessionT m a -> SessionT m b (a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c forall a. a -> SessionT m a forall a b. SessionT m a -> SessionT m b -> SessionT m a forall a b. SessionT m a -> SessionT m b -> SessionT m b forall a b. SessionT m (a -> b) -> SessionT m a -> SessionT m b forall a b c. (a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c forall (m :: * -> *). Monad m => Functor (SessionT m) forall (m :: * -> *) a. Monad m => a -> SessionT m a forall (m :: * -> *) a b. Monad m => SessionT m a -> SessionT m b -> SessionT m a forall (m :: * -> *) a b. Monad m => SessionT m a -> SessionT m b -> SessionT m b forall (m :: * -> *) a b. Monad m => SessionT m (a -> b) -> SessionT m a -> SessionT m b forall (m :: * -> *) a b c. Monad m => (a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c forall (f :: * -> *). Functor f => (forall a. a -> f a) -> (forall a b. f (a -> b) -> f a -> f b) -> (forall a b c. (a -> b -> c) -> f a -> f b -> f c) -> (forall a b. f a -> f b -> f b) -> (forall a b. f a -> f b -> f a) -> Applicative f <* :: SessionT m a -> SessionT m b -> SessionT m a $c<* :: forall (m :: * -> *) a b. Monad m => SessionT m a -> SessionT m b -> SessionT m a > :: SessionT m a -> SessionT m b -> SessionT m b $c> :: forall (m :: * -> *) a b. Monad m => SessionT m a -> SessionT m b -> SessionT m b liftA2 :: (a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c $cliftA2 :: forall (m :: * -> *) a b c. Monad m => (a -> b -> c) -> SessionT m a -> SessionT m b -> SessionT m c <*> :: SessionT m (a -> b) -> SessionT m a -> SessionT m b $c<*> :: forall (m :: * -> *) a b. Monad m => SessionT m (a -> b) -> SessionT m a -> SessionT m b pure :: a -> SessionT m a $cpure :: forall (m :: * -> *) a. Monad m => a -> SessionT m a $cp1Applicative :: forall (m :: * -> *). Monad m => Functor (SessionT m) Applicative, Applicative (SessionT m) a -> SessionT m a Applicative (SessionT m) => (forall a b. SessionT m a -> (a -> SessionT m b) -> SessionT m b) -> (forall a b. SessionT m a -> SessionT m b -> SessionT m b) -> (forall a. a -> SessionT m a) -> Monad (SessionT m) SessionT m a -> (a -> SessionT m b) -> SessionT m b SessionT m a -> SessionT m b -> SessionT m b forall a. a -> SessionT m a forall a b. SessionT m a -> SessionT m b -> SessionT m b forall a b. SessionT m a -> (a -> SessionT m b) -> SessionT m b forall (m :: * -> *). Monad m => Applicative (SessionT m) forall (m :: * -> *) a. Monad m => a -> SessionT m a forall (m :: * -> *) a b. Monad m => SessionT m a -> SessionT m b -> SessionT m b forall (m :: * -> *) a b. Monad m => SessionT m a -> (a -> SessionT m b) -> SessionT m b forall (m :: * -> *). Applicative m => (forall a b. m a -> (a -> m b) -> m b) -> (forall a b. m a -> m b -> m b) -> (forall a. a -> m a) -> Monad m return :: a -> SessionT m a $creturn :: forall (m :: * -> *) a. Monad m => a -> SessionT m a
:: SessionT m a -> SessionT m b -> SessionT m b $c>> :: forall (m :: * -> *) a b. Monad m => SessionT m a -> SessionT m b -> SessionT m b = :: SessionT m a -> (a -> SessionT m b) -> SessionT m b $c>>= :: forall (m :: * -> *) a b. Monad m => SessionT m a -> (a -> SessionT m b) -> SessionT m b $cp1Monad :: forall (m :: * -> *). Monad m => Applicative (SessionT m) Monad, Monad (SessionT m) Monad (SessionT m) => (forall a. IO a -> SessionT m a) -> MonadIO (SessionT m) IO a -> SessionT m a forall a. IO a -> SessionT m a forall (m :: * -> *). Monad m => (forall a. IO a -> m a) -> MonadIO m forall (m :: * -> *). MonadIO m => Monad (SessionT m) forall (m :: * -> *) a. MonadIO m => IO a -> SessionT m a liftIO :: IO a -> SessionT m a $cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> SessionT m a $cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (SessionT m) MonadIO, Monad (SessionT m) e -> SessionT m a Monad (SessionT m) => (forall e a. Exception e => e -> SessionT m a) -> MonadThrow (SessionT m) forall e a. Exception e => e -> SessionT m a forall (m :: * -> *). Monad m => (forall e a. Exception e => e -> m a) -> MonadThrow m forall (m :: * -> *). MonadThrow m => Monad (SessionT m) forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> SessionT m a throwM :: e -> SessionT m a $cthrowM :: forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> SessionT m a $cp1MonadThrow :: forall (m :: * -> *). MonadThrow m => Monad (SessionT m) MonadThrow, MonadThrow (SessionT m) MonadThrow (SessionT m) => (forall e a. Exception e => SessionT m a -> (e -> SessionT m a) -> SessionT m a) -> MonadCatch (SessionT m) SessionT m a -> (e -> SessionT m a) -> SessionT m a forall e a. Exception e => SessionT m a -> (e -> SessionT m a) -> SessionT m a forall (m :: * -> *). MonadCatch m => MonadThrow (SessionT m) forall (m :: * -> *) e a. (MonadCatch m, Exception e) => SessionT m a -> (e -> SessionT m a) -> SessionT m a forall (m :: * -> *). MonadThrow m => (forall e a. Exception e => m a -> (e -> m a) -> m a) -> MonadCatch m catch :: SessionT m a -> (e -> SessionT m a) -> SessionT m a $ccatch :: forall (m :: * -> *) e a. (MonadCatch m, Exception e) => SessionT m a -> (e -> SessionT m a) -> SessionT m a $cp1MonadCatch :: forall (m :: * -> *). MonadCatch m => MonadThrow (SessionT m) MonadCatch, MonadCatch (SessionT m) MonadCatch (SessionT m) => (forall b. ((forall a. SessionT m a -> SessionT m a) -> SessionT m b) -> SessionT m b) -> (forall b. ((forall a. SessionT m a -> SessionT m a) -> SessionT m b) -> SessionT m b) -> (forall a b c. SessionT m a -> (a -> ExitCase b -> SessionT m c) -> (a -> SessionT m b) -> SessionT m (b, c)) -> MonadMask (SessionT m) SessionT m a -> (a -> ExitCase b -> SessionT m c) -> (a -> SessionT m b) -> SessionT m (b, c) ((forall a. SessionT m a -> SessionT m a) -> SessionT m b) -> SessionT m b ((forall a. SessionT m a -> SessionT m a) -> SessionT m b) -> SessionT m b forall b. ((forall a. SessionT m a -> SessionT m a) -> SessionT m b) -> SessionT m b forall a b c. SessionT m a -> (a -> ExitCase b -> SessionT m c) -> (a -> SessionT m b) -> SessionT m (b, c) forall (m :: * -> *). MonadCatch m => (forall b. ((forall a. m a -> m a) -> m b) -> m b) -> (forall b. ((forall a. m a -> m a) -> m b) -> m b) -> (forall a b c. m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)) -> MonadMask m forall (m :: * -> *). MonadMask m => MonadCatch (SessionT m) forall (m :: * -> *) b. MonadMask m => ((forall a. SessionT m a -> SessionT m a) -> SessionT m b) -> SessionT m b forall (m :: * -> *) a b c. MonadMask m => SessionT m a -> (a -> ExitCase b -> SessionT m c) -> (a -> SessionT m b) -> SessionT m (b, c) generalBracket :: SessionT m a -> (a -> ExitCase b -> SessionT m c) -> (a -> SessionT m b) -> SessionT m (b, c) $cgeneralBracket :: forall (m :: * -> *) a b c. MonadMask m => SessionT m a -> (a -> ExitCase b -> SessionT m c) -> (a -> SessionT m b) -> SessionT m (b, c) uninterruptibleMask :: ((forall a. SessionT m a -> SessionT m a) -> SessionT m b) -> SessionT m b $cuninterruptibleMask :: forall (m :: * -> *) b. MonadMask m => ((forall a. SessionT m a -> SessionT m a) -> SessionT m b) -> SessionT m b mask :: ((forall a. SessionT m a -> SessionT m a) -> SessionT m b) -> SessionT m b $cmask :: forall (m :: * -> *) b. MonadMask m => ((forall a. SessionT m a -> SessionT m a) -> SessionT m b) -> SessionT m b $cp1MonadMask :: forall (m :: * -> *). MonadMask m => MonadCatch (SessionT m) MonadMask, Monad (SessionT m) Monad (SessionT m) => (forall a. String -> SessionT m a) -> MonadFail (SessionT m) String -> SessionT m a forall a. String -> SessionT m a forall (m :: * -> *). Monad m => (forall a. String -> m a) -> MonadFail m forall (m :: * -> *). MonadFail m => Monad (SessionT m) forall (m :: * -> *) a. MonadFail m => String -> SessionT m a fail :: String -> SessionT m a $cfail :: forall (m :: * -> *) a. MonadFail m => String -> SessionT m a $cp1MonadFail :: forall (m :: * -> *). MonadFail m => Monad (SessionT m) MonadFail)
instance MonadTrans SessionT where lift :: m a -> SessionT m a lift = ReaderT SessionState (BuildT m) a -> SessionT m a forall (m :: * -> ) a. ReaderT SessionState (BuildT m) a -> SessionT m a Session (ReaderT SessionState (BuildT m) a -> SessionT m a) -> (m a -> ReaderT SessionState (BuildT m) a) -> m a -> SessionT m a forall b c a. (b -> c) -> (a -> b) -> a -> c . BuildT m a -> ReaderT SessionState (BuildT m) a forall (t :: ( -> *) -> * -> *) (m :: * -> ) a. (MonadTrans t, Monad m) => m a -> t m a lift (BuildT m a -> ReaderT SessionState (BuildT m) a) -> (m a -> BuildT m a) -> m a -> ReaderT SessionState (BuildT m) a forall b c a. (b -> c) -> (a -> b) -> a -> c . m a -> BuildT m a forall (t :: ( -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift
runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a runSession :: SessionT m a -> m a runSession = Options -> SessionT m a -> m a forall (m :: * -> *) a. (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a runSessionWithOptions Options forall a. Default a => a def
data Options = Options { Options -> ByteString _sessionTarget :: ByteString , Options -> ConfigProto _sessionConfig :: ConfigProto , Options -> Tracer _sessionTracer :: Tracer }
instance Default Options where def :: Options def = Options :: ByteString -> ConfigProto -> Tracer -> Options Options { _sessionTarget :: ByteString _sessionTarget = "" , _sessionConfig :: ConfigProto _sessionConfig = ConfigProto forall msg. Message msg => msg defMessage , _sessionTracer :: Tracer _sessionTracer = IO () -> Tracer forall a b. a -> b -> a const (() -> IO () forall (m :: * -> *) a. Monad m => a -> m a return ()) }
sessionTarget :: Lens' Options ByteString sessionTarget :: LensLike' f Options ByteString sessionTarget = (Options -> ByteString) -> (Options -> ByteString -> Options) -> Lens Options Options ByteString ByteString forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b lens Options -> ByteString _sessionTarget (\g :: Options g x :: ByteString x -> Options g { _sessionTarget :: ByteString _sessionTarget = ByteString x })
sessionConfig :: Lens' Options ConfigProto sessionConfig :: LensLike' f Options ConfigProto sessionConfig = (Options -> ConfigProto) -> (Options -> ConfigProto -> Options) -> Lens Options Options ConfigProto ConfigProto forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b lens Options -> ConfigProto _sessionConfig (\g :: Options g x :: ConfigProto x -> Options g { _sessionConfig :: ConfigProto _sessionConfig = ConfigProto x })
sessionTracer :: Lens' Options Tracer sessionTracer :: LensLike' f Options Tracer sessionTracer = (Options -> Tracer) -> (Options -> Tracer -> Options) -> Lens Options Options Tracer Tracer forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b lens Options -> Tracer _sessionTracer (\g :: Options g x :: Tracer x -> Options g { _sessionTracer :: Tracer _sessionTracer = Tracer x })
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a runSessionWithOptions :: Options -> SessionT m a -> m a runSessionWithOptions options :: Options options (Session m :: ReaderT SessionState (BuildT m) a m) = (SessionOptions -> IO ()) -> ((IO () -> IO ()) -> Session -> m a) -> m a forall (m :: * -> *) a. (MonadIO m, MonadMask m) => (SessionOptions -> IO ()) -> ((IO () -> IO ()) -> Session -> m a) -> m a FFI.withSession SessionOptions -> IO () applyOptions (((IO () -> IO ()) -> Session -> m a) -> m a) -> ((IO () -> IO ()) -> Session -> m a) -> m a forall a b. (a -> b) -> a -> b $ \as :: IO () -> IO () as rs :: Session rs -> let initState :: SessionState initState = Session -> (IO () -> IO ()) -> Tracer -> SessionState SessionState Session rs IO () -> IO () as (Options options Options -> FoldLike Tracer Options Options Tracer Tracer -> Tracer forall s a t b. s -> FoldLike a s t a b -> a ^. FoldLike Tracer Options Options Tracer Tracer Lens Options Options Tracer Tracer sessionTracer) in BuildT m a -> m a forall (m :: * -> *) a. Monad m => BuildT m a -> m a evalBuildT (ReaderT SessionState (BuildT m) a -> SessionState -> BuildT m a forall r (m :: * -> *) a. ReaderT r m a -> r -> m a runReaderT ReaderT SessionState (BuildT m) a m SessionState initState) where applyOptions :: SessionOptions -> IO () applyOptions opt :: SessionOptions opt = do ByteString -> SessionOptions -> IO () FFI.setSessionTarget (Options options Options -> FoldLike ByteString Options Options ByteString ByteString -> ByteString forall s a t b. s -> FoldLike a s t a b -> a ^. FoldLike ByteString Options Options ByteString ByteString Lens Options Options ByteString ByteString sessionTarget) SessionOptions opt ConfigProto -> SessionOptions -> IO () FFI.setSessionConfig (Options options Options -> FoldLike ConfigProto Options Options ConfigProto ConfigProto -> ConfigProto forall s a t b. s -> FoldLike a s t a b -> a ^. FoldLike ConfigProto Options Options ConfigProto ConfigProto Lens Options Options ConfigProto ConfigProto sessionConfig) SessionOptions opt
instance Monad m => MonadBuild (SessionT m) where build :: Build a -> SessionT m a build = ReaderT SessionState (BuildT m) a -> SessionT m a forall (m :: * -> ) a. ReaderT SessionState (BuildT m) a -> SessionT m a Session (ReaderT SessionState (BuildT m) a -> SessionT m a) -> (Build a -> ReaderT SessionState (BuildT m) a) -> Build a -> SessionT m a forall b c a. (b -> c) -> (a -> b) -> a -> c . BuildT m a -> ReaderT SessionState (BuildT m) a forall (t :: ( -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (BuildT m a -> ReaderT SessionState (BuildT m) a) -> (Build a -> BuildT m a) -> Build a -> ReaderT SessionState (BuildT m) a forall b c a. (b -> c) -> (a -> b) -> a -> c . Build a -> BuildT m a forall (m :: * -> *) a. MonadBuild m => Build a -> m a build
extend :: MonadIO m => SessionT m () extend :: SessionT m () extend = do Session session <- ReaderT SessionState (BuildT m) Session -> SessionT m Session forall (m :: * -> *) a. ReaderT SessionState (BuildT m) a -> SessionT m a Session ((SessionState -> Session) -> ReaderT SessionState (BuildT m) Session forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a asks SessionState -> Session rawSession) Tracer trace <- ReaderT SessionState (BuildT m) Tracer -> SessionT m Tracer forall (m :: * -> *) a. ReaderT SessionState (BuildT m) a -> SessionT m a Session ((SessionState -> Tracer) -> ReaderT SessionState (BuildT m) Tracer forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a asks SessionState -> Tracer tracer) [NodeDef] nodesToExtend <- Build [NodeDef] -> SessionT m [NodeDef] forall (m :: * -> *) a. MonadBuild m => Build a -> m a build Build [NodeDef] forall (m :: * -> *). MonadBuild m => m [NodeDef] flushNodeBuffer Bool -> SessionT m () -> SessionT m () forall (f :: * -> *). Applicative f => Bool -> f () -> f () unless ([NodeDef] -> Bool forall (t :: * -> *) a. Foldable t => t a -> Bool null [NodeDef] nodesToExtend) (SessionT m () -> SessionT m ()) -> SessionT m () -> SessionT m () forall a b. (a -> b) -> a -> b $ IO () -> SessionT m () forall (m :: * -> *) a. MonadIO m => IO a -> m a liftIO (IO () -> SessionT m ()) -> IO () -> SessionT m () forall a b. (a -> b) -> a -> b $ do let graphDef :: GraphDef graphDef = (GraphDef forall msg. Message msg => msg defMessage :: GraphDef) GraphDef -> (GraphDef -> GraphDef) -> GraphDef forall s t. s -> (s -> t) -> t & forall (f :: * -> *). Identical f => LensLike' f GraphDef [NodeDef] forall (f :: * -> *) s a. (Functor f, HasField s "node" a) => LensLike' f s a node (forall (f :: * -> *). Identical f => LensLike' f GraphDef [NodeDef]) -> [NodeDef] -> GraphDef -> GraphDef forall s t a b. Setter s t a b -> b -> s -> t .~ [NodeDef] nodesToExtend Tracer trace ("Session.extend " Builder -> Builder -> Builder forall a. Semigroup a => a -> a -> a <> String -> Builder Builder.string8 (GraphDef -> String forall msg. Message msg => msg -> String showMessage GraphDef graphDef)) Session -> GraphDef -> IO () FFI.extendGraph Session session GraphDef graphDef
[NodeName]initializers <- Build [NodeName] -> SessionT m [NodeName] forall (m :: * -> *) a. MonadBuild m => Build a -> m a build Build [NodeName] forall (m :: * -> *). Monad m => BuildT m [NodeName] flushInitializers Bool -> SessionT m () -> SessionT m () forall (f :: * -> *). Applicative f => Bool -> f () -> f () unless ([NodeName] -> Bool forall (t :: * -> *) a. Foldable t => t a -> Bool null [NodeName] initializers) (SessionT m () -> SessionT m ()) -> SessionT m () -> SessionT m () forall a b. (a -> b) -> a -> b $ SessionT m [TensorData] -> SessionT m () forall (f :: * -> *) a. Functor f => f a -> f () void (SessionT m [TensorData] -> SessionT m ()) -> SessionT m [TensorData] -> SessionT m () forall a b. (a -> b) -> a -> b $ IO [TensorData] -> SessionT m [TensorData] forall (m :: * -> *) a. MonadIO m => IO a -> m a liftIO (IO [TensorData] -> SessionT m [TensorData]) -> IO [TensorData] -> SessionT m [TensorData] forall a b. (a -> b) -> a -> b $ Session -> [(ByteString, TensorData)] -> [ByteString] -> [ByteString] -> IO [TensorData] FFI.run Session session [] [] ([NodeName] -> [ByteString] toNodeNames [NodeName] initializers)
run :: (MonadIO m, Fetchable t a) => t -> SessionT m a run :: t -> SessionT m a run = [Feed] -> t -> SessionT m a forall (m :: * -> *) t a. (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a runWithFeeds []
runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a runWithFeeds :: [Feed] -> t -> SessionT m a runWithFeeds feeds :: [Feed] feeds t :: t t = do Set NodeName ns <- Build (Set NodeName) -> SessionT m (Set NodeName) forall (m :: * -> *) a. MonadBuild m => Build a -> m a build (Build (Set NodeName) -> SessionT m (Set NodeName)) -> Build (Set NodeName) -> SessionT m (Set NodeName) forall a b. (a -> b) -> a -> b $ t -> Build (Set NodeName) forall t. Nodes t => t -> Build (Set NodeName) getNodes t t
Fetch afetch <- Build (Fetch a) -> SessionT m (Fetch a) forall (m :: * -> *) a. MonadBuild m => Build a -> m a build (Build (Fetch a) -> SessionT m (Fetch a)) -> Build (Fetch a) -> SessionT m (Fetch a) forall a b. (a -> b) -> a -> b $ t -> Build (Fetch a) forall t a. Fetchable t a => t -> Build (Fetch a) getFetch t t [Feed] -> Set NodeName -> Fetch a -> SessionT m a forall (m :: * -> *) a. MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a runFetchWithFeeds [Feed] feeds Set NodeName ns Fetch a fetch
runFetchWithFeeds :: MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> SessionT m a runFetchWithFeeds feeds :: [Feed] feeds target :: Set NodeName target (Fetch fetch :: Set Text fetch restore :: Map Text TensorData -> a restore) = do SessionT m () forall (m :: * -> *). MonadIO m => SessionT m () extend let feeds' :: [(ByteString, TensorData)] feeds' = [Feed] -> [(ByteString, TensorData)] fixFeeds [Feed] feeds let fetchNames :: [ByteString] fetchNames = Text -> ByteString encodeUtf8 (Text -> ByteString) -> [Text] -> [ByteString] forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> Set Text -> [Text] forall a. Set a -> [a] Set.toList Set Text fetch targetNames :: [ByteString] targetNames = [NodeName] -> [ByteString] toNodeNames ([NodeName] -> [ByteString]) -> [NodeName] -> [ByteString] forall a b. (a -> b) -> a -> b $ Set NodeName -> [NodeName] forall a. Set a -> [a] Set.toList Set NodeName target Session session <- ReaderT SessionState (BuildT m) Session -> SessionT m Session forall (m :: * -> *) a. ReaderT SessionState (BuildT m) a -> SessionT m a Session ((SessionState -> Session) -> ReaderT SessionState (BuildT m) Session forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a asks SessionState -> Session rawSession) [TensorData] runResult <- IO [TensorData] -> SessionT m [TensorData] forall (m :: * -> *) a. MonadIO m => IO a -> m a liftIO (IO [TensorData] -> SessionT m [TensorData]) -> IO [TensorData] -> SessionT m [TensorData] forall a b. (a -> b) -> a -> b $ Session -> [(ByteString, TensorData)] -> [ByteString] -> [ByteString] -> IO [TensorData] FFI.run Session session [(ByteString, TensorData)] feeds' [ByteString] fetchNames [ByteString] targetNames let resultTensorsMap :: Map Text TensorData resultTensorsMap = [(Text, TensorData)] -> Map Text TensorData forall k a. Ord k => [(k, a)] -> Map k a Map.fromList ([(Text, TensorData)] -> Map Text TensorData) -> [(Text, TensorData)] -> Map Text TensorData forall a b. (a -> b) -> a -> b $ [Text] -> [TensorData] -> [(Text, TensorData)] forall a b. [a] -> [b] -> [(a, b)] zip (Set Text -> [Text] forall a. Set a -> [a] Set.toList Set Text fetch) [TensorData] runResult a -> SessionT m a forall (m :: * -> *) a. Monad m => a -> m a return (a -> SessionT m a) -> a -> SessionT m a forall a b. (a -> b) -> a -> b $ Map Text TensorData -> a restore Map Text TensorData resultTensorsMap
toNodeNames :: [NodeName] -> [ByteString] toNodeNames :: [NodeName] -> [ByteString] toNodeNames = (NodeName -> ByteString) -> [NodeName] -> [ByteString] forall a b. (a -> b) -> [a] -> [b] map (Text -> ByteString encodeUtf8 (Text -> ByteString) -> (NodeName -> Text) -> NodeName -> ByteString forall b c a. (b -> c) -> (a -> b) -> a -> c . NodeName -> Text unNodeName)
run_ :: (MonadIO m, Nodes t) => t -> SessionT m () run_ :: t -> SessionT m () run_ = [Feed] -> t -> SessionT m () forall (m :: * -> *) t. (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m () runWithFeeds_ []
runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m () runWithFeeds_ :: [Feed] -> t -> SessionT m () runWithFeeds_ feeds :: [Feed] feeds t :: t t = do Set NodeName ns <- Build (Set NodeName) -> SessionT m (Set NodeName) forall (m :: * -> *) a. MonadBuild m => Build a -> m a build (Build (Set NodeName) -> SessionT m (Set NodeName)) -> Build (Set NodeName) -> SessionT m (Set NodeName) forall a b. (a -> b) -> a -> b $ t -> Build (Set NodeName) forall t. Nodes t => t -> Build (Set NodeName) getNodes t t [Feed] -> Set NodeName -> Fetch () -> SessionT m () forall (m :: * -> *) a. MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a runFetchWithFeeds [Feed] feeds Set NodeName ns (() -> Fetch () forall (f :: * -> *) a. Applicative f => a -> f a pure ())
fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)] fixFeeds :: [Feed] -> [(ByteString, TensorData)] fixFeeds = (Feed -> (ByteString, TensorData)) -> [Feed] -> [(ByteString, TensorData)] forall a b. (a -> b) -> [a] -> [b] map ((Feed -> (ByteString, TensorData)) -> [Feed] -> [(ByteString, TensorData)]) -> (Feed -> (ByteString, TensorData)) -> [Feed] -> [(ByteString, TensorData)] forall a b. (a -> b) -> a -> b $ (Feed o :: Output o d :: TensorData d) -> (Text -> ByteString encodeUtf8 (Text -> ByteString) -> Text -> ByteString forall a b. (a -> b) -> a -> b $ Output -> Text encodeOutput Output o, TensorData d)
asyncProdNodes :: (MonadIO m, Nodes t)
=> t
-> SessionT m ()
asyncProdNodes :: t -> SessionT m ()
asyncProdNodes nodes :: t
nodes = do
Set NodeName
target <- Build (Set NodeName) -> SessionT m (Set NodeName)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (t -> Build (Set NodeName)
forall t. Nodes t => t -> Build (Set NodeName)
getNodes t
nodes)
SessionT m ()
forall (m :: * -> *). MonadIO m => SessionT m ()
extend
let targetNames :: [ByteString]
targetNames = [NodeName] -> [ByteString]
toNodeNames ([NodeName] -> [ByteString]) -> [NodeName] -> [ByteString]
forall a b. (a -> b) -> a -> b
$ Set NodeName -> [NodeName]
forall a. Set a -> [a]
Set.toList Set NodeName
target
SessionState
state <- ReaderT SessionState (BuildT m) SessionState
-> SessionT m SessionState
forall (m :: * -> *) a.
ReaderT SessionState (BuildT m) a -> SessionT m a
Session ReaderT SessionState (BuildT m) SessionState
forall (m :: * -> *) r. Monad m => ReaderT r m r
ask
let loop :: IO b
loop = IO () -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO [TensorData] -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Session
-> [(ByteString, TensorData)]
-> [ByteString]
-> [ByteString]
-> IO [TensorData]
FFI.run (SessionState -> Session
rawSession SessionState
state) [] [] [ByteString]
targetNames))
IO () -> SessionT m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (SessionState -> IO () -> IO ()
asyncCollector SessionState
state IO ()
forall b. IO b
loop)