Spiking Neural Networks in Deep Learning (original) (raw)

Last Updated : 20 May, 2026

Spiking Neural Networks (SNNs) are brain-inspired neural networks that process information using discrete signals called spikes instead of continuous values like traditional neural networks.

Key Concepts in Spiking Neural Networks

Spiking Neural Networks use biologically inspired mechanisms to process and learn information through spikes and their timing.

1. Neurons and Spikes

Neurons communicate by generating spikes when their membrane potential reaches a threshold.

2. Temporal Coding

In SNNs, the timing of spikes carries important information.

3. Synaptic Weights and Plasticity

Connections between neurons are controlled by synaptic weights, which change during learning.

Working of Spiking Neural Networks

1. Membrane Potential and Firing Threshold

Neurons accumulate incoming spikes in their membrane potential and fire when the threshold is reached.

2. Synaptic Integration

Incoming spikes influence connected neurons through weighted synaptic connections.

3. Learning Rules

SNNs learn by adjusting synaptic weights based on spike timing.

4. Neuron Models

Different neuron models are used to simulate spiking behavior.

Implementation of Spiking Neural Network

In this section, we will implement a simple Spiking Neural Network (SNN) using the Leaky Integrate-and-Fire (LIF) neuron model for detecting a specific spike pattern.

Step 1: Define Neuron and Synapse Classes

import numpy as np

class LIFNeuron: def init(self, threshold, reset_value, decay_factor, refractory_period): self.threshold = threshold self.reset_value = reset_value self.decay_factor = decay_factor self.refractory_period = refractory_period self.membrane_potential = 0 self.spike_time = -1 self.refractory_end_time = -1

def update(self, incoming_spikes, current_time):
    if current_time < self.refractory_end_time:
        return False
    
    self.membrane_potential *= self.decay_factor
    self.membrane_potential += np.sum(incoming_spikes)
    
    if self.membrane_potential >= self.threshold:
        self.spike_time = current_time
        self.membrane_potential = self.reset_value
        self.refractory_end_time = current_time + self.refractory_period
        return True
    return False

class Synapse: def init(self, weight): self.weight = weight

`

Step 2: Define the STDP Learning Rule

The stdp function adjusts the synaptic weights based on the timing difference between the pre- and post-synaptic spikes.

Python `

def stdp(pre_spike_time, post_spike_time, weight, learning_rate, tau_positive, tau_negative): if pre_spike_time > 0 and post_spike_time > 0: delta_t = post_spike_time - pre_spike_time if delta_t > 0: return weight + learning_rate * np.exp(-delta_t / tau_positive) else: return weight - learning_rate * np.exp(delta_t / tau_negative) return weight

`

Step 3: Initialize Simulation Parameters and Network

time_steps = 100 input_size = 5 hidden_size = 3 output_size = 1

input_neurons = [LIFNeuron(threshold=1.0, reset_value=0.0, decay_factor=0.9, refractory_period=2) for _ in range(input_size)] hidden_neurons = [LIFNeuron(threshold=1.0, reset_value=0.0, decay_factor=0.9, refractory_period=2) for _ in range(hidden_size)] output_neurons = [LIFNeuron(threshold=1.0, reset_value=0.0, decay_factor=0.9, refractory_period=2) for _ in range(output_size)]

input_to_hidden_synapses = np.random.rand(input_size, hidden_size) hidden_to_output_synapses = np.random.rand(hidden_size, output_size)

learning_rate = 0.01 tau_positive = 20 tau_negative = 20

`

Step 4: Define the Spike Train Pattern to Detect

Set the pattern of spikes that the network should detect.

Python `

pattern = [1, 0, 1, 0, 1]

`

Step 5: Simulation Loop

for t in range(time_steps): input_spikes = np.random.randint(0, 2, size=input_size)

hidden_spikes = np.zeros(hidden_size)
for i, neuron in enumerate(input_neurons):
    if neuron.update(input_spikes[i] * input_to_hidden_synapses[i], t):
        hidden_spikes += input_to_hidden_synapses[i]

output_spikes = np.zeros(output_size)
for j, neuron in enumerate(hidden_neurons):
    if neuron.update(hidden_spikes[j] * hidden_to_output_synapses[j], t):
        output_spikes += hidden_to_output_synapses[j]

for k, neuron in enumerate(output_neurons):
    neuron.update(output_spikes[k], t)

for i in range(input_size):
    for j in range(hidden_size):
        input_to_hidden_synapses[i, j] = stdp(input_neurons[i].spike_time, hidden_neurons[j].spike_time, input_to_hidden_synapses[i, j], learning_rate, tau_positive, tau_negative)
for j in range(hidden_size):
    for k in range(output_size):
        hidden_to_output_synapses[j, k] = stdp(hidden_neurons[j].spike_time, output_neurons[k].spike_time, hidden_to_output_synapses[j, k], learning_rate, tau_positive, tau_negative)

if all(neuron.spike_time == t for neuron, pat in zip(input_neurons, pattern) if pat == 1):
    print(f"Pattern detected at time step {t}")

`

**Output:

output89

Output

Download full code from here

Applications

Challenges