Importing Libraries

import torch

import copy
import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split

from torch import nn, optim

import torch.nn.functional as F
from arff2pandas import a2p

%matplotlib inline
%config InlineBackend.figure_format='retina'

sns.set(style='whitegrid', palette='muted', font_scale=1.2)

HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]

sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))

rcParams['figure.figsize'] = 12, 8

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
<torch._C.Generator at 0x7f976c556370>

Dataset Description

We have 5 types of hearbeats (classes):

  • Normal (N)
  • R-on-T Premature Ventricular Contraction (R-on-T PVC)
  • Premature Ventricular Contraction (PVC)
  • Supra-ventricular Premature or Ectopic Beat (SP or EB)
  • Unclassified Beat (UB).

Assuming a healthy heart and a typical rate of 70 to 75 beats per minute, each cardiac cycle, or heartbeat, takes about 0.8 seconds to complete the cycle. Frequency: 60–100 per minute (Humans) Duration: 0.6–1 second (Humans)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
with open('ECG5000_TRAIN.arff') as f:
    train = a2p.load(f)
    
with open('ECG5000_TEST.arff') as f:
    test = a2p.load(f)
df = train.append(test)
df = df.sample(frac=1.0)
df.shape
(5000, 141)
df.head()
att1@NUMERIC att2@NUMERIC att3@NUMERIC att4@NUMERIC att5@NUMERIC att6@NUMERIC att7@NUMERIC att8@NUMERIC att9@NUMERIC att10@NUMERIC ... att132@NUMERIC att133@NUMERIC att134@NUMERIC att135@NUMERIC att136@NUMERIC att137@NUMERIC att138@NUMERIC att139@NUMERIC att140@NUMERIC target@{1,2,3,4,5}
1001 1.469756 -1.048520 -3.394356 -4.254399 -4.162834 -3.822570 -3.003609 -1.799773 -1.500033 -1.025095 ... 0.945178 1.275588 1.617218 1.580279 1.306195 1.351674 1.915517 1.672103 -1.039932 1
2086 -1.998602 -3.770552 -4.267091 -4.256133 -3.515288 -2.554540 -1.699639 -1.566366 -1.038815 -0.425483 ... 1.008577 1.024698 1.051141 1.015352 0.988475 1.050191 1.089509 1.465382 0.799517 1
2153 -1.187772 -3.365038 -3.695653 -4.094781 -3.992549 -3.425381 -2.057643 -1.277729 -1.307397 -0.623098 ... 1.085007 1.467196 1.413850 1.283822 0.923126 0.759235 0.932364 1.216265 -0.824489 1
555 0.604969 -1.671363 -3.236131 -3.966465 -4.067820 -3.551897 -2.582864 -1.804755 -1.688151 -1.025897 ... 0.545222 0.649363 0.986846 1.234495 1.280039 1.215985 1.617971 2.196543 0.023843 1
205 -1.197203 -3.270123 -3.778723 -3.977574 -3.405060 -2.392634 -1.726322 -1.572748 -0.920075 -0.388731 ... 0.828168 0.914338 1.063077 1.393479 1.469756 1.392281 1.144732 1.668263 1.734676 1

5 rows × 141 columns

CLASS_NORMAL = 1

