Import Neural Network Models Using ONNX - MATLAB & Simulink (original) (raw)

To create function approximators for reinforcement learning, you can import pre-trained deep neural networks or deep neural network layer architectures using the Deep Learning Toolbox™ network import functionality. You can import:

After you import a deep neural network, you can create an actor or critic object, such asrlValueFunction orrlDiscreteCategoricalActor.

When you import deep neural network architectures, consider the following.

For more information on the deep neural network architectures supported for reinforcement learning, see Create Policies and Value Functions.

Import Actor and Critic for Image Observation Application

As an example, assume that you have an environment with a 50-by-50 grayscale image observation signal and a continuous action space. To train a policy gradient (PG) agent, you require the following function approximators, both of which must have a single 50-by-50 image input observation layer and a single scalar output value.

Also, assume that you have the following network architectures to import:

To import the critic and actor networks, use theimportNetworkFromONNX function.

criticNetwork = importNetworkFromONNX("criticNetwork.onnx"); actorNetwork = importNetworkFromONNX("actorNetwork.onnx");

After you import the network, if you already have an appropriate agent for your environment you can use getActor andgetCritic to extract the actor and critic function approximators for the agent, then setModel to set the imported networks a the approximation models of the actor and critic, and then setActor andsetCritic to set the actor and critic with the imported network into your agent.

Alternatively, create new actor and critic function approximators that use the imported networks. To do so, first obtain the observation and action specifications from the environment.

obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);

Create the critic. PG agents use an rlValueFunction approximator.

critic = rlValueFunction(criticNetwork,obsInfo);

If your critic has more than one input channel (for example because your environment has more than one output channel or because you are using a Q-value function critic, which also needs an action input), it is good practice to specify the names of the input layer that need to be connected, in sequential order, with each critic input channel. For an example, see Train DDPG Agent to Swing Up and Balance Pendulum with Image Observation.

Create the actor. PG agents use an rlContinuousDeterministicActor approximator.

actor = rlContinuousDeterministicActor(actorNetwork,obsInfo,actInfo);

As for the critic, if your actor has more than one input channel (because your environment has more than one output channel), it is good practice to specify the name of the input layer that needs to be connected with each actor input channel.

To verify that your actor and critic work properly, use getAction andgetValue to return an action (for the actor) and the value (for the critic) corresponding to a random observation, using the current network weights.

After you have the actor and critic with the imported networks, you can then:

See Also

Functions

Objects

More About