Create Custom Environment from Class Template - MATLAB & Simulink (original) (raw)

You can define a custom reinforcement learning environment by creating and modifying a template environment class. You can use a custom template environment to:

For more information about creating MATLAB® classes, see User-Defined Classes.

You can create less complex custom reinforcement learning environments using custom functions, as described in Create Custom Environment Using Step and Reset Functions.

Create Template Class

To define your custom environment, first create the template class file, specifying the name of the class. For this example, name the class MyEnvironment.

rlCreateEnvTemplate("MyEnvironment")

The function rlCreateEnvTemplate creates and opens the template class file. The template class is a subclass of therl.env.MATLABEnvironment abstract class, as shown in the class definition at the start of the template file. This abstract class is the same one used by the other MATLAB reinforcement learning environment objects.

classdef MyEnvironment < rl.env.MATLABEnvironment

By default, the template class implements a simple cart-pole balancing model similar to the cart-pole predefined environments described in Load Predefined Control System Environments.

To define your environment dynamics, save the file asMyEnvironment.m. Then modify the template class by specifying the following:

Environment Properties

In the properties section of the template, specify any parameters necessary for creating and simulating the environment. These parameters can include:

properties % Specify and initialize the necessary properties of the environment
% Acceleration due to gravity in m/s^2 Gravity = 9.8

% Mass of the cart
CartMass = 1.0

% Mass of the pole
PoleMass = 0.1

% Half the length of the pole
HalfPoleLength = 0.5

% Max force the input can apply
MaxForce = 10
       
% Sample time
Ts = 0.02

% Angle at which to fail the episode (radians)
AngleThreshold = 12 * pi/180
    
% Distance at which to fail the episode
DisplacementThreshold = 2.4
    
% Reward each time step the cart-pole is balanced
RewardForNotFalling = 1

% Penalty when the cart-pole fails to balance
PenaltyForFalling = -10 

end

properties % Initialize system state [x,dx,theta,dtheta]' State = zeros(4,1) end

properties(Access = protected) % Initialize internal flag to indicate episode termination IsDone = false
end

Required Functions

A reinforcement learning environment requires the following functions to be defined. ThegetObservationInfo, getActionInfo,sim, and validateEnvironment functions are already defined in the base abstract class. To create your environment, you must define the constructor, reset, and step functions.

Function Description
getObservationInfo Return information about the environment observations
getActionInfo Return information about the environment actions
sim Simulate the environment with an agent
validateEnvironment Validate the environment by calling the reset function and simulating the environment for one time step using step
reset Initialize the environment state and clean up any visualization
step Apply an action, simulate the environment for one step, and output the observations and rewards; also, set a flag indicating whether the episode is complete
Constructor function A function with the same name as the class that creates an instance of the class

Sample Constructor Function

The sample cart-pole constructor function creates the environment by:

function this = MyEnvironment() % Initialize observation settings ObservationInfo = rlNumericSpec([4 1]); ObservationInfo.Name = 'CartPole States'; ObservationInfo.Description = 'x, dx, theta, dtheta';

% Initialize action settings   
ActionInfo = rlFiniteSetSpec([-1 1]);
ActionInfo.Name = 'CartPole Action';

% The following line implements built-in functions of the RL environment
this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo);

% Initialize property values and precompute necessary values
updateActionInfo(this);

end

This sample constructor function does not include any input arguments. However, you can add input arguments for your custom constructor.

Sample reset Function

The sample cart-pole reset function sets the initial condition of the model and returns the initial values of the observations. It also generates a notification that the environment has been updated by calling the envUpdatedCallback function, which is useful for updating the environment visualization.

% Reset environment to initial state and return initial observation function InitialObservation = reset(this) % Theta (+- .05 rad) T0 = 2 * 0.05 * rand - 0.05;
% Thetadot Td0 = 0; % X X0 = 0; % Xdot Xd0 = 0;

InitialObservation = [X0;Xd0;T0;Td0];
this.State = InitialObservation;

% (Optional) Use notifyEnvUpdated to signal that the 
% environment is updated (for example, to update the visualization)
notifyEnvUpdated(this);

end

Sample step Function

The sample cart-pole step function:

function [Observation,Reward,IsDone,Info] = step(this,Action) Info = [];

% Get action
Force = getForce(this,Action);            

% Unpack state vector
XDot = this.State(2);
Theta = this.State(3);
ThetaDot = this.State(4);

% Cache to avoid recomputation
CosTheta = cos(Theta);
SinTheta = sin(Theta);            
SystemMass = this.CartMass + this.PoleMass;
temp = (Force + this.PoleMass*this.HalfPoleLength*ThetaDot^2*SinTheta)...
    /SystemMass;