class_names = ['Normal', 'R on T', 'PVC', 'SP', 'UB']
new_columns = list(df.columns)
new_columns[-1] = 'target'
df.columns = new_columns
df.head()
att1@NUMERIC att2@NUMERIC att3@NUMERIC att4@NUMERIC att5@NUMERIC att6@NUMERIC att7@NUMERIC att8@NUMERIC att9@NUMERIC att10@NUMERIC ... att132@NUMERIC att133@NUMERIC att134@NUMERIC att135@NUMERIC att136@NUMERIC att137@NUMERIC att138@NUMERIC att139@NUMERIC att140@NUMERIC target
1001 1.469756 -1.048520 -3.394356 -4.254399 -4.162834 -3.822570 -3.003609 -1.799773 -1.500033 -1.025095 ... 0.945178 1.275588 1.617218 1.580279 1.306195 1.351674 1.915517 1.672103 -1.039932 1
2086 -1.998602 -3.770552 -4.267091 -4.256133 -3.515288 -2.554540 -1.699639 -1.566366 -1.038815 -0.425483 ... 1.008577 1.024698 1.051141 1.015352 0.988475 1.050191 1.089509 1.465382 0.799517 1
2153 -1.187772 -3.365038 -3.695653 -4.094781 -3.992549 -3.425381 -2.057643 -1.277729 -1.307397 -0.623098 ... 1.085007 1.467196 1.413850 1.283822 0.923126 0.759235 0.932364 1.216265 -0.824489 1
555 0.604969 -1.671363 -3.236131 -3.966465 -4.067820 -3.551897 -2.582864 -1.804755 -1.688151 -1.025897 ... 0.545222 0.649363 0.986846 1.234495 1.280039 1.215985 1.617971 2.196543 0.023843 1
205 -1.197203 -3.270123 -3.778723 -3.977574 -3.405060 -2.392634 -1.726322 -1.572748 -0.920075 -0.388731 ... 0.828168 0.914338 1.063077 1.393479 1.469756 1.392281 1.144732 1.668263 1.734676 1

5 rows × 141 columns

Exploratory Data Analysis

