Visualize Activations of LSTM Network - MATLAB & Simulink (original) (raw)

Main Content

This example shows how to investigate and visualize the features learned by LSTM networks by extracting the activations.

Load pretrained network. JapaneseVowelsNet is a pretrained LSTM network trained on the Japanese Vowels dataset as described in [1] and [2]. It was trained on the sequences sorted by sequence length with a mini-batch size of 27.

View the network architecture.

ans = 4×1 Layer array with layers:

 1   'sequenceinput'   Sequence Input    Sequence input with 12 dimensions
 2   'lstm'            LSTM              LSTM with 100 hidden units
 3   'fc'              Fully Connected   9 fully connected layer
 4   'softmax'         Softmax           softmax

Load the test data.

load JapaneseVowelsTestData

Visualize the first time series in a plot. Each line corresponds to a feature.

X = XTest{1};

figure plot(XTest{1}') xlabel("Time Step") title("Test Observation 1") numFeatures = size(XTest{1},1); legend("Feature " + string(1:numFeatures),'Location',"northeastoutside")

Figure contains an axes object. The axes object with title Test Observation 1, xlabel Time Step contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

For each time step of the sequences, get the activations output by the LSTM layer (layer 2) for that time step and update the network state.

sequenceLength = size(X,2); idxLayer = 2; outputSize = net.Layers(idxLayer).NumHiddenUnits;

for i = 1:sequenceLength [features(i,:),state] = predict(net,X(:,1)',Outputs="lstm"); net.State = state; end features = features';

Visualize the first 10 hidden units using a heatmap.

figure heatmap(features(1:10,:)); xlabel("Time Step") ylabel("Hidden Unit") title("LSTM Activations")

Figure contains an object of type heatmap. The chart of type heatmap has title LSTM Activations.

The heatmap shows how strongly each hidden unit activates and highlights how the activations change over time.

References

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions."Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

See Also

trainnet | trainingOptions | dlnetwork | predict | forward | lstmLayer | bilstmLayer | sequenceInputLayer

Topics