How to implement neuronal coupling using Brian2 module?

23 Views Asked by At

I have been trying to simulate a network of adaptive exponential integrate and fire neurons using brian2 module. The differential equation representing the system is as follows:

Differential equation representing the system.

But I am facing a difficult in how to implement the coupling term in the Brian2 module. I have read that it can be possible by using the brian2.Synapses functionality, but I am facing difficulty in implementing it. Also I can't find a way to connect the neurons using a provided adjacency matrix.

I have wrote the code for the same without using the Brian2 module.

# ============================ Necessary packages ====================================

import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from collections import namedtuple
# ============================ Parameters ============================================

# Parameters
AdexParams = namedtuple("AdexParameter", ["c", "g_l", "e_l", "v_t", "v_peak", "delta_t", "tau_w", "a", "b", "v_reset", "current", "v_rev", "tau_g", "g_exc"])

params = AdexParams(
    c = 200.0,
    g_l = 12.0,
    e_l = -70,
    v_t = -50.0,
    v_peak = 20.0,
    delta_t = 2,
    tau_w = 300.0,
    a = 2.0,
    b = 50,
    v_reset = -70.0,
    current = 509.7,
    v_rev = 0.0,
    tau_g = 2.728,
    g_exc = 0.05
)

# Network parameters
num_neurons = 50

# Simulation parameters
t_sim = 1000
dt = 0.05
t = np.arange(0, t_sim, dt)
iterations = int(t_sim / dt)

# Generating network
creating_network = lambda num_neurons, degree: nx.to_numpy_array(nx.erdos_renyi_graph(num_neurons, degree/(num_neurons - 1)))

# differential equation
def adex_neuron(state, adj_matrix, g_array, g_s, neuron_index):
    v = state[0]
    w = state[1]
    g = state[2]

    adex_output = np.zeros_like(state)
    
    dv_dt = (params.g_l * (params.e_l - v) + params.g_l * params.delta_t * np.exp((v - params.v_t) / params.delta_t) - w + params.current) / params.c
    dw_dt = (params.a * (v - params.e_l) - w) / params.tau_w
    dg_dt = -g / params.tau_g

    coupling_term = g_s * (params.v_rev - v) * np.sum([adj_matrix[neuron_index][j] * g_array[j] for j in range(num_neurons)])

    adex_output[0] = dv_dt + coupling_term
    adex_output[1] = dw_dt
    adex_output[2] = dg_dt

    return adex_output


# Integration algorithm
def rk4(initial_state, adj_matrix, g_s):
    spike = [[] for _ in range(num_neurons)]
    state_final = np.zeros((iterations, num_neurons, 3))
    state_final[0, :, :] = initial_state

    g_array = [0 for _ in range(num_neurons)]
    
    for iter in range(1, iterations):
        for neuron in range(num_neurons):
            g_array = state_final[iter - 1, :, 2]
            k1 = adex_neuron(state_final[iter - 1, neuron, :], adj_matrix, g_array, g_s, neuron)
            k2 = adex_neuron(state_final[iter - 1, neuron, :] + 0.5 * k1 * dt, adj_matrix, g_array, g_s, neuron)
            k3 = adex_neuron(state_final[iter - 1, neuron, :] + 0.5 * k2 * dt, adj_matrix, g_array, g_s, neuron)
            k4 = adex_neuron(state_final[iter - 1, neuron, :] + k3 * dt, adj_matrix, g_array, g_s, neuron)

            state_final[iter, neuron, :] = state_final[iter - 1, neuron, :] + (k1 + 2 * (k2 + k3) + k4) * (dt / 6)
            
            print(f"previous potential: {state_final[iter - 1, neuron, 0]}")
            print(f"increment: {k1 + 2 * (k2 + k3) + k4}")
            

            if state_final[iter, neuron, 0] >= params.v_t:
                state_final[iter, neuron, 0] = params.v_reset
                state_final[iter, neuron, 1] += params.b
                state_final[iter, neuron, 2] += params.g_exc
                if t[iter] > 0:
                    spike[neuron].append(t[iter])
                    print("Spiked.....")

    return spike, state_final
    
    
# Initial conditions
state_ini_network = np.zeros((num_neurons, 3))
state_ini_network[:, :] = np.array((params.e_l, 0, 0))

# Running simulation
adj_matrix = creating_network(num_neurons, degree=25)
spike, state_final = rk4(state_ini_network, adj_matrix, 0.2)

plt.subplot(211)
plt.plot(t, state_final[:, :, 0])
plt.xlabel("Time")
plt.ylabel("V")

plt.subplot(212)
for i in range(num_neurons):
    plt.plot(spike[i], np.full(len(spike[i]), i), ".")
plt.xlabel("Spike times")
plt.ylabel("Neuron index")

plt.show()

But the results that I am getting is not accurate. So I am trying to do the same with Brian2 module. Any insight on how to use brian2.Synapses in this particular case will be a huge help.

0

There are 0 best solutions below