df.target.value_counts()
1    2919
2    1767
4     194
3      96
5      24
Name: target, dtype: int64
ax = sns.countplot(df.target)
ax.set_xticklabels(class_names);
/home/siddy/anaconda3/envs/torch/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: x. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(

The normal class, has by far, the most examples. This is great because we'll use it to train our model.

Let's have a look at an averaged (smoothed out with one standard deviation on top and bottom of it) Time Series for each class:

def plot_time_series_class(data, class_name, ax, n_steps=10):
    time_series_df = pd.DataFrame(data)
    
    smooth_path = time_series_df.rolling(n_steps).mean()
    path_deviation = 2 * time_series_df.rolling(n_steps).std()
    
    under_line = (smooth_path - path_deviation)[0]
    over_line = (smooth_path + path_deviation)[0]
    
    ax.plot(smooth_path, linewidth=2)
    ax.fill_between(
      path_deviation.index,
      under_line,
      over_line,
      alpha=.125
    )
    ax.set_title(class_name)
classes = df.target.unique()

fig, axs = plt.subplots(
  nrows = len(classes) // 3 + 1,
  ncols = 3,
  sharey=True,
  figsize=(14,8)
)

for i, cls in enumerate(classes):
    ax = axs.flat[i]
    data = df[df.target == cls] \
      .drop(labels='target', axis=1) \
      .mean(axis=0) \
      .to_numpy()
    plot_time_series_class(data, class_names[i], ax)
    
fig.delaxes(axs.flat[-1])
fig.tight_layout();

LSTM Autoencoder

I'll have a look at how to feed Time Series data to an Autoencoder. We'll use a couple of LSTM layers (hence the LSTM Autoencoder) to capture the temporal dependencies of the data.

To classify a sequence as normal or an anomaly, we'll pick a threshold above which a heartbeat is considered abnormal.

Reconstruction Loss

When training an Autoencoder, the objective is to reconstruct the input as best as possible. This is done by minimizing a loss function (just like in supervised learning). This function is known as reconstruction loss. Cross-entropy loss and Mean squared error are common examples.

Data Preprocessing

normal_df = df[df.target == str(CLASS_NORMAL)].drop(labels='target', axis=1)
normal_df.shape
(2919, 140)
anomaly_df = df[df.target != str(CLASS_NORMAL)].drop(labels='target', axis=1)
anomaly_df.shape
(2081, 140)
train_df, val_df = train_test_split(
  normal_df,
  test_size=0.15,
  random_state=RANDOM_SEED
)

val_df, test_df = train_test_split(
  val_df,
  test_size=0.33,
  random_state=RANDOM_SEED
)
print(test_df.shape)
print(val_df.shape)
print(test_df.shape)
(145, 140)
(293, 140)
(145, 140)
def create_dataset(df):
    
    sequences = df.astype(np.float32).to_numpy().tolist()
    
    dataset = [torch.tensor(s).unsqueeze(1).float() for s in sequences]
    
    n_seq, seq_len, n_features = torch.stack(dataset).shape
    
    return dataset, seq_len, n_features

Each Time Series will be converted to a 2D Tensor in the shape sequence length x number of features (140x1 in our case).

train_dataset, seq_len, n_features = create_dataset(train_df)
val_dataset, _, _ = create_dataset(val_df)
test_normal_dataset, _, _ = create_dataset(test_df)
test_anomaly_dataset, _, _ = create_dataset(anomaly_df)

LSTM Autoencoder

LSTM_Autoencoder

The general Autoencoder architecture consists of two components. An Encoder that compresses the input and a Decoder that tries to reconstruct it.

We'll use the LSTM Autoencoder from this GitHub repo with some small tweaks. Our model's job is to reconstruct Time Series data. Let's start with the Encoder:

class Encoder(nn.Module):

  def __init__(self, seq_len, n_features, embedding_dim=64):
    super(Encoder, self).__init__()

    self.seq_len, self.n_features = seq_len, n_features
    self.embedding_dim, self.hidden_dim = embedding_dim, 2 * embedding_dim

    self.rnn1 = nn.LSTM(
      input_size=n_features,
      hidden_size=self.hidden_dim,
      num_layers=1,
      batch_first=True
    )
    
    self.rnn2 = nn.LSTM(
      input_size=self.hidden_dim,
      hidden_size=embedding_dim,
      num_layers=1,
      batch_first=True
    )

  def forward(self, x):
    x = x.reshape((1, self.seq_len, self.n_features))

    x, (_, _) = self.rnn1(x)
    x, (hidden_n, _) = self.rnn2(x)

    return hidden_n.reshape((self.n_features, self.embedding_dim))
class Decoder(nn.Module):

  def __init__(self, seq_len, input_dim=64, n_features=1):
    super(Decoder, self).__init__()

    self.seq_len, self.input_dim = seq_len, input_dim
    self.hidden_dim, self.n_features = 2 * input_dim, n_features

    self.rnn1 = nn.LSTM(
      input_size=input_dim,
      hidden_size=input_dim,
      num_layers=1,
      batch_first=True
    )

    self.rnn2 = nn.LSTM(
      input_size=input_dim,
      hidden_size=self.hidden_dim,
      num_layers=1,
      batch_first=True
    )

    self.output_layer = nn.Linear(self.hidden_dim, n_features)

  def forward(self, x):
    x = x.repeat(self.seq_len, self.n_features)
    x = x.reshape((self.n_features, self.seq_len, self.input_dim))

    x, (hidden_n, cell_n) = self.rnn1(x)
    x, (hidden_n, cell_n) = self.rnn2(x)
    x = x.reshape((self.seq_len, self.hidden_dim))

    return self.output_layer(x)
class RecurrentAutoencoder(nn.Module):

  def __init__(self, seq_len, n_features, embedding_dim=64):
    super(RecurrentAutoencoder, self).__init__()

    self.encoder = Encoder(seq_len, n_features, embedding_dim).to(device)
    self.decoder = Decoder(seq_len, embedding_dim, n_features).to(device)

  def forward(self, x):
    x = self.encoder(x)
    x = self.decoder(x)

    return x
model = RecurrentAutoencoder(seq_len, n_features, 128)
model = model.to(device)

Training

def train_model(model, train_dataset, val_dataset, n_epochs):
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
  criterion = nn.L1Loss(reduction='sum').to(device)
  history = dict(train=[], val=[])

  best_model_wts = copy.deepcopy(model.state_dict())
  best_loss = 10000.0
  
  for epoch in range(1, n_epochs + 1):
    model = model.train()

    train_losses = []
    for seq_true in train_dataset:
      optimizer.zero_grad()

      seq_true = seq_true.to(device)
      seq_pred = model(seq_true)

      loss = criterion(seq_pred, seq_true)

      loss.backward()
      optimizer.step()

      train_losses.append(loss.item())

    val_losses = []
    model = model.eval()
    with torch.no_grad():
      for seq_true in val_dataset:

        seq_true = seq_true.to(device)
        seq_pred = model(seq_true)

        loss = criterion(seq_pred, seq_true)
        val_losses.append(loss.item())

    train_loss = np.mean(train_losses)
    val_loss = np.mean(val_losses)

    history['train'].append(train_loss)
    history['val'].append(val_loss)

    if val_loss < best_loss:
      best_loss = val_loss
      best_model_wts = copy.deepcopy(model.state_dict())

    print(f'Epoch {epoch}: train loss {train_loss} val loss {val_loss}')

  model.load_state_dict(best_model_wts)
  return model.eval(), history
model, history = train_model(
  model, 
  train_dataset, 
  val_dataset, 
  n_epochs=10
)
Epoch 1: train loss 78.96793919924237 val loss 57.008918449740364
Epoch 2: train loss 55.24031583285534 val loss 51.29970774471556
Epoch 3: train loss 50.97752316869684 val loss 50.35719702919189
Epoch 4: train loss 50.59584151271465 val loss 40.614140123230605
Epoch 5: train loss 38.15368703352449 val loss 37.547762724319824
Epoch 6: train loss 34.309216836055214 val loss 37.58409660261239
Epoch 7: train loss 31.98069694416026 val loss 34.298767298154864
Epoch 8: train loss 28.60677365553278 val loss 27.40926725628433
Epoch 9: train loss 26.80304576254141 val loss 24.187094398732885
Epoch 10: train loss 25.63979911073856 val loss 29.823875609518318
ax = plt.figure().gca()

ax.plot(history['train'])
ax.plot(history['val'])
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'test'])
plt.title('Loss over training epochs')
plt.show();
MODEL_PATH = 'Time-Series-ECG5000-Pytorch.pth'

