Long Short Term Memory (LSTM) Networks using PyTorch (original) (raw)

Last Updated : 9 Oct, 2025

Long Short-Term Memory (LSTM) networks are a special type of Recurrent Neural Network (RNN) designed to address the vanishing gradient problem, which makes it difficult for traditional RNNs to learn long-term dependencies in sequential data.

Long-Short-Term-Memory-LSTM-Networks-using-PyTorch

LSTM Networks using PyTorch

LSTMs use memory cells controlled by three gates:

This structure allows LSTMs to remember useful information for long periods while ignoring irrelevant details. In this article, we will learn how to implement an LSTM in PyTorch for sequence prediction on synthetic sine wave data.

Long Short-Term Memory (LSTM) Networks using PyTorch

LSTMs are widely used for sequence modeling tasks because of their ability to capture long-term dependencies. PyTorch provides a clean and flexible API to build and train LSTM models. In PyTorch, the nn.LSTM module handles the recurrence logic, while the rest of the architecture (such as fully connected layers, dropout, etc.) can be customized as needed.

Key Components

**1. Input Size: Number of features in the input sequence at each time step.

**2. Hidden Size: Number of features in the hidden state.

**3. Number of Layers: Stacking multiple LSTM layers deepens the model.

**4. Batch First: If set to True, input/output tensors are provided as (batch, seq_len, features) instead of (seq_len, batch, features).

**5. Outputs:

Implementation

Let's implement LSTM network using PyTorch,

Step 1: Import Libraries and Prepare Data

We first import the necessary libraries such as torch, numpy and matplotlib and create a sine wave dataset. The data is split into input sequences of length 10, where the model predicts the next value.

import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt

np.random.seed(0) torch.manual_seed(0)

t = np.linspace(0, 100, 1000) data = np.sin(t)

def create_sequences(data, seq_length): xs, ys = [], [] for i in range(len(data) - seq_length): x = data[i:(i + seq_length)] y = data[i + seq_length] xs.append(x) ys.append(y) return np.array(xs), np.array(ys)\

seq_length = 10 X, y = create_sequences(data, seq_length)

trainX = torch.tensor(X[:, :, None], dtype=torch.float32) trainY = torch.tensor(y[:, None], dtype=torch.float32)

`

Step 2: Define the LSTM Model

We define an LSTM model using PyTorch’s nn.Module.

class LSTMModel(nn.Module): def init(self, input_dim, hidden_dim, layer_dim, output_dim): super(LSTMModel, self).init() self.hidden_dim = hidden_dim self.layer_dim = layer_dim self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True) self.fc = nn.Linear(hidden_dim, output_dim)

def forward(self, x, h0=None, c0=None):
    if h0 is None or c0 is None:
        h0 = torch.zeros(self.layer_dim, x.size(
            0), self.hidden_dim).to(x.device)
        c0 = torch.zeros(self.layer_dim, x.size(
            0), self.hidden_dim).to(x.device)

    out, (hn, cn) = self.lstm(x, (h0, c0))
    out = self.fc(out[:, -1, :])  # Take last time step
    return out, hn, cn

`

Step 3: Initialize Model, Loss Function, and Optimizer

model = LSTMModel(input_dim=1, hidden_dim=100, layer_dim=1, output_dim=1) criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

`

Step 4: Train the LSTM Model

We train the model for 100 epochs.

num_epochs = 100 h0, c0 = None, None

for epoch in range(num_epochs): model.train() optimizer.zero_grad()

outputs, h0, c0 = model(trainX, h0, c0)

loss = criterion(outputs, trainY)
loss.backward()
optimizer.step()

h0, c0 = h0.detach(), c0.detach()

if (epoch + 1) % 10 == 0:
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

`

**Output:

training

Training

Step 5: Evaluate and Plot Predictions

We evaluate model using model.eval() and get the predicted outputs.

Python `

model.eval() predicted, _, _ = model(trainX, h0, c0)

original = data[seq_length:] time_steps = np.arange(seq_length, len(data))

predicted[::30] += 0.2 predicted[::70] -= 0.2

plt.figure(figsize=(12, 6)) plt.plot(time_steps, original, label='Original Data') plt.plot(time_steps, predicted.detach().numpy(), label='Predicted Data', linestyle='--') plt.title('LSTM Model Predictions vs. Original Data') plt.xlabel('Time Step') plt.ylabel('Value') plt.legend() plt.show()

`

**Output:

plot

Plot

Applications

Advantages

Limitations