% Apply motion equations            
ThetaDotDot = (this.Gravity*SinTheta - CosTheta*temp)...
    / (this.HalfPoleLength*(4.0/3.0 - this.PoleMass*CosTheta*CosTheta/SystemMass));
XDotDot  = temp - this.PoleMass*this.HalfPoleLength*ThetaDotDot*CosTheta/SystemMass;

% Euler integration
Observation = this.State + this.Ts.*[XDot;XDotDot;ThetaDot;ThetaDotDot];

% Update system states
this.State = Observation;

% Check terminal condition
X = Observation(1);
Theta = Observation(3);
IsDone = abs(X) > this.DisplacementThreshold || abs(Theta) > this.AngleThreshold;
this.IsDone = IsDone;

% Get reward
Reward = getReward(this);

% (Optional) Use notifyEnvUpdated to signal that the 
% environment has been updated (for example, to update the visualization)
notifyEnvUpdated(this);

end

Optional Functions

You can define any other functions in your template class as required. For example, you can create helper functions that are called by either step orreset. The cart-pole template model implements a getReward function for computing the reward at each time step.

function Reward = getReward(this) if ~this.IsDone Reward = this.RewardForNotFalling; else Reward = this.PenaltyForFalling; end
end

Environment Visualization

You can add a visualization to your custom environment by implementing theplot function. In the plot function:

function plot(this) % Initiate the visualization this.Figure = figure('Visible','on','HandleVisibility','off'); ha = gca(this.Figure); ha.XLimMode = 'manual'; ha.YLimMode = 'manual'; ha.XLim = [-3 3]; ha.YLim = [-1 2]; hold(ha,'on'); % Update the visualization envUpdatedCallback(this) end

For this example, store the handle to the figure as a protected property of the environment object.

properties(Access = protected) % Initialize internal flag to indicate episode termination IsDone = false

% Handle to figure
Figure

end

In the envUpdatedCallback, plot the visualization to the figure or use your custom visualizer object. For example, check if the figure handle has been set. If it has, then plot the visualization.

function envUpdatedCallback(this) if ~isempty(this.Figure) && isvalid(this.Figure) % Set visualization figure as the current figure ha = gca(this.Figure);

    % Extract the cart position and pole angle
    x = this.State(1);
    theta = this.State(3);

    cartplot = findobj(ha,'Tag','cartplot');
    poleplot = findobj(ha,'Tag','poleplot');
    if isempty(cartplot) || ~isvalid(cartplot) ...
            || isempty(poleplot) || ~isvalid(poleplot)
        % Initialize the cart plot
        cartpoly = polyshape([-0.25 -0.25 0.25 0.25],[-0.125 0.125 0.125 -0.125]);
        cartpoly = translate(cartpoly,[x 0]);
        cartplot = plot(ha,cartpoly,'FaceColor',[0.8500 0.3250 0.0980]);
        cartplot.Tag = 'cartplot';

        % Initialize the pole plot
        L = this.HalfPoleLength*2;
        polepoly = polyshape([-0.1 -0.1 0.1 0.1],[0 L L 0]);
        polepoly = translate(polepoly,[x,0]);
        polepoly = rotate(polepoly,rad2deg(theta),[x,0]);
        poleplot = plot(ha,polepoly,'FaceColor',[0 0.4470 0.7410]);
        poleplot.Tag = 'poleplot';
    else
        cartpoly = cartplot.Shape;
        polepoly = poleplot.Shape;
    end

    % Compute the new cart and pole position
    [cartposx,~] = centroid(cartpoly);
    [poleposx,poleposy] = centroid(polepoly);
    dx = x - cartposx;
    dtheta = theta - atan2(cartposx-poleposx,poleposy-0.25/2);
    cartpoly = translate(cartpoly,[dx,0]);
    polepoly = translate(polepoly,[dx,0]);
    polepoly = rotate(polepoly,rad2deg(dtheta),[x,0.25/2]);

    % Update the cart and pole positions on the plot
    cartplot.Shape = cartpoly;
    poleplot.Shape = polepoly;

    % Refresh rendering in the figure window
    drawnow();
end

end

The environment calls the envUpdatedCallback function, and therefore updates the visualization, whenever the environment is updated.

Instantiate Custom Environment

After you define your custom environment class, create an instance of it in the MATLAB workspace. At the command line, type the following.

If your constructor has input arguments, specify them after the class name. For example,MyEnvironment(arg1,arg2).

After you create your environment, the best practice is to validate the environment dynamics. To do so, use the validateEnvironment function, which prints an error to the command window if your environment implementation has any issues.

After validating the environment object, you can use it to train a reinforcement learning agent. For more information on training agents, see Train Reinforcement Learning Agents.

See Also

Functions

Objects

Topics