torch.save(model, MODEL_PATH)

Choosing a threshold

def predict(model, dataset):
  predictions, losses = [], []
  criterion = nn.L1Loss(reduction='sum').to(device)
  with torch.no_grad():
    model = model.eval()
    for seq_true in dataset:
      seq_true = seq_true.to(device)
      seq_pred = model(seq_true)

      loss = criterion(seq_pred, seq_true)

      predictions.append(seq_pred.cpu().numpy().flatten())
      losses.append(loss.item())
  return predictions, losses
_, losses = predict(model, train_dataset)

sns.distplot(losses, bins=50, kde=True);
/home/siddy/anaconda3/envs/torch/lib/python3.8/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
THRESHOLD = 26

Evaluation

Normal hearbeats

predictions, pred_losses = predict(model, test_normal_dataset)
sns.distplot(pred_losses, bins=50, kde=True);
/home/siddy/anaconda3/envs/torch/lib/python3.8/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
correct = sum(l <= THRESHOLD for l in pred_losses)
print(f'Correct normal predictions: {correct}/{len(test_normal_dataset)}')
Correct normal predictions: 111/145

Anomalies

anomaly_dataset = test_anomaly_dataset[:len(test_normal_dataset)]
predictions, pred_losses = predict(model, anomaly_dataset)
sns.distplot(pred_losses, bins=50, kde=True);
/home/siddy/anaconda3/envs/torch/lib/python3.8/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
correct = sum(l > THRESHOLD for l in pred_losses)
print(f'Correct anomaly predictions: {correct}/{len(anomaly_dataset)}')
Correct anomaly predictions: 144/145

Looking at Examples

def plot_prediction(data, model, title, ax):
  predictions, pred_losses = predict(model, [data])

  ax.plot(data, label='true')
  ax.plot(predictions[0], label='reconstructed')
  ax.set_title(f'{title} (loss: {np.around(pred_losses[0], 2)})')
  ax.legend()
fig, axs = plt.subplots(
  nrows=2,
  ncols=6,
  sharey=True,
  sharex=True,
  figsize=(24, 8)
)

for i, data in enumerate(test_normal_dataset[:6]):
  plot_prediction(data, model, title='Normal', ax=axs[0, i])

for i, data in enumerate(test_anomaly_dataset[:6]):
  plot_prediction(data, model, title='Anomaly', ax=axs[1, i])

fig.tight_layout();