How to decrease too many False Positives I get from a KNN classifier for ECG R-peak detection?

52 Views Asked by At

I'm running on Python 3.11.7 (using Jupyter notebook) a code replicating the algorithm of this research paper. The researchers posit the idea of using a K-nearest neighbour classifier (K=3, Euclidian distance metric) for detecting QRS complexes in an ECG signal after preprocessing (Pan-Tompkins [5-12] Hz bandpass filter, gradient calculation and gradient curve extraction). However, upon applying it, and even after training the classifier on the record 100 of the MIT/BIH dataset, it outputs a high number of False Positives while testing it on this very record. Other records were also tested, further showing discrepancies with the paper results. All records in the MIT/BIH dataset have a sampling rate of 360 samples/s.

Trying the code shared hereunder, it outputs the following results:

2273 reference annotations, 2824 test annotations

True Positives (matched samples): 2273
False Positives (unmatched test samples): 551
False Negatives (unmatched reference samples): 0

Python code: Training stage

import wfdb
from wfdb import processing
import matplotlib.pyplot as plt
import numpy as np
import scipy.signal as sg
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_predict

def calculate_gradient(signal):
    """
    Calculate the gradient of a signal using central finite differences.

    Parameters:
        signal (numpy.ndarray): Input signal.

    Returns:
        numpy.ndarray: Gradient of the signal.
    """
    gradient = np.zeros_like(signal)
    n = len(signal)

    for i in range(1, n - 1):
        gradient[i] = (signal[i + 1] - signal[i - 1]) / 2.0
    
    gradient[0] = (signal[1] - signal[0])
    gradient[n - 1] = (signal[n - 1] - signal[n - 2])

    gradient /= np.max(gradient)
    return gradient

def bandpass_filter(ecg, fs):   
    Wn = 12*2/fs
    N = 3
    a, b = sg.butter(N, Wn, btype='lowpass')
    ecg_l = sg.filtfilt(a, b, ecg)
    
    ecg_l = ecg_l/np.max(np.abs(ecg_l)) 

    Wn = 5*2/fs
    N = 3                                           
    a, b = sg.butter(N, Wn, btype='highpass')            
    ecg_h = sg.filtfilt(a, b, ecg_l, padlen=3*(max(len(a), len(b))-1))
    ecg_h = ecg_h/np.max(np.abs(ecg_h))  
    return ecg_h

y1_0 = bandpass_filter(data, Fs)
y2_0 = calculate_gradient(y1_0)

m = len(y2_0) # Number of training instances
n = 2 # Number of features

feature_vector = np.zeros((m, n))
label_vector = np.zeros(m)
feature_vector[:, 0] = y2_0

# Setting labels for QRS/non-QRS regions
for sample_index in ann_ref:
    window_start = max(0, sample_index - 25)
    window_end = min(len(data), sample_index + 25)
    label_vector[window_start:window_end] = 1

neigh = KNeighborsClassifier(n_neighbors=3, p=2, metric='minkowski')
neigh.fit(feature_vector, label_vector)

Python code: Testing stage

y1_0 = bandpass_filter(data0, Fs)
y2_0 = calculate_gradient(y1_0)

m = len(y2_0) # Number of training instances
n = 2 # Number of features

X = np.zeros((m, n))
y = np.zeros(m)
X[:, 0] = y2_0

for sample_index in ann_ref:
    window_start = max(0, sample_index - 25)
    window_end = min(len(data0), sample_index + 25)
    y[window_start:window_end] = 1

predicted_labels0 = cross_val_predict(neigh, X[:, 0].reshape(-1, 1), y, cv=5) # Performing fivefold cross-validation

# Function to calculate the average pulse duration
def calculate_average_pulse_duration(predicted_labels):
    peak_durations = np.diff(np.where(predicted_labels == 1)[0])
    return np.mean(peak_durations)

# Function to detect QRS-complex based on the average pulse duration
def detect_QRS_complex(train_of_ones, average_pulse_duration):
    QRS_indices = []
    for i, train_duration in enumerate(train_of_ones):
        if train_duration > 3 * average_pulse_duration:
            QRS_indices.append(train_duration)
    return QRS_indices

train_of_ones_0 = []
for label, index in zip(predicted_labels0, range(len(predicted_labels0))):
    if label == 1:
        train_of_ones_0.append(index)

average_pulse_duration_0 = calculate_average_pulse_duration(predicted_labels0)
QRS_indices_0 = detect_QRS_complex(train_of_ones_0, average_pulse_duration_0)
QRS_indices_0 = np.array(QRS_indices_0, dtype=int)

samples = 100
index0 = []
for i in QRS_indices_0:
    start_index = max(0, i - samples)
    end_index = min(len(y2_0), i + samples + 1)
    signal_within_margin = y2_0[start_index:end_index]
    peaks = np.max(signal_within_margin)
    peaks_idx = np.argmax(signal_within_margin)
    index0.append(peaks_idx + start_index)

index0 = np.unique(index0)
index0 = np.sort(index0)
index0 = np.array(index0)

comparitor = processing.compare_annotations(ann_ref_indices, index0, int(0.1*Fs))
comparitor.print_summary()

As a note, I referred to the first lead in the ECG record by following up the data characteristic by 0 (y2_0, train_of_ones_0, index0...) because the algorithm is designed to classify for both the MLII and VI ECG leads. What am I getting wrong here?

0

There are 0 best solutions below