{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "D7UvwwCM0wC7" }, "source": [ "# Notebook\n", "\n", "**Authors:** Colin Small (crs1031@wildcats.unh.edu), Matthew Argall (Matthew.Argall@unh.edu), Marek Petrik (Marek.Petrik@unh.edu)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "tdDrn6mBTUgT" }, "source": [ "[![MMS Mission Video](https://upload.wikimedia.org/wikipedia/commons/thumb/b/b3/Artist_depiction_of_MMS_spacecraft_%28SVS12239%29.png/640px-Artist_depiction_of_MMS_spacecraft_%28SVS12239%29.png)](https://upload.wikimedia.org/wikipedia/commons/c/c9/NASA_Spacecraft_Finds_New_Magnetic_Process_in_Turbulent_Space.webm)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rL5p5pJQ39Dd" }, "source": [ "## Introduction\n", "Global-scale energy flow throughout Earth’s magnetosphere is catalyzed by processes that occur at Earth’s magnetopause (MP) in the electron diffusion region (EDR) of magnetic reconnection. Until the launch of the Magnetospheric Multiscale (MMS) mission, only rare, fortuitous circumstances permitted a glimpse of the electron dynamics that break magnetic field lines and energize plasma. MMS employs automated burst triggers onboard the spacecraft and a Scientist-in-the-Loop (SITL) on the ground to select intervals likely to contain diffusion regions. Only low-resolution survey data is available to the SITL, which is insufficient to resolve electron dynamics. A strategy for the SITL, then, is to select all MP crossings. This has resulted in over 35 potential MP EDR encounters but is labor- and resource-intensive; after manual reclassification, just ∼ 0.7% of MP crossings, or 0.0001% of the mission lifetime during MMS’s first two years contained an EDR.\n", "\n", "In this notebook, we develop a Long-Short Term Memory (LSTM) neural network to detect magnetopause crossings and automate the SITL classification process. An LSTM developed with this notebook has been implemented in the MMS data stream to provide automated predictions to the SITL.\n", "\n", "\n", "This model facilitates EDR studies and helps free-up mission operation costs by consolidating manual classification processes into automated routines." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "xRLV_B1CtKIP" }, "source": [ "**Authors' notes:** \n", "\n", "1. This notebook was developed after the development of the original model in use at the SDC. We have tried our best to replicate the development steps and hyperparameters of that model, but we cannot guarantee that models developed with this notebook will exactly match the performance of the original.\n", "\n", "2. This notebook was designed on, and is best run on, Google Colab. It must either be run on Colab or on a machine with an NVIDIA GPU and cuDNN installed. If your machine does not have an NVIDIA GPU, does not have cuDNN installed, or if you run into issues running this notebook yourself, please open the notebook in Google Colab, which provides you with a virtual GPU to run the notebook. (If TF Keras is unable to identify a GPU to run on, make sure the notebook is set to use one by clicking the \"Runtime\" tab in the top menu bar, selecting \"Change runtime type\", selecting \"GPU\" in the dropdown menu under \"Hardware accelerator\", and clicking save. Colab will refresh your timetime, and you will need to re-run all cells.):" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "K8513yuEtKIP" }, "source": [ "\"Open\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "4Em6RZ8Z6UFl" }, "source": [ "## Import Libraries\n", "\n", "To start, we import the neccesary libraries for this notebook." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 318 }, "colab_type": "code", "id": "3WshufyyiiLd", "outputId": "fb577231-e76a-4f6e-e81d-1c191e5e164e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: nasa-pymms in /usr/local/lib/python3.6/dist-packages (0.3.1)\n", "Requirement already satisfied: tqdm>=4.36.1 in /usr/local/lib/python3.6/dist-packages (from nasa-pymms) (4.41.1)\n", "Requirement already satisfied: matplotlib>=3.1.1 in /usr/local/lib/python3.6/dist-packages (from nasa-pymms) (3.2.2)\n", "Requirement already satisfied: requests>=2.22.0 in /usr/local/lib/python3.6/dist-packages (from nasa-pymms) (2.23.0)\n", "Requirement already satisfied: scipy>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from nasa-pymms) (1.4.1)\n", "Requirement already satisfied: cdflib in /usr/local/lib/python3.6/dist-packages (from nasa-pymms) (0.3.19)\n", "Requirement already satisfied: numpy>=1.8 in /usr/local/lib/python3.6/dist-packages (from nasa-pymms) (1.18.5)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.1.1->nasa-pymms) (1.2.0)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.1.1->nasa-pymms) (0.10.0)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.1.1->nasa-pymms) (2.4.7)\n", "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.1.1->nasa-pymms) (2.8.1)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.22.0->nasa-pymms) (2020.6.20)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.22.0->nasa-pymms) (3.0.4)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.22.0->nasa-pymms) (1.24.3)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.22.0->nasa-pymms) (2.10)\n", "Requirement already satisfied: attrs>=19.2.0 in /usr/local/lib/python3.6/dist-packages (from cdflib->nasa-pymms) (20.2.0)\n", "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from cycler>=0.10->matplotlib>=3.1.1->nasa-pymms) (1.15.0)\n" ] } ], "source": [ "!pip install nasa-pymms" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 87 }, "colab_type": "code", "id": "MwnLUN2cUd4S", "outputId": "784c11cf-a6ce-4e24-ad6e-2afe19051981" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TensorFlow 1.x selected.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Creating root data directory /root/data/mms\n", "Creating root data directory /root/data/mms/dropbox\n" ] } ], "source": [ "from pathlib import Path\n", "from sklearn import preprocessing\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense, Dropout, LSTM, CuDNNLSTM, BatchNormalization, Bidirectional, Reshape, TimeDistributed\n", "from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint\n", "from matplotlib import pyplot\n", "from sklearn.metrics import roc_curve, auc, confusion_matrix\n", "from keras import backend as K\n", "from pymms.sdc import mrmms_sdc_api as mms\n", "import keras.backend.tensorflow_backend as tfb\n", "import tensorflow as tf\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "import matplotlib.pyplot as plt\n", "plt.rcParams.update({'font.size': 18})\n", "import datetime as dt\n", "import os\n", "import time\n", "import sklearn\n", "import scipy\n", "import pickle\n", "import random\n", "import requests" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "x-LAsDIG3XD7" }, "source": [ "## Download, Preprocess, and Format MMS Data" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "gWIpfE2zSCOP" }, "source": [ "After installing and importinng the neccesary libraries, we download our training and validation data. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 407 }, "colab_type": "code", "id": "mcytuTHUR8U8", "outputId": "04175787-010b-43d0-9951-104ec5782768" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2020-09-16 16:36:14-- https://zenodo.org/record/3884266/files/original_training_data.csv?download=1\n", "Resolving zenodo.org (zenodo.org)... 188.184.117.155\n", "Connecting to zenodo.org (zenodo.org)|188.184.117.155|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 447842635 (427M) [text/plain]\n", "Saving to: ‘training_data.csv’\n", "\n", "training_data.csv 100%[===================>] 427.10M 7.45MB/s in 24s \n", "\n", "2020-09-16 16:36:39 (18.0 MB/s) - ‘training_data.csv’ saved [447842635/447842635]\n", "\n", "--2020-09-16 16:36:39-- https://zenodo.org/record/3884266/files/original_validation_data.csv?download=1\n", "Resolving zenodo.org (zenodo.org)... 188.184.117.155\n", "Connecting to zenodo.org (zenodo.org)|188.184.117.155|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 90314951 (86M) [text/plain]\n", "Saving to: ‘validation_data.csv’\n", "\n", "validation_data.csv 100%[===================>] 86.13M 8.10MB/s in 10s \n", "\n", "2020-09-16 16:36:49 (8.45 MB/s) - ‘validation_data.csv’ saved [90314951/90314951]\n", "\n" ] } ], "source": [ "!wget -O training_data.csv https://zenodo.org/record/3884266/files/original_training_data.csv?download=1\n", "!wget -O validation_data.csv https://zenodo.org/record/3884266/files/original_validation_data.csv?download=1" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "aDg-HJ0MAtjS" }, "source": [ "After downloading the training and validation data, we preprocess our training data in preparation for training the neural network." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "fp4mcnLfFZXw" }, "source": [ "We first load the data we downloaded above. The data is a table of measurements from the MMS spacecraft, where each row represents individual measurements taken at a given time and where each column represents a feature (variable) recorded at that time. There is an additional column representing the ground truths for each measurement (whether this measurement was selected by a SITL or not). Then, we will adjust the formatting and datatypes of several of the columns and sort the data by the time of the measurement." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "tojTr8i472HY" }, "outputs": [], "source": [ "mms_data = pd.read_csv('training_data.csv', index_col=0, infer_datetime_format=True,\n", "\t\t\t\t\t\t parse_dates=[0])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 623 }, "colab_type": "code", "id": "U2yZv1MwGqWQ", "outputId": "78dc7f6b-3f8c-43e6-b7b3-0bda18e92966" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mms1_des_energyspectr_omni_fast_0mms1_des_energyspectr_omni_fast_1mms1_des_energyspectr_omni_fast_2mms1_des_energyspectr_omni_fast_3mms1_des_energyspectr_omni_fast_4mms1_des_energyspectr_omni_fast_5mms1_des_energyspectr_omni_fast_6mms1_des_energyspectr_omni_fast_7mms1_des_energyspectr_omni_fast_8mms1_des_energyspectr_omni_fast_9mms1_des_energyspectr_omni_fast_10mms1_des_energyspectr_omni_fast_11mms1_des_energyspectr_omni_fast_12mms1_des_energyspectr_omni_fast_13mms1_des_energyspectr_omni_fast_14mms1_des_energyspectr_omni_fast_15mms1_des_energyspectr_omni_fast_16mms1_des_energyspectr_omni_fast_17mms1_des_energyspectr_omni_fast_18mms1_des_energyspectr_omni_fast_19mms1_des_energyspectr_omni_fast_20mms1_des_energyspectr_omni_fast_21mms1_des_energyspectr_omni_fast_22mms1_des_energyspectr_omni_fast_23mms1_des_energyspectr_omni_fast_24mms1_des_energyspectr_omni_fast_25mms1_des_energyspectr_omni_fast_26mms1_des_energyspectr_omni_fast_27mms1_des_energyspectr_omni_fast_28mms1_des_energyspectr_omni_fast_29mms1_des_energyspectr_omni_fast_30mms1_des_numberdensity_fastmms1_des_bulkv_dbcs_fast_0mms1_des_bulkv_dbcs_fast_1mms1_des_heatq_dbcs_fast_0mms1_des_heatq_dbcs_fast_1mms1_des_temppara_fastmms1_des_tempperp_fastmms1_des_prestensor_dbcs_fast_x1_y1mms1_des_prestensor_dbcs_fast_x2_y1...mms1_dis_energyspectr_omni_fast_29mms1_dis_energyspectr_omni_fast_30mms1_dis_numberdensity_fastmms1_dis_bulkv_dbcs_fast_0mms1_dis_bulkv_dbcs_fast_1mms1_dis_heatq_dbcs_fast_0mms1_dis_heatq_dbcs_fast_1mms1_dis_temppara_fastmms1_dis_tempperp_fastmms1_dis_prestensor_dbcs_fast_x1_y1mms1_dis_prestensor_dbcs_fast_x2_y1mms1_dis_prestensor_dbcs_fast_x2_y2mms1_dis_prestensor_dbcs_fast_x3_y1mms1_dis_prestensor_dbcs_fast_x3_y2mms1_dis_prestensor_dbcs_fast_x3_y3mms1_dis_temptensor_dbcs_fast_x1_y1mms1_dis_temptensor_dbcs_fast_x2_y1mms1_dis_temptensor_dbcs_fast_x2_y2mms1_dis_temptensor_dbcs_fast_x3_y1mms1_dis_temptensor_dbcs_fast_x3_y2mms1_dis_temptensor_dbcs_fast_x3_y3mms1_dis_temp_anisotropymms1_dis_scalar_temperaturemms1_dis_N_Qmms1_dis_Vz_Qmms1_dis_nV_Qmms1_afg_srvy_dmpa_Bxmms1_afg_srvy_dmpa_Bymms1_afg_srvy_dmpa_Bzmms1_afg_srvy_dmpa_|B|mms1_afg_magnetic_pressuremms1_afg_clock_anglemms1_afg_Bz_Qmms1_edp_xmms1_edp_ymms1_edp_zmms1_edp_|E|mms1_temp_ratiomms1_plasma_betaselected
Epoch
2017-01-01 01:49:08.736524172560370.0141811650.0115564310.0103489660.0109156240.0138017710.0199794740.0313011970.0495821400.07.675890e+081.102645e+091.355750e+091.283706e+09876457860.0437018100.0178497440.061324772.017706034.05950672.01453186.01230702.81290793.01361365.001443553.01538581.101600447.11693706.61860683.81821310.11785482.62143655.047.289574-64.066220-59.8503840.067417-0.03308768.66341469.6437200.5306810.000513...1847614.41368769.6045.166210-65.301070-75.960210-0.017214-0.054081496.30145558.620104.1053840.0393053.9793790.0775440.3727123.5914204.1053840.0393053.9793790.0775440.3727123.591420-0.111558537.847200.0000000.0000005217.443037-0.02619733.162117-39.65694051.6952602.126628e+091.5715861.467709-0.8622751.6650130.8270682.0493481.0524.89600False
2017-01-01 01:49:13.236552160474430.0134115120.0113082936.0106597736.0118977260.0157128700.0234196500.0368190240.0588617700.09.143196e+081.280595e+091.386385e+091.020292e+09546922600.0238565300.095543030.034221956.010795790.03346869.81132098.61230606.61290689.81370548.601217164.61303508.001396131.61496823.21930027.91878429.5675747.12085773.547.806107-34.529873-128.3425100.091984-0.00041259.05984062.0346300.4731320.000187...1385553.01053661.5046.629120-60.415287-118.971080-0.0037610.163968415.75192523.831503.8837460.0740543.9430840.0495510.0032023.1059793.8837460.0740543.9430840.0495510.0032023.105979-0.206325487.804961.09718239.2561626031.1446683.003209-17.920328-18.67769226.0578905.403419e+08-1.4047532.396316-0.5728420.670425-0.5702901.0501651.0929.00616False
2017-01-01 01:49:17.736573143446660.0120096200.0103888320.099487384.0114008050.0152906400.0228558100.0359750530.0572418500.08.799912e+081.192516e+091.208780e+098.369016e+08437589570.0191214960.078865330.028987052.09752265.03037489.21132141.21230653.41290740.11361307.601453176.61303684.901417762.91178981.41254457.61834328.21900604.51562550.443.848970-17.841991-80.636430-0.0175040.00735355.00536061.1002430.4326110.002111...1216220.01003711.9042.998802-37.589535-101.1163250.0022840.062132379.23505442.852943.202901-0.0232772.8988290.074439-0.1315502.6125943.202901-0.0232772.8988290.074439-0.1315502.612594-0.143655421.646971.62555647.4372225204.2414550.250873-3.805478-25.77231026.0529585.401373e+08-1.5049671.586905-0.3574240.633575-0.1603080.7448941.01132.09920False
2017-01-01 01:49:22.236602143115380.0120483730.0103335910.099983890.0115342700.0156445900.0236307790.0374869920.0598786900.09.304425e+081.259985e+091.203378e+097.605617e+08361202240.0147414610.058348600.021312232.07609467.03276093.51132188.61230705.21290795.81361368.001443556.01548742.101391640.81787557.81927728.01754844.51890987.52347546.844.116190-24.291311-51.504208-0.010656-0.06507856.31973657.4965970.4051840.001348...1146827.8996494.7542.281338-15.571530-88.482050-0.0395050.037287350.79077375.923342.6082620.0162382.484871-0.012547-0.0626052.3763142.6082620.0162382.484871-0.012547-0.0626052.376314-0.066856367.545802.43795063.6045324158.1290410.571565-26.909445-33.61694343.0644151.475799e+09-1.5495590.4142311.9629651.813941-1.0013142.8541631.0257.55072False
2017-01-01 01:49:26.736624154996270.0129635576.0109879810.0104667430.0118939120.0159722750.0240743730.0384687520.0622868740.09.923178e+081.407102e+091.379264e+098.729052e+08412877600.0165255970.064202252.022482496.07437224.53127818.01132103.51230611.91290695.21361258.801443435.91492447.101570121.21725425.41605647.42037316.21869736.62320483.047.787876-43.337650-68.804990-0.138768-0.23358156.95401458.3217540.449897-0.001409...1509011.91205417.5044.735374-29.033060-83.403510-0.047671-0.056385396.54170400.258332.9473060.0070262.790271-0.076785-0.3114232.8421502.9473060.0070262.790271-0.076785-0.3114232.842150-0.009286399.019440.08326257.7804164519.709417-1.338079-25.596290-35.26012843.5917101.512161e+09-1.6230250.0096671.6742210.755573-0.4228701.8848671.0423.39264False
......................................................................................................................................................................................................................................................
2017-01-31 01:59:37.612761195522660.0169047760.0151144740.0155255970.0187461070.0260878460.0391024700.0601519170.0933239230.01.431671e+092.001644e+091.911962e+099.213409e+08266009100.070228190.023561248.08219571.03132344.01069388.61156007.01205684.81285194.1776630.401443075.81507529.501614608.11747582.81609458.82050434.22227281.82083288.266.366800-114.178600-66.3549350.118533-0.16759054.73800049.9987600.5425380.017298...1173978.5946332.5046.211296-99.599365-137.597990-0.087668-0.111775269.56824291.431582.067940-0.0591062.0421810.0329450.0233442.2011182.067940-0.0591062.0421810.0329450.0233442.201118-0.075020284.143802.02677510.799598722.083772-12.985076-21.6497173.80409325.5302495.186808e+08-2.1110541.8258850.569317-1.014906-0.4782031.2581071.0451.70038False
2017-01-31 01:59:42.112805188999890.0167058830.0153385100.0160276850.0194781820.0265345520.0388861400.0597449200.0947348600.01.491336e+091.999952e+091.605463e+096.863398e+08206139580.060198680.021560316.08027363.53163052.51069362.91155979.51205655.91227593.41093269.801310439.91257120.901626893.61687317.91547724.21889413.21736235.62222297.064.247070-114.180830-57.8735240.079878-0.02705149.83791748.9502800.507997-0.001615...1173978.5946332.5046.211296-99.599365-137.597990-0.087668-0.111775269.56824291.431582.067940-0.0591062.0421810.0329450.0233442.2011182.067940-0.0591062.0421810.0329450.0233442.201118-0.075020284.143802.02677510.799598722.083772-15.43644711.2149944.80984519.6772613.081197e+082.5132842.0992051.721005-1.0337571.9008642.7647421.0205.54817False
2017-01-31 01:59:46.612839156668940.0136574260.0123834980.0129313670.0158046480.0222452110.0343151600.0552882500.0898671700.01.377260e+091.645098e+091.142623e+094.697292e+08148775630.047467200.018703922.07264802.52508120.01085820.61155923.61167832.61042989.01352492.101176428.91383997.801534980.21675194.21783336.11992431.11171974.12034246.553.937880-147.811950-66.3324360.0318490.02117248.72533447.5220900.415525-0.001351...1173978.5946332.5046.211296-99.599365-137.597990-0.087668-0.111775269.56824291.431582.067940-0.0591062.0421810.0329450.0233442.2011182.067940-0.0591062.0421810.0329450.0233442.201118-0.075020284.143802.02677510.799598722.0837729.732903-5.53475721.61504624.3428254.715547e+08-0.5170602.3475742.787715-2.293623-2.0824774.1675861.0136.35893False
2017-01-31 01:59:51.112882148339340.0131617280.0123383950.0133689650.0169121890.0242394990.0371708160.0592586500.0940750800.01.356828e+091.457272e+099.310631e+083.592990e+08107206430.034867150.014601406.05957955.01804840.91102969.01155993.61165338.01181596.9630540.941371131.61339956.001473016.61658352.91457381.91135691.62071354.12416306.052.308780-141.634200-85.162560-0.0317700.03575745.69808044.9619940.374102-0.003214...1173978.5946332.5046.211296-99.599365-137.597990-0.087668-0.111775269.56824291.431582.067940-0.0591062.0421810.0329450.0233442.2011182.067940-0.0591062.0421810.0329450.0233442.201118-0.075020284.143802.02677510.799598722.08377213.513677-6.46138629.84856433.3961908.875320e+08-0.4460053.3690513.261839-4.248680-2.1281065.7636541.098.59849False
2017-01-31 01:59:55.612917135440930.0119457550.0111787400.0121276030.0155176540.0227165100.0355259520.0570692000.0891547700.01.239421e+091.285931e+098.412920e+083.517979e+08116526360.041865636.018613406.08321585.53825070.01326246.51155995.01205672.41224776.41275854.501217636.2838192.251057988.51786348.41358997.82065777.12144108.02296880.848.376137-109.131290-104.492080-0.023458-0.05328546.80138045.0597900.350553-0.002084...1173978.5946332.5046.211296-99.599365-137.597990-0.087668-0.111775269.56824291.431582.067940-0.0591062.0421810.0329450.0233442.2011182.067940-0.0591062.0421810.0329450.0233442.201118-0.075020284.143802.02677510.799598722.083772-9.63113329.820732-9.85626532.8508878.587848e+081.8831893.069619-0.2514881.1932910.4947891.3160571.0431.81090False
\n", "

302188 rows × 124 columns

\n", "
" ], "text/plain": [ " mms1_des_energyspectr_omni_fast_0 ... selected\n", "Epoch ... \n", "2017-01-01 01:49:08.736524 172560370.0 ... False\n", "2017-01-01 01:49:13.236552 160474430.0 ... False\n", "2017-01-01 01:49:17.736573 143446660.0 ... False\n", "2017-01-01 01:49:22.236602 143115380.0 ... False\n", "2017-01-01 01:49:26.736624 154996270.0 ... False\n", "... ... ... ...\n", "2017-01-31 01:59:37.612761 195522660.0 ... False\n", "2017-01-31 01:59:42.112805 188999890.0 ... False\n", "2017-01-31 01:59:46.612839 156668940.0 ... False\n", "2017-01-31 01:59:51.112882 148339340.0 ... False\n", "2017-01-31 01:59:55.612917 135440930.0 ... False\n", "\n", "[302188 rows x 124 columns]" ] }, "execution_count": 5, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "mms_data[mms_data['selected'] == False]" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Fe4HZy7CHvFF" }, "source": [ "We save references to data's index and column names for later use and additionally pop off the ground truths column. We will reattach the ground truths column after standardizing and interpolating the data." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": {}, "colab_type": "code", "id": "d2bm6URzIeFT" }, "outputs": [], "source": [ "index = mms_data.index\n", "selections = mms_data.pop(\"selected\")\n", "column_names = mms_data.columns" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "9gBjgkNsJvPV" }, "source": [ "Since there exists a possibility that the training contains missing data or data misreported by the MMS spacecraft (reported as either infinity or negative infinity), we need to fill in (interpolate) any missing data." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "IBg8Fds8KIhW" }, "outputs": [], "source": [ "mms_data = mms_data.replace([np.inf, -np.inf], np.nan)\n", "mms_data = mms_data.interpolate(method='time', limit_area='inside')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rdMDvmcYLQ-1" }, "source": [ "We normalize all features with standardization:\n", "\n", "![z = (x - u) / s ](https://wikimedia.org/api/rest_v1/media/math/render/svg/b0aa2e7d203db1526c577192f2d9102b718eafd5)\n", "\n", "Where x̄ is the mean of the data, and σ is the standard deviation of the data.\n", "\n", "Normalization ensures that the numerical values of all features of the data fall within a range from one to negative one and are centered around their mean (zero-mean and unit variance). Normalization improves the speed and performance of training neural networks as it unifies the scale by which differences in the data are represented without altering the data themselves." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": {}, "colab_type": "code", "id": "pXlE0dj_PbqI" }, "outputs": [], "source": [ "scaler = preprocessing.StandardScaler()\n", "mms_data = scaler.fit_transform(mms_data)\n", "mms_data = pd.DataFrame(mms_data, index, column_names)\n", "mms_data = mms_data.join(selections)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "HytkkKWLR6KS" }, "source": [ "Next, we calculate class weights for our data classes (selected data points and non-selected data points). Since the distribution of our data is heavily skewed towards non-selected data points (just 1.9% of all data points in our training data were selected), it's important to give the class of selected data points a higher weight when training. In fact, without establishing these class weights our model would quickly acheive 98% accuracy by naively leaving all data points unselected." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": {}, "colab_type": "code", "id": "03pNFL5zTINT" }, "outputs": [], "source": [ "false_weight = len(mms_data)/(2*np.bincount(mms_data['selected'].values)[0])\n", "true_weight = len(mms_data)/(2*np.bincount(mms_data['selected'].values)[1])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ckN-_FOuTlI4" }, "source": [ "Our entire dataset is not contigous, and it contains time intervals with no observations. Therefore, we break it up into contigous chunks. We can do so by breaking up the data into the windows that the SITLs used to review the data." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": {}, "colab_type": "code", "id": "vuoPyoU5UV3Q" }, "outputs": [], "source": [ "sitl_windows = mms.mission_events('sroi', mms_data.index[0].to_pydatetime(), mms_data.index[-1].to_pydatetime(), sc='mms1')\n", "windows = []\n", "for start, end in zip(sitl_windows['tstart'], sitl_windows['tend']):\n", " window = mms_data[start:end]\n", " if not window.empty and len(window[window['selected']==True])>1:\n", " windows.append(window)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "colab_type": "code", "id": "cyCpZWhRfIOc", "outputId": "eeab4449-7153-41b4-cfec-8c4dab623228" }, "outputs": [ { "data": { "text/plain": [ "[ mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-01 01:49:31.236651 1.883959 ... False\n", " 2017-01-01 01:49:35.736674 1.458570 ... False\n", " 2017-01-01 01:49:40.236701 1.368072 ... False\n", " 2017-01-01 01:49:44.736723 1.338055 ... False\n", " 2017-01-01 01:49:49.236750 1.660276 ... False\n", " ... ... ... ...\n", " 2017-01-01 15:42:28.613021 -0.710802 ... False\n", " 2017-01-01 15:42:33.113061 -0.251731 ... False\n", " 2017-01-01 15:42:37.613092 -0.093185 ... False\n", " 2017-01-01 15:42:42.113131 -0.247850 ... False\n", " 2017-01-01 15:42:46.613161 -0.708025 ... False\n", " \n", " [11111 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-02 01:39:35.994680 0.320164 ... False\n", " 2017-01-02 01:39:40.494716 0.196985 ... False\n", " 2017-01-02 01:39:44.994745 0.157529 ... False\n", " 2017-01-02 01:39:49.494779 0.300838 ... False\n", " 2017-01-02 01:39:53.994809 0.211870 ... False\n", " ... ... ... ...\n", " 2017-01-02 15:32:28.885849 -0.294389 ... False\n", " 2017-01-02 15:32:33.385891 -0.736511 ... False\n", " 2017-01-02 15:32:37.885924 -1.195971 ... False\n", " 2017-01-02 15:32:42.385967 -1.370307 ... False\n", " 2017-01-02 15:32:46.886000 -1.384843 ... False\n", " \n", " [11110 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-03 01:34:28.067529 0.549291 ... False\n", " 2017-01-03 01:34:32.567559 0.407271 ... False\n", " 2017-01-03 01:34:37.067596 0.337537 ... False\n", " 2017-01-03 01:34:41.567624 0.476185 ... False\n", " 2017-01-03 01:34:46.067661 0.387942 ... False\n", " ... ... ... ...\n", " 2017-01-03 15:27:29.959758 2.154355 ... False\n", " 2017-01-03 15:27:34.459798 2.084295 ... False\n", " 2017-01-03 15:27:38.959830 2.215957 ... False\n", " 2017-01-03 15:27:43.459870 2.168310 ... False\n", " 2017-01-03 15:27:47.959902 2.102627 ... False\n", " \n", " [11112 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-04 01:29:30.169168 0.265701 ... False\n", " 2017-01-04 01:29:34.669196 0.287085 ... False\n", " 2017-01-04 01:29:39.169233 0.482386 ... False\n", " 2017-01-04 01:29:43.669261 0.021679 ... False\n", " 2017-01-04 01:29:48.169297 -0.160651 ... False\n", " ... ... ... ...\n", " 2017-01-04 15:22:27.565195 -0.666403 ... False\n", " 2017-01-04 15:22:32.065236 -0.602717 ... False\n", " 2017-01-04 15:22:36.565269 -0.642503 ... False\n", " 2017-01-04 15:22:41.065311 -0.610725 ... False\n", " 2017-01-04 15:22:45.565345 -0.660693 ... False\n", " \n", " [11111 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-05 01:19:35.672379 0.927498 ... False\n", " 2017-01-05 01:19:40.172417 0.804923 ... False\n", " 2017-01-05 01:19:44.672446 1.021925 ... False\n", " 2017-01-05 01:19:49.172483 0.780899 ... False\n", " 2017-01-05 01:19:53.672513 0.700766 ... False\n", " ... ... ... ...\n", " 2017-01-05 15:12:28.574822 0.022651 ... False\n", " 2017-01-05 15:12:33.074864 -0.033623 ... False\n", " 2017-01-05 15:12:37.574897 -0.141828 ... False\n", " 2017-01-05 15:12:42.074939 -0.203950 ... False\n", " 2017-01-05 15:12:46.574971 -0.221731 ... False\n", " \n", " [9510 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-06 01:14:31.176632 -0.941931 ... False\n", " 2017-01-06 01:14:35.676662 0.299075 ... False\n", " 2017-01-06 01:14:40.176699 0.678641 ... False\n", " 2017-01-06 01:14:44.676728 0.428966 ... False\n", " 2017-01-06 01:14:49.176766 0.655676 ... False\n", " ... ... ... ...\n", " 2017-01-06 13:59:40.546992 -1.342667 ... True\n", " 2017-01-06 13:59:45.047032 -1.300301 ... True\n", " 2017-01-06 13:59:49.547066 -1.278993 ... True\n", " 2017-01-06 13:59:54.047106 -1.267142 ... True\n", " 2017-01-06 13:59:58.547140 -1.197250 ... True\n", " \n", " [10207 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-07 02:00:02.850541 0.578615 ... False\n", " 2017-01-07 02:00:07.350580 0.330658 ... False\n", " 2017-01-07 02:00:11.850611 0.615171 ... False\n", " 2017-01-07 02:00:16.350650 0.575314 ... False\n", " 2017-01-07 02:00:20.850681 0.645435 ... False\n", " ... ... ... ...\n", " 2017-01-07 14:57:25.736533 0.171769 ... False\n", " 2017-01-07 14:57:30.236577 0.024648 ... False\n", " 2017-01-07 14:57:34.736611 0.036262 ... False\n", " 2017-01-07 14:57:39.236655 0.129640 ... False\n", " 2017-01-07 14:57:43.736689 0.209871 ... False\n", " \n", " [10370 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-08 00:59:30.743854 -0.149322 ... False\n", " 2017-01-08 00:59:35.243893 0.051192 ... False\n", " 2017-01-08 00:59:39.743923 -0.109274 ... False\n", " 2017-01-08 00:59:44.243962 -0.134377 ... False\n", " 2017-01-08 00:59:48.743993 0.040231 ... False\n", " ... ... ... ...\n", " 2017-01-08 14:52:28.160111 -1.145365 ... False\n", " 2017-01-08 14:52:32.660144 -1.120129 ... False\n", " 2017-01-08 14:52:37.160187 -0.970623 ... False\n", " 2017-01-08 14:52:41.660221 -1.047322 ... False\n", " 2017-01-08 14:52:46.160263 -1.139481 ... False\n", " \n", " [11111 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-09 00:54:34.297495 0.150986 ... False\n", " 2017-01-09 00:54:38.797525 0.054092 ... False\n", " 2017-01-09 00:54:43.297564 0.041586 ... False\n", " 2017-01-09 00:54:47.797595 0.211846 ... False\n", " 2017-01-09 00:54:52.297634 0.138488 ... False\n", " ... ... ... ...\n", " 2017-01-09 14:47:27.214472 -1.425762 ... False\n", " 2017-01-09 14:47:31.714506 -1.444290 ... False\n", " 2017-01-09 14:47:36.214549 -1.437870 ... False\n", " 2017-01-09 14:47:40.714583 -1.402435 ... False\n", " 2017-01-09 14:47:45.214625 -1.435300 ... False\n", " \n", " [11110 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-10 00:49:30.129621 0.250380 ... False\n", " 2017-01-10 00:49:34.629653 0.356107 ... False\n", " 2017-01-10 00:49:39.129693 0.265523 ... False\n", " 2017-01-10 00:49:43.629726 0.210121 ... False\n", " 2017-01-10 00:49:48.129766 0.314735 ... False\n", " ... ... ... ...\n", " 2017-01-10 14:42:27.556663 -1.232397 ... False\n", " 2017-01-10 14:42:32.056708 -1.235161 ... False\n", " 2017-01-10 14:42:36.556742 -1.265290 ... False\n", " 2017-01-10 14:42:41.056784 -1.235246 ... False\n", " 2017-01-10 14:42:45.556820 -1.244504 ... False\n", " \n", " [11111 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-11 00:44:31.554608 -0.438260 ... False\n", " 2017-01-11 00:44:36.054647 -0.247389 ... False\n", " 2017-01-11 00:44:40.554679 -0.360649 ... False\n", " 2017-01-11 00:44:45.054719 -0.461376 ... False\n", " 2017-01-11 00:44:49.554749 -0.250608 ... False\n", " ... ... ... ...\n", " 2017-01-11 12:39:34.919214 -1.374367 ... False\n", " 2017-01-11 12:39:39.419257 -1.388974 ... False\n", " 2017-01-11 12:39:43.919291 -1.387084 ... False\n", " 2017-01-11 12:39:48.419334 -1.379114 ... False\n", " 2017-01-11 12:39:52.919369 -1.375000 ... False\n", " \n", " [9539 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-12 00:34:26.372263 0.282036 ... False\n", " 2017-01-12 00:34:30.872292 0.207331 ... False\n", " 2017-01-12 00:34:35.372328 0.135283 ... False\n", " 2017-01-12 00:34:39.872357 0.205012 ... False\n", " 2017-01-12 00:34:44.372392 0.200723 ... False\n", " ... ... ... ...\n", " 2017-01-12 14:27:37.628688 -1.174282 ... False\n", " 2017-01-12 14:27:42.128726 -1.222883 ... False\n", " 2017-01-12 14:27:46.628757 -1.191178 ... False\n", " 2017-01-12 14:27:51.128794 -1.202539 ... False\n", " 2017-01-12 14:27:55.628825 -1.169961 ... False\n", " \n", " [9328 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-13 00:29:31.532187 0.433031 ... False\n", " 2017-01-13 00:29:36.032218 0.775808 ... False\n", " 2017-01-13 00:29:40.532246 0.681540 ... False\n", " 2017-01-13 00:29:45.032280 0.447294 ... False\n", " 2017-01-13 00:29:49.532306 0.495920 ... False\n", " ... ... ... ...\n", " 2017-01-13 14:22:37.896918 -1.122267 ... False\n", " 2017-01-13 14:22:42.396955 -1.095192 ... False\n", " 2017-01-13 14:22:46.896974 -1.078990 ... False\n", " 2017-01-13 14:22:51.397022 -1.045212 ... False\n", " 2017-01-13 14:22:55.897051 -1.078421 ... False\n", " \n", " [11113 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-14 00:24:33.973502 0.516664 ... False\n", " 2017-01-14 00:24:38.473535 0.454663 ... False\n", " 2017-01-14 00:24:42.973562 0.536189 ... False\n", " 2017-01-14 00:24:47.473595 0.508872 ... False\n", " 2017-01-14 00:24:51.973623 0.595519 ... False\n", " ... ... ... ...\n", " 2017-01-14 14:17:35.850795 -1.048623 ... False\n", " 2017-01-14 14:17:40.350835 -1.060130 ... False\n", " 2017-01-14 14:17:44.850865 -1.046742 ... False\n", " 2017-01-14 14:17:49.350906 -1.101527 ... False\n", " 2017-01-14 14:17:53.850938 -1.069602 ... False\n", " \n", " [11112 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-15 00:19:33.423590 0.124151 ... False\n", " 2017-01-15 00:19:37.923618 0.203979 ... False\n", " 2017-01-15 00:19:42.423652 0.194321 ... False\n", " 2017-01-15 00:19:46.923680 0.273831 ... False\n", " 2017-01-15 00:19:51.423716 0.097556 ... False\n", " ... ... ... ...\n", " 2017-01-15 14:12:35.302495 -1.251178 ... False\n", " 2017-01-15 14:12:39.802527 -1.217158 ... False\n", " 2017-01-15 14:12:44.302565 -1.277525 ... False\n", " 2017-01-15 14:12:48.802596 -1.249092 ... False\n", " 2017-01-15 14:12:53.302636 -1.251487 ... False\n", " \n", " [11112 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-16 00:14:29.606218 -0.753529 ... False\n", " 2017-01-16 00:14:34.106254 -0.608870 ... False\n", " 2017-01-16 00:14:38.606282 -0.700074 ... False\n", " 2017-01-16 00:14:43.106318 -0.569566 ... False\n", " 2017-01-16 00:14:47.606346 -0.840524 ... False\n", " ... ... ... ...\n", " 2017-01-16 14:07:35.991104 -1.021327 ... False\n", " 2017-01-16 14:07:40.491143 -1.073761 ... False\n", " 2017-01-16 14:07:44.991175 -1.060275 ... False\n", " 2017-01-16 14:07:49.491213 -1.045543 ... False\n", " 2017-01-16 14:07:53.991245 -1.043030 ... False\n", " \n", " [11113 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-17 00:09:33.097461 0.618697 ... False\n", " 2017-01-17 00:09:37.597490 -0.010759 ... False\n", " 2017-01-17 00:09:42.097524 -0.202842 ... False\n", " 2017-01-17 00:09:46.597553 0.128834 ... False\n", " 2017-01-17 00:09:51.097587 0.501555 ... False\n", " ... ... ... ...\n", " 2017-01-17 13:57:33.480536 -1.056261 ... False\n", " 2017-01-17 13:57:37.980569 -1.028945 ... False\n", " 2017-01-17 13:57:42.480608 -1.056892 ... False\n", " 2017-01-17 13:57:46.980639 -1.055822 ... False\n", " 2017-01-17 13:57:51.480679 -1.044705 ... False\n", " \n", " [11045 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-17 23:59:30.257868 -0.460953 ... False\n", " 2017-01-17 23:59:34.757896 -0.477479 ... True\n", " 2017-01-17 23:59:39.257933 -0.540812 ... True\n", " 2017-01-17 23:59:43.757962 -0.410302 ... True\n", " 2017-01-17 23:59:48.257997 -0.714016 ... True\n", " ... ... ... ...\n", " 2017-01-18 13:47:35.148867 -0.351026 ... False\n", " 2017-01-18 13:47:39.648900 0.028805 ... False\n", " 2017-01-18 13:47:44.148942 -0.060379 ... False\n", " 2017-01-18 13:47:48.648976 -0.036331 ... False\n", " 2017-01-18 13:47:53.149017 0.230107 ... False\n", " \n", " [11046 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-18 23:54:30.109686 -1.077935 ... False\n", " 2017-01-18 23:54:34.609716 -0.473158 ... False\n", " 2017-01-18 23:54:39.109751 -0.668999 ... False\n", " 2017-01-18 23:54:43.609781 -0.616445 ... False\n", " 2017-01-18 23:54:48.109818 -1.026126 ... False\n", " ... ... ... ...\n", " 2017-01-19 13:42:35.005498 -0.561375 ... False\n", " 2017-01-19 13:42:39.505531 -0.582718 ... False\n", " 2017-01-19 13:42:44.005572 -0.569202 ... False\n", " 2017-01-19 13:42:48.505605 -0.630146 ... False\n", " 2017-01-19 13:42:53.005646 -0.643737 ... False\n", " \n", " [11046 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-19 23:54:32.495962 0.887217 ... False\n", " 2017-01-19 23:54:36.995991 1.040867 ... False\n", " 2017-01-19 23:54:41.496029 0.968065 ... False\n", " 2017-01-19 23:54:45.996058 0.874947 ... False\n", " 2017-01-19 23:54:50.496094 0.996150 ... False\n", " ... ... ... ...\n", " 2017-01-20 13:42:32.893654 -1.341355 ... False\n", " 2017-01-20 13:42:37.393695 -1.340042 ... False\n", " 2017-01-20 13:42:41.893728 -1.315527 ... False\n", " 2017-01-20 13:42:46.393768 -1.372444 ... False\n", " 2017-01-20 13:42:50.893801 -1.342465 ... False\n", " \n", " [11045 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-20 23:44:28.879445 0.258103 ... False\n", " 2017-01-20 23:44:33.379480 0.165344 ... False\n", " 2017-01-20 23:44:37.879512 0.238723 ... False\n", " 2017-01-20 23:44:42.379549 0.240764 ... False\n", " 2017-01-20 23:44:46.879579 0.325817 ... False\n", " ... ... ... ...\n", " 2017-01-21 13:32:33.783922 -1.281203 ... False\n", " 2017-01-21 13:32:38.283965 -1.240908 ... False\n", " 2017-01-21 13:32:42.783999 -1.245359 ... False\n", " 2017-01-21 13:32:47.284043 -1.222616 ... False\n", " 2017-01-21 13:32:51.784076 -1.244364 ... False\n", " \n", " [11046 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-21 23:39:32.397247 0.648462 ... False\n", " 2017-01-21 23:39:36.897277 0.125554 ... False\n", " 2017-01-21 23:39:41.397315 0.644362 ... False\n", " 2017-01-21 23:39:45.897345 0.794314 ... False\n", " 2017-01-21 23:39:50.397383 0.748639 ... False\n", " ... ... ... ...\n", " 2017-01-22 13:27:32.804047 -1.281105 ... False\n", " 2017-01-22 13:27:37.304090 -1.316070 ... False\n", " 2017-01-22 13:27:41.804122 -1.296095 ... False\n", " 2017-01-22 13:27:46.304166 -1.274427 ... False\n", " 2017-01-22 13:27:50.804198 -1.259403 ... False\n", " \n", " [11045 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-22 23:34:32.497506 0.371171 ... False\n", " 2017-01-22 23:34:36.997536 0.306008 ... False\n", " 2017-01-22 23:34:41.497574 0.222378 ... False\n", " 2017-01-22 23:34:45.997588 0.303669 ... False\n", " 2017-01-22 23:34:50.497642 0.300867 ... False\n", " ... ... ... ...\n", " 2017-01-23 13:22:37.407060 -1.084231 ... False\n", " 2017-01-23 13:22:41.907095 -1.132926 ... False\n", " 2017-01-23 13:22:46.407137 -1.138529 ... False\n", " 2017-01-23 13:22:50.907171 -1.112983 ... False\n", " 2017-01-23 13:22:55.407213 -1.114196 ... False\n", " \n", " [11046 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-23 23:30:17.579249 0.625622 ... False\n", " 2017-01-23 23:30:22.079286 0.446295 ... False\n", " 2017-01-23 23:30:26.579317 0.340665 ... False\n", " 2017-01-23 23:30:31.079355 0.494095 ... False\n", " 2017-01-23 23:30:35.579386 0.285995 ... False\n", " ... ... ... ...\n", " 2017-01-24 08:16:56.833894 0.304322 ... False\n", " 2017-01-24 08:17:01.333935 0.400155 ... False\n", " 2017-01-24 08:17:05.833968 0.147983 ... False\n", " 2017-01-24 08:17:10.334009 0.348916 ... False\n", " 2017-01-24 08:17:14.834041 0.285743 ... False\n", " \n", " [7026 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-24 23:25:19.951727 0.397294 ... False\n", " 2017-01-24 23:25:24.451768 0.536392 ... False\n", " 2017-01-24 23:25:28.951798 0.343806 ... False\n", " 2017-01-24 23:25:33.451838 0.284095 ... False\n", " 2017-01-24 23:25:37.951870 0.483337 ... False\n", " ... ... ... ...\n", " 2017-01-25 13:06:44.364210 -1.271923 ... False\n", " 2017-01-25 13:06:48.864244 -1.243242 ... False\n", " 2017-01-25 13:06:53.364289 -1.241685 ... False\n", " 2017-01-25 13:06:57.864325 -1.216165 ... False\n", " 2017-01-25 13:07:02.364367 -1.243583 ... False\n", " \n", " [10957 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-25 23:20:21.915845 0.629981 ... False\n", " 2017-01-25 23:20:26.415884 0.712099 ... False\n", " 2017-01-25 23:20:30.915915 0.690743 ... False\n", " 2017-01-25 23:20:35.415953 0.771954 ... False\n", " 2017-01-25 23:20:39.915985 0.598590 ... False\n", " ... ... ... ...\n", " 2017-01-26 13:01:41.834981 1.751733 ... False\n", " 2017-01-26 13:01:46.335025 1.385139 ... False\n", " 2017-01-26 13:01:50.835061 1.505051 ... False\n", " 2017-01-26 13:01:55.335105 2.003533 ... False\n", " 2017-01-26 13:01:59.835140 2.122904 ... False\n", " \n", " [10956 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-26 23:15:17.804388 1.620798 ... False\n", " 2017-01-26 23:15:22.304427 1.499084 ... False\n", " 2017-01-26 23:15:26.804459 1.292690 ... False\n", " 2017-01-26 23:15:31.304499 1.313568 ... False\n", " 2017-01-26 23:15:35.804530 1.579847 ... False\n", " ... ... ... ...\n", " 2017-01-27 12:56:46.723774 -1.481024 ... False\n", " 2017-01-27 12:56:51.223817 -1.474493 ... False\n", " 2017-01-27 12:56:55.723852 -1.458199 ... False\n", " 2017-01-27 12:57:00.223895 -1.479174 ... False\n", " 2017-01-27 12:57:04.723930 -1.497171 ... False\n", " \n", " [10957 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-27 23:05:17.911570 0.274722 ... False\n", " 2017-01-27 23:05:22.411609 0.108195 ... False\n", " 2017-01-27 23:05:26.911641 0.005769 ... False\n", " 2017-01-27 23:05:31.411680 -0.030007 ... False\n", " 2017-01-27 23:05:35.911711 -0.060532 ... False\n", " ... ... ... ...\n", " 2017-01-28 12:46:42.333037 -1.247610 ... False\n", " 2017-01-28 12:46:46.833072 -1.237887 ... False\n", " 2017-01-28 12:46:51.333117 -1.219926 ... False\n", " 2017-01-28 12:46:55.833153 -1.233068 ... False\n", " 2017-01-28 12:47:00.333198 -1.255480 ... False\n", " \n", " [10957 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-28 23:00:18.906511 0.274405 ... False\n", " 2017-01-28 23:00:23.406551 0.333638 ... False\n", " 2017-01-28 23:00:27.906583 0.324077 ... False\n", " 2017-01-28 23:00:32.406623 0.413629 ... False\n", " 2017-01-28 23:00:36.906655 0.240758 ... False\n", " ... ... ... ...\n", " 2017-01-29 12:41:43.330800 -1.292784 ... False\n", " 2017-01-29 12:41:47.830835 -1.283827 ... False\n", " 2017-01-29 12:41:52.330878 -1.331312 ... False\n", " 2017-01-29 12:41:56.830914 -1.298229 ... False\n", " 2017-01-29 12:42:01.330958 -1.317063 ... False\n", " \n", " [10957 rows x 124 columns],\n", " mms1_des_energyspectr_omni_fast_0 ... selected\n", " Epoch ... \n", " 2017-01-29 23:00:22.371527 0.205416 ... False\n", " 2017-01-29 23:00:26.871558 0.070299 ... False\n", " 2017-01-29 23:00:31.371597 0.251162 ... False\n", " 2017-01-29 23:00:35.871629 0.143420 ... False\n", " 2017-01-29 23:00:40.371669 0.299162 ... False\n", " ... ... ... ...\n", " 2017-01-30 12:41:37.792835 3.766927 ... False\n", " 2017-01-30 12:41:42.292879 3.941938 ... False\n", " 2017-01-30 12:41:46.792914 4.024080 ... False\n", " 2017-01-30 12:41:51.292959 4.471779 ... False\n", " 2017-01-30 12:41:55.792994 4.125512 ... False\n", " \n", " [10954 rows x 124 columns]]" ] }, "execution_count": 11, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "windows" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "btKBuEPkUrXW" }, "source": [ "Finally, we break up our data into individual sequences that will be fed to our neural network." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "pY_SGxXFU88m" }, "source": [ "We define a SEQ_LEN variable that will determine the length of our sequences. This variable will also be passed to our network so that it knows how long of a data sequence to expect while training. The choice of sequence length is largely arbitrary." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": {}, "colab_type": "code", "id": "YQzSPJXkU7J8" }, "outputs": [], "source": [ "SEQ_LEN = 250" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "gtXhi2H2j7x6" }, "source": [ "For each window, we assemble two sequences: an X_sequence containing individual data points from our training data and a y_sequence containing the truth values for those data points (whether or not those data points were selected by a SITL). \n", "\n", "We add those sequences to four collections: X_train and y_train containing X_sequences and y_sequences for our training data and X_test and y_test containing X_sequences and y_sequences for our testing data. We allocate 80% of the sequences to trainining and the remaining 20% to testing. " ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": {}, "colab_type": "code", "id": "w9ajG41MVJWc" }, "outputs": [], "source": [ " while True:\n", " X_train, X_test, y_train, y_test = [], [], [], []\n", "\n", " sequences = []\n", " for i in range(len(windows)):\n", " X_sequence = []\n", " y_sequence = []\n", "\n", " if random.random() < 0.6:\n", " for value in windows[i].values:\n", " X_sequence.append(value[:-1])\n", " y_sequence.append(value[-1])\n", " if len(X_sequence) == SEQ_LEN:\n", " X_train.append(X_sequence.copy())\n", " \n", " y_train.append(y_sequence.copy())\n", "\n", " X_sequence = []\n", " y_sequence = []\n", "\n", " else:\n", " for value in windows[i].values:\n", " X_sequence.append(value[:-1])\n", " y_sequence.append(value[-1])\n", " if len(X_sequence) == SEQ_LEN:\n", " X_test.append(X_sequence.copy())\n", " \n", " y_test.append(y_sequence.copy())\n", "\n", " X_sequence = []\n", " y_sequence = []\n", "\n", " X_train = np.array(X_train)\n", " X_test = np.array(X_test)\n", " y_train = np.expand_dims(np.array(y_train), axis=2)\n", " y_test = np.expand_dims(np.array(y_test), axis=2)\n", "\n", " if len(X_train) > len(X_test):\n", " break" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "NcGTxO4tYIEa" }, "source": [ "We can see how many sequences of data we have for training and testing, respectively:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 52 }, "colab_type": "code", "id": "w6Xc30NSYIgK", "outputId": "dc987b91-eaab-46f3-ddd0-bb3ee20bb321" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of sequences in training data: 753\n", "Number of sequences in test data: 519\n" ] } ], "source": [ "print(f\"Number of sequences in training data: {len(X_train)}\")\n", "print(f\"Number of sequences in test data: {len(X_test)}\")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "CRsSmLuDs029" }, "source": [ "## Define and Train LSTM\n", "\n", "Now that we have processed our data into our training and test sets, we can begin to build and train and our LSTM." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "r8Hqo34gs4cx" }, "source": [ "First, we need to define a custom F1 score and weighted binary crossentropy functions." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MPl9Nicf4K_d" }, "source": [ "An F1 score is a measure of a model's accuracy, calculated as a balance of the model's precision (the number of true positives predicted by the model divided by the total number of positives predicted by the model) and recall (the number of true positives predicted by the model divided by the number of actual positives in the data):\n", "\n", "![F1 = 2 * (precision * recall) / (precision + recall)](https://wikimedia.org/api/rest_v1/media/math/render/svg/1bf179c30b00db201ce1895d88fe2915d58e6bfd)\n", "\n", "We will evaluate our model using the F1 score since we want to strike a balance between the model's precision and recall. Remember, we cannot use true accuracy (the number of true positives and true negatives divided by the number of data points in the data) because of the imbalance between our classes.\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": {}, "colab_type": "code", "id": "EKhoUVlFmfUx" }, "outputs": [], "source": [ "# (Credit: Paddy and Kev1n91 from https://stackoverflow.com/a/45305384/3988976)\n", "def f1(y_true, y_pred):\n", " def recall(y_true, y_pred):\n", " \"\"\"Recall metric.\n", "\n", " Only computes a batch-wise average of recall.\n", "\n", " Computes the recall, a metric for multi-label classification of\n", " how many relevant items are selected.\n", " \"\"\"\n", " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n", " possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))\n", " recall = true_positives / (possible_positives + K.epsilon())\n", " return recall\n", "\n", " def precision(y_true, y_pred):\n", " \"\"\"Precision metric.\n", "\n", " Only computes a batch-wise average of precision.\n", "\n", " Computes the precision, a metric for multi-label classification of\n", " how many selected items are relevant.\n", " \"\"\"\n", " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n", " predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))\n", " precision = true_positives / (predicted_positives + K.epsilon())\n", " return precision\n", " precision = precision(y_true, y_pred)\n", " recall = recall(y_true, y_pred)\n", " return 2*((precision*recall)/(precision+recall+K.epsilon()))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "emVoV71O4MAD" }, "source": [ "Cross-entropy is a function used to determine the loss between a set of predictions and their truth values. The larger the difference between a prediction and its true value, the larger the loss will be. In general, many machine learning architectures (including our LSTM) are designed to minimize their given loss function. A perfect model will have a loss of 0.\n", "\n", "Binary cross-entropy is used when we only have two classes (in our case, selected or not selected) and weighted binary cross-entropy allows us to assign a weight to one of the classes. This weight can effectively increase or decrease the loss of that class. In our case, we have previously defined a variable *true_weight* to be the class weight for positive (selected) datapoints. We will pass that weight into the function.\n", "\n", "This cross-entropy function will be passed in to our model as our loss function.\n", "\n", "(Because the loss function of a model needs to be differentiable to perform gradient descent, we cannot use our F1 score as our loss function.)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": {}, "colab_type": "code", "id": "yBxIZXbMpJG9" }, "outputs": [], "source": [ "# (Credit: tobigue from https://stackoverflow.com/questions/42158866/neural-network-for-multi-label-classification-with-large-number-of-classes-outpu)\n", "def weighted_binary_crossentropy(target, output):\n", " \"\"\"\n", " Weighted binary crossentropy between an output tensor \n", " and a target tensor. POS_WEIGHT is used as a multiplier \n", " for the positive targets.\n", "\n", " Combination of the following functions:\n", " * keras.losses.binary_crossentropy\n", " * keras.backend.tensorflow_backend.binary_crossentropy\n", " * tf.nn.weighted_cross_entropy_with_logits\n", " \"\"\"\n", " # transform back to logits\n", " _epsilon = tfb._to_tensor(tfb.epsilon(), output.dtype.base_dtype)\n", " output = tf.clip_by_value(output, _epsilon, 1 - _epsilon)\n", " output = tf.log(output / (1 - output))\n", " # compute weighted loss\n", " loss = tf.nn.weighted_cross_entropy_with_logits(targets=target,\n", " logits=output,\n", " pos_weight=true_weight)\n", " return tf.reduce_mean(loss, axis=-1)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "V7ScqhL295sl" }, "source": [ "Before building our LSTM, we define several hyperparameters that will define how the model is trained:\n", "\n", "EPOCHS: The number of times the model trains through our entire dataset\n", "\n", "BATCH_SIZE: The number of sequences that our model trains using at any given point\n", "\n", "LAYER_SIZE: The number of LSTM internal to each layer of the model.\n", "\n", "Choices for these hyperparameters are largely arbitrary and can be altered to tune our LSTM." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": {}, "colab_type": "code", "id": "U7HBAE5wwTDU" }, "outputs": [], "source": [ "EPOCHS = 100\n", "BATCH_SIZE = 128\n", "LAYER_SIZE = 300" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "WMHwIiE9Aec9" }, "source": [ "We now define our LSTM.\n", "\n", "For this version of the model, we two bidirectional LSTM layers, two dropout layers, and one time distributed dense layer.\n", "\n", "Internally, an LSTM layer uses a for loop to iterate over the timesteps of a sequence, while maintaining states that encode information from those timesteps. Using these internal states, the LSTM learns the characteristics of our data (the X_sequences we defined earlier) and how those data relate to our expected output (the y_sequences we defined earlier). Normal (unidirectional) LSTMs only encode information from prior-seen timesteps. Bidirectional LSTMs can can encode information prior to and after a given timestep.\n", "\n", "With the addition of a dense layer, the LSTM will output a value between 0 and 1 that corresponds to the model's certainty about whether or not a timestep was selected by the SITL.\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 303 }, "colab_type": "code", "id": "3jU7BfnJBj0V", "outputId": "9605a3ef-47b3-4477-faf8-a60049a7719d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/init_ops.py:97: calling GlorotUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Call initializer instance with the dtype argument instead of passing it to the constructor\n", "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/init_ops.py:97: calling Orthogonal.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Call initializer instance with the dtype argument instead of passing it to the constructor\n", "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/init_ops.py:97: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Call initializer instance with the dtype argument instead of passing it to the constructor\n", "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "If using Keras pass *_constraint arguments to layers.\n", "WARNING:tensorflow:From :20: calling weighted_cross_entropy_with_logits (from tensorflow.python.ops.nn_impl) with targets is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "targets is deprecated, use labels instead\n" ] } ], "source": [ "model_name = f\"{SEQ_LEN}-SEQ_LEN-{BATCH_SIZE}-BATCH_SIZE-{LAYER_SIZE}-LAYER_SIZE-{int(time.time())}\"\n", "\n", "model = Sequential()\n", "\n", "model.add(Bidirectional(LSTM(LAYER_SIZE, return_sequences=True), input_shape=(None, X_train.shape[2])))\n", "\n", "model.add(Dropout(0.4))\n", "\n", "model.add(Bidirectional(LSTM(LAYER_SIZE, return_sequences=True), input_shape=(None, X_train.shape[2])))\n", "\n", "model.add(Dropout(0.4))\n", "\n", "model.add(TimeDistributed(Dense(1, activation='sigmoid')))\n", "\n", "opt = tf.keras.optimizers.Adam()\n", "\n", "model.compile(loss=weighted_binary_crossentropy,\n", " optimizer=opt,\n", " metrics=['accuracy', f1, tf.keras.metrics.Precision()])" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 336 }, "colab_type": "code", "id": "asBMdOujBrbk", "outputId": "8cecf30c-864d-485b-ce81-05d1fb1b31ed" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "bidirectional (Bidirectional (None, None, 600) 1017600 \n", "_________________________________________________________________\n", "dropout (Dropout) (None, None, 600) 0 \n", "_________________________________________________________________\n", "bidirectional_1 (Bidirection (None, None, 600) 2162400 \n", "_________________________________________________________________\n", "dropout_1 (Dropout) (None, None, 600) 0 \n", "_________________________________________________________________\n", "time_distributed (TimeDistri (None, None, 1) 601 \n", "=================================================================\n", "Total params: 3,180,601\n", "Trainable params: 3,180,601\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "vIaRlNZJBkNE" }, "source": [ "We set our training process to save the best versions of our model according to the previously defined F1 score. Each epoch, if a version of the model is trained with a higher F1 score than the previous best, the model saved on disk will be overwritten with the current best model." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": {}, "colab_type": "code", "id": "rFaaHkcwCP7X" }, "outputs": [], "source": [ "filepath = \"mp-dl-unh\" \n", "checkpoint = ModelCheckpoint(filepath, monitor='val_f1', verbose=1, save_best_only=True, mode='max')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "l8TWoDyKCRrE" }, "source": [ "The following will train the model and save the training history for later visualization." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "colab_type": "code", "id": "UCHQs29JCZMU", "outputId": "ab439715-f690-438e-a0d1-d0745dc11466" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/math_grad.py:1424: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.where in 2.0, which has the same broadcast rule as np.where\n", "Train on 753 samples, validate on 519 samples\n", "Epoch 1/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.8397 - acc: 0.6974 - f1: 0.3451 - precision: 0.1849\n", "Epoch 00001: val_f1 improved from -inf to 0.42757, saving model to mp-dl-unh\n", "753/753 [==============================] - 15s 19ms/sample - loss: 0.7737 - acc: 0.7232 - f1: 0.3541 - precision: 0.1909 - val_loss: 0.6396 - val_acc: 0.8299 - val_f1: 0.4276 - val_precision: 0.2014\n", "Epoch 2/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.6136 - acc: 0.8747 - f1: 0.5461 - precision: 0.3894\n", "Epoch 00002: val_f1 did not improve from 0.42757\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.5620 - acc: 0.8849 - f1: 0.5549 - precision: 0.3958 - val_loss: 0.5994 - val_acc: 0.8189 - val_f1: 0.2793 - val_precision: 0.1862\n", "Epoch 3/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.5949 - acc: 0.8367 - f1: 0.4756 - precision: 0.3271\n", "Epoch 00003: val_f1 did not improve from 0.42757\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.5476 - acc: 0.8467 - f1: 0.4810 - precision: 0.3289 - val_loss: 0.4256 - val_acc: 0.8731 - val_f1: 0.3077 - val_precision: 0.2567\n", "Epoch 4/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.4788 - acc: 0.8978 - f1: 0.5785 - precision: 0.4482\n", "Epoch 00004: val_f1 improved from 0.42757 to 0.50001, saving model to mp-dl-unh\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.4369 - acc: 0.9070 - f1: 0.5971 - precision: 0.4580 - val_loss: 0.3668 - val_acc: 0.9128 - val_f1: 0.5000 - val_precision: 0.3469\n", "Epoch 5/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.4156 - acc: 0.8689 - f1: 0.5636 - precision: 0.3923\n", "Epoch 00005: val_f1 did not improve from 0.50001\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.3772 - acc: 0.8815 - f1: 0.5819 - precision: 0.4019 - val_loss: 0.3872 - val_acc: 0.9183 - val_f1: 0.4034 - val_precision: 0.3632\n", "Epoch 6/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.3745 - acc: 0.9192 - f1: 0.6637 - precision: 0.5168\n", "Epoch 00006: val_f1 did not improve from 0.50001\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.3420 - acc: 0.9244 - f1: 0.6662 - precision: 0.5181 - val_loss: 0.4222 - val_acc: 0.8848 - val_f1: 0.3684 - val_precision: 0.2857\n", "Epoch 7/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.3198 - acc: 0.9050 - f1: 0.6347 - precision: 0.4735\n", "Epoch 00007: val_f1 improved from 0.50001 to 0.51366, saving model to mp-dl-unh\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.2950 - acc: 0.9119 - f1: 0.6405 - precision: 0.4771 - val_loss: 0.4070 - val_acc: 0.8908 - val_f1: 0.5137 - val_precision: 0.2971\n", "Epoch 8/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2917 - acc: 0.9302 - f1: 0.6997 - precision: 0.5554\n", "Epoch 00008: val_f1 did not improve from 0.51366\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.2700 - acc: 0.9330 - f1: 0.6947 - precision: 0.5498 - val_loss: 0.4036 - val_acc: 0.9009 - val_f1: 0.4554 - val_precision: 0.3199\n", "Epoch 9/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2607 - acc: 0.9331 - f1: 0.7108 - precision: 0.5656\n", "Epoch 00009: val_f1 did not improve from 0.51366\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.2403 - acc: 0.9365 - f1: 0.7086 - precision: 0.5627 - val_loss: 0.4023 - val_acc: 0.9023 - val_f1: 0.4948 - val_precision: 0.3243\n", "Epoch 10/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2405 - acc: 0.9361 - f1: 0.7232 - precision: 0.5768\n", "Epoch 00010: val_f1 did not improve from 0.51366\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.2214 - acc: 0.9394 - f1: 0.7211 - precision: 0.5741 - val_loss: 0.4330 - val_acc: 0.9020 - val_f1: 0.4683 - val_precision: 0.3236\n", "Epoch 11/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2338 - acc: 0.9385 - f1: 0.7312 - precision: 0.5870\n", "Epoch 00011: val_f1 did not improve from 0.51366\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.2133 - acc: 0.9426 - f1: 0.7347 - precision: 0.5887 - val_loss: 0.4408 - val_acc: 0.9032 - val_f1: 0.3857 - val_precision: 0.3265\n", "Epoch 12/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2178 - acc: 0.9344 - f1: 0.7238 - precision: 0.5687\n", "Epoch 00012: val_f1 did not improve from 0.51366\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1993 - acc: 0.9396 - f1: 0.7310 - precision: 0.5738 - val_loss: 0.4235 - val_acc: 0.9129 - val_f1: 0.4649 - val_precision: 0.3482\n", "Epoch 13/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2071 - acc: 0.9458 - f1: 0.7570 - precision: 0.6187\n", "Epoch 00013: val_f1 did not improve from 0.51366\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1917 - acc: 0.9483 - f1: 0.7531 - precision: 0.6145 - val_loss: 0.4799 - val_acc: 0.9062 - val_f1: 0.4858 - val_precision: 0.3319\n", "Epoch 14/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2104 - acc: 0.9504 - f1: 0.7719 - precision: 0.6420\n", "Epoch 00014: val_f1 improved from 0.51366 to 0.52297, saving model to mp-dl-unh\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1932 - acc: 0.9527 - f1: 0.7684 - precision: 0.6379 - val_loss: 0.5000 - val_acc: 0.8869 - val_f1: 0.5230 - val_precision: 0.2922\n", "Epoch 15/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2647 - acc: 0.9105 - f1: 0.6608 - precision: 0.4893\n", "Epoch 00015: val_f1 did not improve from 0.52297\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.2380 - acc: 0.9205 - f1: 0.6866 - precision: 0.5039 - val_loss: 0.4565 - val_acc: 0.9263 - val_f1: 0.4257 - val_precision: 0.3876\n", "Epoch 16/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2233 - acc: 0.9420 - f1: 0.7482 - precision: 0.6011\n", "Epoch 00016: val_f1 did not improve from 0.52297\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.2056 - acc: 0.9446 - f1: 0.7432 - precision: 0.5968 - val_loss: 0.4489 - val_acc: 0.8952 - val_f1: 0.3653 - val_precision: 0.3088\n", "Epoch 17/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2032 - acc: 0.9488 - f1: 0.7627 - precision: 0.6338\n", "Epoch 00017: val_f1 did not improve from 0.52297\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1852 - acc: 0.9525 - f1: 0.7681 - precision: 0.6367 - val_loss: 0.4300 - val_acc: 0.9151 - val_f1: 0.4169 - val_precision: 0.3557\n", "Epoch 18/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1870 - acc: 0.9412 - f1: 0.7454 - precision: 0.5952\n", "Epoch 00018: val_f1 improved from 0.52297 to 0.54961, saving model to mp-dl-unh\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1700 - acc: 0.9465 - f1: 0.7572 - precision: 0.6034 - val_loss: 0.4437 - val_acc: 0.9228 - val_f1: 0.5496 - val_precision: 0.3775\n", "Epoch 19/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1757 - acc: 0.9597 - f1: 0.8024 - precision: 0.6911\n", "Epoch 00019: val_f1 did not improve from 0.54961\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1617 - acc: 0.9615 - f1: 0.7995 - precision: 0.6864 - val_loss: 0.4601 - val_acc: 0.9069 - val_f1: 0.4623 - val_precision: 0.3355\n", "Epoch 20/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1771 - acc: 0.9395 - f1: 0.7388 - precision: 0.5875\n", "Epoch 00020: val_f1 improved from 0.54961 to 0.57562, saving model to mp-dl-unh\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1675 - acc: 0.9453 - f1: 0.7526 - precision: 0.5974 - val_loss: 0.4548 - val_acc: 0.9226 - val_f1: 0.5756 - val_precision: 0.3766\n", "Epoch 21/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1746 - acc: 0.9539 - f1: 0.7852 - precision: 0.6552\n", "Epoch 00021: val_f1 did not improve from 0.57562\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1626 - acc: 0.9557 - f1: 0.7794 - precision: 0.6496 - val_loss: 0.4955 - val_acc: 0.9261 - val_f1: 0.4233 - val_precision: 0.3881\n", "Epoch 22/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1686 - acc: 0.9592 - f1: 0.8006 - precision: 0.6878\n", "Epoch 00022: val_f1 did not improve from 0.57562\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1551 - acc: 0.9614 - f1: 0.8001 - precision: 0.6854 - val_loss: 0.5072 - val_acc: 0.9091 - val_f1: 0.4278 - val_precision: 0.3412\n", "Epoch 23/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1871 - acc: 0.9423 - f1: 0.7492 - precision: 0.6002\n", "Epoch 00023: val_f1 did not improve from 0.57562\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1701 - acc: 0.9480 - f1: 0.7638 - precision: 0.6105 - val_loss: 0.4906 - val_acc: 0.9208 - val_f1: 0.4136 - val_precision: 0.3704\n", "Epoch 24/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1611 - acc: 0.9591 - f1: 0.8048 - precision: 0.6841\n", "Epoch 00024: val_f1 improved from 0.57562 to 0.58272, saving model to mp-dl-unh\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1511 - acc: 0.9617 - f1: 0.8058 - precision: 0.6854 - val_loss: 0.4396 - val_acc: 0.9231 - val_f1: 0.5827 - val_precision: 0.3796\n", "Epoch 25/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1908 - acc: 0.9429 - f1: 0.7511 - precision: 0.6034\n", "Epoch 00025: val_f1 did not improve from 0.58272\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1759 - acc: 0.9468 - f1: 0.7537 - precision: 0.6052 - val_loss: 0.4258 - val_acc: 0.9310 - val_f1: 0.5641 - val_precision: 0.4088\n", "Epoch 26/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1779 - acc: 0.9626 - f1: 0.8137 - precision: 0.7120\n", "Epoch 00026: val_f1 did not improve from 0.58272\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.1625 - acc: 0.9649 - f1: 0.8147 - precision: 0.7107 - val_loss: 0.4802 - val_acc: 0.9150 - val_f1: 0.5226 - val_precision: 0.3566\n", "Epoch 27/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2158 - acc: 0.9254 - f1: 0.7089 - precision: 0.5350\n", "Epoch 00027: val_f1 did not improve from 0.58272\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1950 - acc: 0.9332 - f1: 0.7270 - precision: 0.5473 - val_loss: 0.4961 - val_acc: 0.9092 - val_f1: 0.3794 - val_precision: 0.3318\n", "Epoch 28/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1749 - acc: 0.9619 - f1: 0.8177 - precision: 0.7044\n", "Epoch 00028: val_f1 did not improve from 0.58272\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1619 - acc: 0.9623 - f1: 0.8055 - precision: 0.6914 - val_loss: 0.4533 - val_acc: 0.9053 - val_f1: 0.4927 - val_precision: 0.3323\n", "Epoch 29/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1875 - acc: 0.9393 - f1: 0.7402 - precision: 0.5881\n", "Epoch 00029: val_f1 did not improve from 0.58272\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1698 - acc: 0.9455 - f1: 0.7568 - precision: 0.5993 - val_loss: 0.4593 - val_acc: 0.9332 - val_f1: 0.5228 - val_precision: 0.4165\n", "Epoch 30/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1464 - acc: 0.9578 - f1: 0.8005 - precision: 0.6740\n", "Epoch 00030: val_f1 improved from 0.58272 to 0.58565, saving model to mp-dl-unh\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1355 - acc: 0.9602 - f1: 0.8003 - precision: 0.6732 - val_loss: 0.4969 - val_acc: 0.9232 - val_f1: 0.5856 - val_precision: 0.3801\n", "Epoch 31/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1404 - acc: 0.9606 - f1: 0.8147 - precision: 0.6887\n", "Epoch 00031: val_f1 did not improve from 0.58565\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.1298 - acc: 0.9633 - f1: 0.8167 - precision: 0.6905 - val_loss: 0.5545 - val_acc: 0.9055 - val_f1: 0.4497 - val_precision: 0.3303\n", "Epoch 32/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1644 - acc: 0.9551 - f1: 0.7862 - precision: 0.6627\n", "Epoch 00032: val_f1 did not improve from 0.58565\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.1520 - acc: 0.9591 - f1: 0.7960 - precision: 0.6699 - val_loss: 0.5339 - val_acc: 0.9174 - val_f1: 0.3990 - val_precision: 0.3617\n", "Epoch 33/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1696 - acc: 0.9425 - f1: 0.7557 - precision: 0.6000\n", "Epoch 00033: val_f1 did not improve from 0.58565\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1558 - acc: 0.9483 - f1: 0.7703 - precision: 0.6109 - val_loss: 0.5203 - val_acc: 0.9297 - val_f1: 0.4219 - val_precision: 0.3965\n", "Epoch 34/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2050 - acc: 0.9600 - f1: 0.8113 - precision: 0.6996\n", "Epoch 00034: val_f1 did not improve from 0.58565\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1913 - acc: 0.9600 - f1: 0.7960 - precision: 0.6826 - val_loss: 0.5136 - val_acc: 0.8894 - val_f1: 0.4170 - val_precision: 0.2989\n", "Epoch 35/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.2066 - acc: 0.9286 - f1: 0.7089 - precision: 0.5464\n", "Epoch 00035: val_f1 did not improve from 0.58565\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.1891 - acc: 0.9360 - f1: 0.7271 - precision: 0.5584 - val_loss: 0.4396 - val_acc: 0.9269 - val_f1: 0.4287 - val_precision: 0.3925\n", "Epoch 36/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1601 - acc: 0.9583 - f1: 0.7980 - precision: 0.6796\n", "Epoch 00036: val_f1 did not improve from 0.58565\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.1476 - acc: 0.9607 - f1: 0.7982 - precision: 0.6783 - val_loss: 0.4820 - val_acc: 0.9197 - val_f1: 0.4863 - val_precision: 0.3710\n", "Epoch 37/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1486 - acc: 0.9507 - f1: 0.7752 - precision: 0.6365\n", "Epoch 00037: val_f1 did not improve from 0.58565\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1382 - acc: 0.9545 - f1: 0.7814 - precision: 0.6408 - val_loss: 0.4972 - val_acc: 0.9274 - val_f1: 0.4364 - val_precision: 0.3943\n", "Epoch 38/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1395 - acc: 0.9684 - f1: 0.8379 - precision: 0.7417\n", "Epoch 00038: val_f1 did not improve from 0.58565\n", "753/753 [==============================] - 11s 15ms/sample - loss: 0.1305 - acc: 0.9692 - f1: 0.8310 - precision: 0.7329 - val_loss: 0.5471 - val_acc: 0.9167 - val_f1: 0.5373 - val_precision: 0.3608\n", "Epoch 39/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1557 - acc: 0.9579 - f1: 0.8039 - precision: 0.6752\n", "Epoch 00039: val_f1 did not improve from 0.58565\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.1427 - acc: 0.9614 - f1: 0.8108 - precision: 0.6808 - val_loss: 0.5140 - val_acc: 0.9252 - val_f1: 0.4214 - val_precision: 0.3871\n", "Epoch 40/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1403 - acc: 0.9549 - f1: 0.7917 - precision: 0.6567\n", "Epoch 00040: val_f1 improved from 0.58565 to 0.61811, saving model to mp-dl-unh\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1295 - acc: 0.9590 - f1: 0.8019 - precision: 0.6646 - val_loss: 0.5232 - val_acc: 0.9301 - val_f1: 0.6181 - val_precision: 0.4024\n", "Epoch 41/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1241 - acc: 0.9669 - f1: 0.8349 - precision: 0.7268\n", "Epoch 00041: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1149 - acc: 0.9692 - f1: 0.8372 - precision: 0.7278 - val_loss: 0.5646 - val_acc: 0.9202 - val_f1: 0.5297 - val_precision: 0.3688\n", "Epoch 42/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1113 - acc: 0.9679 - f1: 0.8384 - precision: 0.7304\n", "Epoch 00042: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1038 - acc: 0.9704 - f1: 0.8435 - precision: 0.7343 - val_loss: 0.5828 - val_acc: 0.9263 - val_f1: 0.5464 - val_precision: 0.3863\n", "Epoch 43/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0992 - acc: 0.9745 - f1: 0.8704 - precision: 0.7740\n", "Epoch 00043: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0928 - acc: 0.9761 - f1: 0.8715 - precision: 0.7748 - val_loss: 0.5917 - val_acc: 0.9313 - val_f1: 0.5702 - val_precision: 0.4053\n", "Epoch 44/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0955 - acc: 0.9770 - f1: 0.8823 - precision: 0.7930\n", "Epoch 00044: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0893 - acc: 0.9782 - f1: 0.8808 - precision: 0.7913 - val_loss: 0.6317 - val_acc: 0.9255 - val_f1: 0.4839 - val_precision: 0.3868\n", "Epoch 45/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1074 - acc: 0.9722 - f1: 0.8606 - precision: 0.7602\n", "Epoch 00045: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0990 - acc: 0.9746 - f1: 0.8660 - precision: 0.7649 - val_loss: 0.6379 - val_acc: 0.9302 - val_f1: 0.5052 - val_precision: 0.4005\n", "Epoch 46/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1209 - acc: 0.9640 - f1: 0.8311 - precision: 0.7056\n", "Epoch 00046: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.1117 - acc: 0.9664 - f1: 0.8319 - precision: 0.7069 - val_loss: 0.6452 - val_acc: 0.9338 - val_f1: 0.4438 - val_precision: 0.4144\n", "Epoch 47/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1087 - acc: 0.9763 - f1: 0.8758 - precision: 0.7941\n", "Epoch 00047: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1004 - acc: 0.9774 - f1: 0.8739 - precision: 0.7907 - val_loss: 0.6740 - val_acc: 0.9174 - val_f1: 0.5592 - val_precision: 0.3589\n", "Epoch 48/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1176 - acc: 0.9667 - f1: 0.8328 - precision: 0.7232\n", "Epoch 00048: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1074 - acc: 0.9698 - f1: 0.8419 - precision: 0.7300 - val_loss: 0.7245 - val_acc: 0.9103 - val_f1: 0.4316 - val_precision: 0.3335\n", "Epoch 49/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0990 - acc: 0.9761 - f1: 0.8772 - precision: 0.7874\n", "Epoch 00049: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0921 - acc: 0.9774 - f1: 0.8755 - precision: 0.7854 - val_loss: 0.6579 - val_acc: 0.9159 - val_f1: 0.4782 - val_precision: 0.3560\n", "Epoch 50/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1273 - acc: 0.9606 - f1: 0.8043 - precision: 0.6868\n", "Epoch 00050: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.1173 - acc: 0.9646 - f1: 0.8183 - precision: 0.6969 - val_loss: 0.6570 - val_acc: 0.9305 - val_f1: 0.4314 - val_precision: 0.3995\n", "Epoch 51/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1187 - acc: 0.9750 - f1: 0.8755 - precision: 0.7838\n", "Epoch 00051: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1125 - acc: 0.9751 - f1: 0.8645 - precision: 0.7723 - val_loss: 0.5983 - val_acc: 0.9080 - val_f1: 0.5348 - val_precision: 0.3320\n", "Epoch 52/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1176 - acc: 0.9691 - f1: 0.8425 - precision: 0.7401\n", "Epoch 00052: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.1075 - acc: 0.9715 - f1: 0.8479 - precision: 0.7436 - val_loss: 0.5963 - val_acc: 0.9281 - val_f1: 0.5786 - val_precision: 0.3919\n", "Epoch 53/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0958 - acc: 0.9731 - f1: 0.8657 - precision: 0.7643\n", "Epoch 00053: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0890 - acc: 0.9754 - f1: 0.8711 - precision: 0.7693 - val_loss: 0.5994 - val_acc: 0.9329 - val_f1: 0.4421 - val_precision: 0.4104\n", "Epoch 54/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0904 - acc: 0.9796 - f1: 0.8945 - precision: 0.8134\n", "Epoch 00054: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0837 - acc: 0.9807 - f1: 0.8931 - precision: 0.8118 - val_loss: 0.6191 - val_acc: 0.9318 - val_f1: 0.5887 - val_precision: 0.4080\n", "Epoch 55/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0798 - acc: 0.9788 - f1: 0.8889 - precision: 0.8035\n", "Epoch 00055: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0743 - acc: 0.9801 - f1: 0.8897 - precision: 0.8038 - val_loss: 0.6557 - val_acc: 0.9358 - val_f1: 0.5383 - val_precision: 0.4236\n", "Epoch 56/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0775 - acc: 0.9822 - f1: 0.9072 - precision: 0.8325\n", "Epoch 00056: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0720 - acc: 0.9832 - f1: 0.9067 - precision: 0.8317 - val_loss: 0.6965 - val_acc: 0.9321 - val_f1: 0.5790 - val_precision: 0.4076\n", "Epoch 57/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0784 - acc: 0.9786 - f1: 0.8890 - precision: 0.8014\n", "Epoch 00057: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0743 - acc: 0.9804 - f1: 0.8933 - precision: 0.8066 - val_loss: 0.7530 - val_acc: 0.9334 - val_f1: 0.4333 - val_precision: 0.4096\n", "Epoch 58/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0754 - acc: 0.9807 - f1: 0.8987 - precision: 0.8184\n", "Epoch 00058: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0704 - acc: 0.9816 - f1: 0.8968 - precision: 0.8163 - val_loss: 0.7244 - val_acc: 0.9340 - val_f1: 0.4407 - val_precision: 0.4142\n", "Epoch 59/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0727 - acc: 0.9844 - f1: 0.9139 - precision: 0.8513\n", "Epoch 00059: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0678 - acc: 0.9848 - f1: 0.9102 - precision: 0.8465 - val_loss: 0.7185 - val_acc: 0.9304 - val_f1: 0.4268 - val_precision: 0.4012\n", "Epoch 60/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0800 - acc: 0.9764 - f1: 0.8778 - precision: 0.7848\n", "Epoch 00060: val_f1 did not improve from 0.61811\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0737 - acc: 0.9788 - f1: 0.8865 - precision: 0.7926 - val_loss: 0.8552 - val_acc: 0.9380 - val_f1: 0.5040 - val_precision: 0.4282\n", "Epoch 61/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0829 - acc: 0.9833 - f1: 0.9133 - precision: 0.8460\n", "Epoch 00061: val_f1 improved from 0.61811 to 0.62276, saving model to mp-dl-unh\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0767 - acc: 0.9838 - f1: 0.9090 - precision: 0.8411 - val_loss: 0.6703 - val_acc: 0.9328 - val_f1: 0.6228 - val_precision: 0.4134\n", "Epoch 62/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1030 - acc: 0.9752 - f1: 0.8691 - precision: 0.7839\n", "Epoch 00062: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0927 - acc: 0.9775 - f1: 0.8774 - precision: 0.7899 - val_loss: 0.6816 - val_acc: 0.9329 - val_f1: 0.4430 - val_precision: 0.4115\n", "Epoch 63/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1066 - acc: 0.9658 - f1: 0.8387 - precision: 0.7154\n", "Epoch 00063: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0994 - acc: 0.9680 - f1: 0.8388 - precision: 0.7161 - val_loss: 0.7887 - val_acc: 0.9395 - val_f1: 0.5539 - val_precision: 0.4356\n", "Epoch 64/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1056 - acc: 0.9796 - f1: 0.8957 - precision: 0.8205\n", "Epoch 00064: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0962 - acc: 0.9807 - f1: 0.8945 - precision: 0.8184 - val_loss: 0.6480 - val_acc: 0.9326 - val_f1: 0.5222 - val_precision: 0.4126\n", "Epoch 65/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0979 - acc: 0.9704 - f1: 0.8489 - precision: 0.7450\n", "Epoch 00065: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0896 - acc: 0.9736 - f1: 0.8617 - precision: 0.7551 - val_loss: 0.6877 - val_acc: 0.9373 - val_f1: 0.5169 - val_precision: 0.4292\n", "Epoch 66/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0807 - acc: 0.9820 - f1: 0.9010 - precision: 0.8328\n", "Epoch 00066: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0743 - acc: 0.9834 - f1: 0.9041 - precision: 0.8348 - val_loss: 0.7158 - val_acc: 0.9290 - val_f1: 0.5774 - val_precision: 0.3962\n", "Epoch 67/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0812 - acc: 0.9760 - f1: 0.8776 - precision: 0.7820\n", "Epoch 00067: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0742 - acc: 0.9783 - f1: 0.8851 - precision: 0.7887 - val_loss: 0.7923 - val_acc: 0.9326 - val_f1: 0.5752 - val_precision: 0.4064\n", "Epoch 68/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0761 - acc: 0.9855 - f1: 0.9219 - precision: 0.8645\n", "Epoch 00068: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0700 - acc: 0.9862 - f1: 0.9206 - precision: 0.8623 - val_loss: 0.7038 - val_acc: 0.9285 - val_f1: 0.5257 - val_precision: 0.3945\n", "Epoch 69/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0764 - acc: 0.9787 - f1: 0.8876 - precision: 0.8028\n", "Epoch 00069: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0700 - acc: 0.9805 - f1: 0.8923 - precision: 0.8067 - val_loss: 0.7089 - val_acc: 0.9354 - val_f1: 0.5950 - val_precision: 0.4220\n", "Epoch 70/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0776 - acc: 0.9852 - f1: 0.9193 - precision: 0.8628\n", "Epoch 00070: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0709 - acc: 0.9862 - f1: 0.9208 - precision: 0.8634 - val_loss: 0.7348 - val_acc: 0.9339 - val_f1: 0.5063 - val_precision: 0.4143\n", "Epoch 71/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0956 - acc: 0.9716 - f1: 0.8585 - precision: 0.7520\n", "Epoch 00071: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0872 - acc: 0.9743 - f1: 0.8664 - precision: 0.7587 - val_loss: 0.7796 - val_acc: 0.9365 - val_f1: 0.4359 - val_precision: 0.4240\n", "Epoch 72/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0798 - acc: 0.9852 - f1: 0.9184 - precision: 0.8621\n", "Epoch 00072: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0739 - acc: 0.9860 - f1: 0.9177 - precision: 0.8603 - val_loss: 0.6781 - val_acc: 0.9315 - val_f1: 0.5643 - val_precision: 0.4053\n", "Epoch 73/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0780 - acc: 0.9763 - f1: 0.8782 - precision: 0.7841\n", "Epoch 00073: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0716 - acc: 0.9783 - f1: 0.8835 - precision: 0.7886 - val_loss: 0.7094 - val_acc: 0.9371 - val_f1: 0.4491 - val_precision: 0.4293\n", "Epoch 74/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0702 - acc: 0.9856 - f1: 0.9217 - precision: 0.8623\n", "Epoch 00074: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0648 - acc: 0.9867 - f1: 0.9233 - precision: 0.8637 - val_loss: 0.7048 - val_acc: 0.9354 - val_f1: 0.5295 - val_precision: 0.4215\n", "Epoch 75/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0825 - acc: 0.9770 - f1: 0.8841 - precision: 0.7894\n", "Epoch 00075: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0766 - acc: 0.9787 - f1: 0.8862 - precision: 0.7919 - val_loss: 0.7872 - val_acc: 0.9431 - val_f1: 0.4503 - val_precision: 0.4560\n", "Epoch 76/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0949 - acc: 0.9846 - f1: 0.9168 - precision: 0.8601\n", "Epoch 00076: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0857 - acc: 0.9857 - f1: 0.9188 - precision: 0.8612 - val_loss: 0.7124 - val_acc: 0.9257 - val_f1: 0.5371 - val_precision: 0.3839\n", "Epoch 77/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1075 - acc: 0.9680 - f1: 0.8385 - precision: 0.7302\n", "Epoch 00077: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0967 - acc: 0.9715 - f1: 0.8521 - precision: 0.7404 - val_loss: 0.6554 - val_acc: 0.9293 - val_f1: 0.5662 - val_precision: 0.3928\n", "Epoch 78/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1066 - acc: 0.9816 - f1: 0.9014 - precision: 0.8366\n", "Epoch 00078: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0962 - acc: 0.9827 - f1: 0.9015 - precision: 0.8352 - val_loss: 0.5857 - val_acc: 0.9272 - val_f1: 0.5520 - val_precision: 0.3899\n", "Epoch 79/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.1019 - acc: 0.9699 - f1: 0.8456 - precision: 0.7424\n", "Epoch 00079: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0923 - acc: 0.9731 - f1: 0.8578 - precision: 0.7519 - val_loss: 0.6039 - val_acc: 0.9381 - val_f1: 0.5612 - val_precision: 0.4322\n", "Epoch 80/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0848 - acc: 0.9832 - f1: 0.9094 - precision: 0.8446\n", "Epoch 00080: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0779 - acc: 0.9842 - f1: 0.9101 - precision: 0.8446 - val_loss: 0.6343 - val_acc: 0.9332 - val_f1: 0.5017 - val_precision: 0.4113\n", "Epoch 81/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0787 - acc: 0.9779 - f1: 0.8874 - precision: 0.7957\n", "Epoch 00081: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0718 - acc: 0.9799 - f1: 0.8936 - precision: 0.8015 - val_loss: 0.7117 - val_acc: 0.9346 - val_f1: 0.4091 - val_precision: 0.4157\n", "Epoch 82/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0771 - acc: 0.9854 - f1: 0.9218 - precision: 0.8642\n", "Epoch 00082: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0702 - acc: 0.9862 - f1: 0.9215 - precision: 0.8631 - val_loss: 0.7005 - val_acc: 0.9317 - val_f1: 0.4330 - val_precision: 0.4045\n", "Epoch 83/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0696 - acc: 0.9811 - f1: 0.9012 - precision: 0.8209\n", "Epoch 00083: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0638 - acc: 0.9831 - f1: 0.9082 - precision: 0.8276 - val_loss: 0.7277 - val_acc: 0.9353 - val_f1: 0.4932 - val_precision: 0.4181\n", "Epoch 84/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0619 - acc: 0.9858 - f1: 0.9236 - precision: 0.8598\n", "Epoch 00084: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0568 - acc: 0.9868 - f1: 0.9251 - precision: 0.8612 - val_loss: 0.7442 - val_acc: 0.9352 - val_f1: 0.5954 - val_precision: 0.4190\n", "Epoch 85/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0591 - acc: 0.9845 - f1: 0.9155 - precision: 0.8484\n", "Epoch 00085: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0544 - acc: 0.9858 - f1: 0.9188 - precision: 0.8514 - val_loss: 0.8300 - val_acc: 0.9384 - val_f1: 0.5778 - val_precision: 0.4317\n", "Epoch 86/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0568 - acc: 0.9879 - f1: 0.9340 - precision: 0.8792\n", "Epoch 00086: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0524 - acc: 0.9886 - f1: 0.9341 - precision: 0.8790 - val_loss: 0.8088 - val_acc: 0.9378 - val_f1: 0.5847 - val_precision: 0.4296\n", "Epoch 87/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0531 - acc: 0.9865 - f1: 0.9272 - precision: 0.8653\n", "Epoch 00087: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0491 - acc: 0.9878 - f1: 0.9312 - precision: 0.8693 - val_loss: 0.8416 - val_acc: 0.9402 - val_f1: 0.5885 - val_precision: 0.4411\n", "Epoch 88/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0506 - acc: 0.9886 - f1: 0.9374 - precision: 0.8845\n", "Epoch 00088: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0467 - acc: 0.9893 - f1: 0.9378 - precision: 0.8847 - val_loss: 0.8522 - val_acc: 0.9400 - val_f1: 0.4461 - val_precision: 0.4407\n", "Epoch 89/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0477 - acc: 0.9893 - f1: 0.9402 - precision: 0.8904\n", "Epoch 00089: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0441 - acc: 0.9900 - f1: 0.9415 - precision: 0.8914 - val_loss: 0.8778 - val_acc: 0.9400 - val_f1: 0.5881 - val_precision: 0.4405\n", "Epoch 90/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0468 - acc: 0.9888 - f1: 0.9382 - precision: 0.8857\n", "Epoch 00090: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0432 - acc: 0.9898 - f1: 0.9409 - precision: 0.8883 - val_loss: 0.9254 - val_acc: 0.9417 - val_f1: 0.5656 - val_precision: 0.4487\n", "Epoch 91/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0440 - acc: 0.9908 - f1: 0.9482 - precision: 0.9043\n", "Epoch 00091: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0409 - acc: 0.9913 - f1: 0.9474 - precision: 0.9032 - val_loss: 0.9377 - val_acc: 0.9405 - val_f1: 0.5700 - val_precision: 0.4427\n", "Epoch 92/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0436 - acc: 0.9906 - f1: 0.9463 - precision: 0.9025\n", "Epoch 00092: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0403 - acc: 0.9912 - f1: 0.9471 - precision: 0.9030 - val_loss: 0.9483 - val_acc: 0.9404 - val_f1: 0.4586 - val_precision: 0.4416\n", "Epoch 93/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0461 - acc: 0.9889 - f1: 0.9404 - precision: 0.8862\n", "Epoch 00093: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0426 - acc: 0.9898 - f1: 0.9423 - precision: 0.8885 - val_loss: 0.9636 - val_acc: 0.9421 - val_f1: 0.4537 - val_precision: 0.4496\n", "Epoch 94/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0441 - acc: 0.9914 - f1: 0.9523 - precision: 0.9107\n", "Epoch 00094: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0413 - acc: 0.9916 - f1: 0.9497 - precision: 0.9080 - val_loss: 0.8937 - val_acc: 0.9405 - val_f1: 0.5775 - val_precision: 0.4432\n", "Epoch 95/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0471 - acc: 0.9886 - f1: 0.9361 - precision: 0.8842\n", "Epoch 00095: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0435 - acc: 0.9896 - f1: 0.9392 - precision: 0.8874 - val_loss: 0.9624 - val_acc: 0.9423 - val_f1: 0.4566 - val_precision: 0.4509\n", "Epoch 96/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0444 - acc: 0.9905 - f1: 0.9482 - precision: 0.9015\n", "Epoch 00096: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0414 - acc: 0.9909 - f1: 0.9465 - precision: 0.8997 - val_loss: 0.9201 - val_acc: 0.9415 - val_f1: 0.4490 - val_precision: 0.4480\n", "Epoch 97/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0445 - acc: 0.9914 - f1: 0.9511 - precision: 0.9121\n", "Epoch 00097: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 11s 14ms/sample - loss: 0.0410 - acc: 0.9919 - f1: 0.9505 - precision: 0.9109 - val_loss: 0.9170 - val_acc: 0.9384 - val_f1: 0.4426 - val_precision: 0.4335\n", "Epoch 98/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0584 - acc: 0.9850 - f1: 0.9210 - precision: 0.8519\n", "Epoch 00098: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0533 - acc: 0.9865 - f1: 0.9264 - precision: 0.8576 - val_loss: 1.0123 - val_acc: 0.9432 - val_f1: 0.5754 - val_precision: 0.4548\n", "Epoch 99/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0556 - acc: 0.9910 - f1: 0.9505 - precision: 0.9138\n", "Epoch 00099: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0513 - acc: 0.9912 - f1: 0.9472 - precision: 0.9093 - val_loss: 0.8208 - val_acc: 0.9389 - val_f1: 0.4494 - val_precision: 0.4374\n", "Epoch 100/100\n", "640/753 [========================>.....] - ETA: 1s - loss: 0.0715 - acc: 0.9814 - f1: 0.9007 - precision: 0.8240\n", "Epoch 00100: val_f1 did not improve from 0.62276\n", "753/753 [==============================] - 10s 14ms/sample - loss: 0.0645 - acc: 0.9832 - f1: 0.9072 - precision: 0.8296 - val_loss: 0.7622 - val_acc: 0.9396 - val_f1: 0.4736 - val_precision: 0.4397\n" ] } ], "source": [ "history = model.fit(\n", " x=X_train, y=y_train,\n", " batch_size=BATCH_SIZE,\n", " epochs=EPOCHS,\n", " validation_data=(X_test, y_test),\n", " callbacks=[checkpoint],\n", " verbose=1,\n", " shuffle=False\n", ")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "VzBPJKF0Q5sW" }, "source": [ "## Performance Visualization" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "mWCIq-enCh3U" }, "source": [ "To evaluate the training of our model over time, we visualize the model's loss on its training and testing data." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 316 }, "colab_type": "code", "id": "ZhneNTcYChdt", "outputId": "84529fcd-1875-493f-f03d-fb36e69487a3" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.plot(history.history['loss'])\n", "plt.plot(history.history['val_loss'])\n", "plt.title('Model Training Loss vs. Testing Loss by Epoch')\n", "plt.ylabel('Loss')\n", "plt.xlabel('Epoch')\n", "plt.legend(['train', 'testing'], loc='upper right')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 316 }, "colab_type": "code", "id": "DlTr3Cs4VNOB", "outputId": "c4f34a1a-1372-40f8-85e9-14c40913d08c" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.plot(history.history['f1'])\n", "plt.plot(history.history['val_f1'])\n", "plt.title('Model Training F1 vs. Testing F1 by Epoch')\n", "plt.ylabel('F1')\n", "plt.xlabel('Epoch')\n", "plt.legend(['train', 'testing'], loc='upper right')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 316 }, "colab_type": "code", "id": "NQnqV-Kaw6fm", "outputId": "9840871d-7ab2-4892-e126-458ec21106e1" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.plot(history.history['precision'])\n", "plt.plot(history.history['val_precision'])\n", "plt.title('Model Training Precision vs. Testing Precision by Epoch')\n", "plt.ylabel('Precision')\n", "plt.xlabel('Epoch')\n", "plt.legend(['train', 'testing'], loc='upper right')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "D8cXg6gpsZbN" }, "source": [ "(We can see that the model performs much better on its training data. This is expected, as the model learns to recreate the selections of the training data. We can also see that the performance of the model on the testing data decreases over time. This is evidence of the model overfitting. At some point, the model begins to naively recreate the selections of the training data rather than truly learning how to make selections. In practice, we effectively ignore this as we have already saved the version of the model with the best performance on the testing data - mitigating any overfitting.)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": {}, "colab_type": "code", "id": "EVdN-Urh-mdX" }, "outputs": [], "source": [ "model = tf.keras.models.load_model('/content/mp-dl-unh', {'weighted_binary_crossentropy':weighted_binary_crossentropy, 'f1':f1})" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "K2Hd8FlY2Xty" }, "source": [ "## Model Performance Visualization\n", "\n", "Now that we have trained the model, we will visualize its selection-making ability compared to the SITLs." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jwbEhnbi2Xt0" }, "source": [ "Since we've already preprocessed the testing/training data into a format suitable for model training, we reload that data to preprocess it into a format suitable for evaluation." ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "colab": {}, "colab_type": "code", "id": "loaf_Rpz2Xt0" }, "outputs": [], "source": [ "validation_data = pd.read_csv('training_data.csv', index_col=0, infer_datetime_format=True,\n", "\t\t\t\t\t\t parse_dates=[0])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "oujploNp2Xt4" }, "source": [ "We apply the same preprocessing steps to this data as we did for the original training and testing data." ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": {}, "colab_type": "code", "id": "UarjtN552Xt4" }, "outputs": [], "source": [ "index = validation_data.index\n", "selections = validation_data.pop(\"selected\")\n", "column_names = validation_data.columns" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": {}, "colab_type": "code", "id": "PB666aem2Xt7" }, "outputs": [], "source": [ "validation_data = validation_data.replace([np.inf, -np.inf], np.nan)\n", "validation_data = validation_data.interpolate(method='time', limit_area='inside')" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "colab": {}, "colab_type": "code", "id": "DtkJHVze2Xt-" }, "outputs": [], "source": [ "validation_data = scaler.transform(validation_data)\n", "validation_data = pd.DataFrame(validation_data, index, column_names)\n", "validation_data = validation_data.join(selections)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "colab": {}, "colab_type": "code", "id": "k84Veyug2XuC" }, "outputs": [], "source": [ "validation_X = validation_data.values[:,:-1]\n", "validation_y = validation_data.values[:,-1]" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "zux-H0AMLRF0" }, "source": [ "Using the model we trainend earlier, we make test predctions on our validation data." ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "colab": {}, "colab_type": "code", "id": "QWHQ9x8x2XuH" }, "outputs": [], "source": [ "test_predictions = model.predict(np.expand_dims(validation_X, axis=0))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jTUqi-6C2XuJ" }, "source": [ "We visualize the true SITL selections made over the validation data by plotting the ground truth values for each datapoint in the data (where a 1 denotes that an individual datapoint was selected and a 0 denotes that it wasn't).\n" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 372 }, "colab_type": "code", "id": "izyEw0C32XuJ", "outputId": "091dd4bf-6170-4765-b927-c438e4f4b742" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(28, 5))\n", "plt.plot(validation_y.astype(int))\n", "plt.title(\"Ground Truth (SITL) Selections by Datapoint\")\n", "plt.ylabel('Selected (1) or not (0)')\n", "plt.xlabel('Datapoint')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "C_f_rD7T2XuM" }, "source": [ "...and we do the same for the model's predictions." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 372 }, "colab_type": "code", "id": "QRFwAeCF2XuN", "outputId": "f91a11ea-9ed9-4644-b9f8-bef79b7a92a5" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(28, 5))\n", "plt.plot(test_predictions.squeeze())\n", "plt.title(\"Model Predicted Selections by Datapoint\")\n", "plt.ylabel('Selection confidence (continous)')\n", "plt.xlabel('Datapoint')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "nzomETjp3tUa" }, "source": [ "From this plot, we can see the continuous nature of the model's predictions. As mentioned earlier, the model outputs a continuous value between 0 and 1 for each datapoint that(very roughly) corresponds to its confidence in the selection of a point (i.e. an outputted value of 0.95 for a datapoint roughly means that the model is 95% certain that that point should be selected)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "KRuZkV0t2XuP" }, "source": [ "With this in mind, we filter the model's predictions so that only those predictions with a >= 50% probability of being a magnetopause crossing are kept. This choice of probability/certainty is known as the threshold. \n", "\n", "This choice of threshold is chosen to optimize between over-selecting datapoints (resulting in more false-positives) and under-selecting them (resulting in more false-negatives).\n", "\n", "As an example, consider an email server's spam-detection system. Such a system might have a fairly high threshold (>99%), as you don't want to accidentally send a user's non-spam email to their spam inbox. At the same time, it's okay if a handful of spam emails make it through their regular inbox.\n", "\n", "In our case, we can afford to over-select datapoints as we do not want to miss out on any potential magnetopause crossings." ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 372 }, "colab_type": "code", "id": "Qq6_n8gP2XuP", "outputId": "ecab9483-0719-4a43-8b17-5a636d3b7914" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "t_output = [0 if x < 0.5 else 1 for x in test_predictions.squeeze()]\n", "plt.figure(figsize=(28, 5))\n", "plt.plot(t_output)\n", "plt.title(\"Filtered Model Predictions by Datapoint\")\n", "plt.ylabel('Selected (1) or not (0)')\n", "plt.xlabel('Datapoint')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "--HNA2LF-AE4" }, "source": [ "## Model Validation\n", "\n", "Although we have already validated our model on data it has not seen (the testing set), we need to make sure that its ability to select magnetopause crossings is transferable to another range of data." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "tfv571-4MdaP" }, "source": [ "We load a third set of data, the validation set, which serves as an independent check on the model. " ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "colab": {}, "colab_type": "code", "id": "1YUj-vy80I2l" }, "outputs": [], "source": [ "validation_data = pd.read_csv('validation_data.csv', index_col=0, infer_datetime_format=True,\n", "\t\t\t\t\t\t parse_dates=[0])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "tsg-zH1Y0I1p" }, "source": [ "We apply the same preprocessing steps to the validation data as we did for the training and testing data." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "colab": {}, "colab_type": "code", "id": "Z3AsLjh-xsYd" }, "outputs": [], "source": [ "index = validation_data.index\n", "selections = validation_data.pop(\"selected\")\n", "column_names = validation_data.columns" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "colab": {}, "colab_type": "code", "id": "FMXckJVX0QMs" }, "outputs": [], "source": [ "validation_data = validation_data.replace([np.inf, -np.inf], np.nan)\n", "validation_data = validation_data.interpolate(method='time', limit_area='inside')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "dcgfN3Y20bs8" }, "source": [ "However, we standardize the validation data to the scale of the training/testing data." ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "colab": {}, "colab_type": "code", "id": "E2pTbLeH0Ryk" }, "outputs": [], "source": [ "validation_data = scaler.transform(validation_data)\n", "validation_data = pd.DataFrame(validation_data, index, column_names)\n", "validation_data = validation_data.join(selections)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "colab": {}, "colab_type": "code", "id": "HmvIfkOQ0S1N" }, "outputs": [], "source": [ "validation_X = validation_data.values[:,:-1]\n", "validation_y = validation_data.values[:,-1]" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8Iuoi7WQ53gk" }, "source": [ "Using the model we trained earlier, we make test predctions on our validation data." ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "colab": {}, "colab_type": "code", "id": "AlOFmA_R0Wm1" }, "outputs": [], "source": [ "test_predictions = model.predict(np.expand_dims(validation_X, axis=0))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "YZZPb5EjM9EB" }, "source": [ "We visualize the true SITL selections made over the validation data in the same way we did above.\n" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 372 }, "colab_type": "code", "id": "42AfNry2FO7V", "outputId": "2f281410-11c5-40f5-9740-3a071b488e3e" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(28, 5))\n", "plt.plot(validation_y.astype(int))\n", "plt.title(\"Ground Truth (SITL) Selections by Datapoint\")\n", "plt.ylabel('Selected (1) or not (0)')\n", "plt.xlabel('Datapoints')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "oSdniM9HND_c" }, "source": [ "...and we do the same for the model's predictions." ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 372 }, "colab_type": "code", "id": "IQvn147CFVgq", "outputId": "7ffd419d-50f4-457a-8ba5-785d96c152bb" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(28, 5))\n", "plt.plot(test_predictions.squeeze())\n", "plt.title(\"Model Predicted Selections by Datapoint\")\n", "plt.ylabel('Selection confidence (continous)')\n", "plt.xlabel('Datapoints')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Fbr_DnYhNJOJ" }, "source": [ "Once again, we filter the model's predictions so that only those predictions with a >= 50% probability of being a magnetopause crossing are kept." ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 372 }, "colab_type": "code", "id": "kKbbKKOAFZ7y", "outputId": "e1a0b94c-2dce-4281-ed91-e755af68917f" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "t_output = [0 if x < 0.5 else 1 for x in test_predictions.squeeze()]\n", "plt.figure(figsize=(28, 5))\n", "plt.plot(t_output)\n", "plt.title(\"Filtered Model Predictions by Datapoint\")\n", "plt.ylabel('Selected (1) or not (0)')\n", "plt.xlabel('Datapoints')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "7Z42aU4b4ieu" }, "source": [ "We now plot a receiver operating characteristic (ROC) curve based on the model's performance over the evaluation data. \n", "\n", "An ROC curve will plot a model's true-positive vs. false positive rates of predictions for varying choices of thresholds. As the threshold approaches 1, the false positive rate and the true positive rates approach 0, as every prediction made is over the threshold and is thus considered a selection. As the threshold approaches 1, the false positive rate and the true positive rates approach 0, as no prediction made surpasses the threshold of 1.\n", "\n", "While we can use the plot to determine where we want to set our threshold (considering the importance of under-selecting or over-selecting points), it is more often used to get a sense of the performance of our model.\n", "\n", "To do so, we calculate the total area under the ROC curve. This area is equal to the probability that the model will output a higher prediction value for a randomly chosen datapoint whose ground truth was \"selected\" than for a randomly chosen datapoint whose ground truth value was \"not selected\"." ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 316 }, "colab_type": "code", "id": "62howqH4NP2h", "outputId": "0846195d-924d-47e5-8a84-d92e34ceb729" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "fpr, tpr, thresholds = roc_curve(validation_y.astype(int), test_predictions.squeeze())\n", "lw = 2\n", "plt.plot(fpr, tpr, color='darkorange',\n", " lw=lw, label='ROC curve')\n", "plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')\n", "plt.xlim([0.0, 1.0])\n", "plt.ylim([0.0, 1.0])\n", "plt.xlabel('False Positive Rate')\n", "plt.ylabel('True Positive Rate')\n", "plt.title('ROC curve - AUC = {:.2f}'.format(auc(fpr, tpr)))\n", "plt.legend(loc=\"lower right\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "J2X1ovuuNyE6" }, "source": [ "Finally, we generate a list of predicted selection windows. The following code groups contiguous selected datapoints into windows and list the start and dates of those windows.\n" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "colab": {}, "colab_type": "code", "id": "XC-65ktvIKUc" }, "outputs": [], "source": [ "predicts_df = pd.DataFrame()\n", "predicts_df.insert(0, \"time\", validation_data.index)\n", "predicts_df.insert(1, \"prediction\", t_output)\n", "predicts_df['group'] = (predicts_df.prediction != predicts_df.prediction.shift()).cumsum()\n", "predicts_df = predicts_df.loc[predicts_df['prediction'] == True]\n", "selections = pd.DataFrame({'BeginDate' : predicts_df.groupby('group').time.first(), \n", " 'EndDate' : predicts_df.groupby('group').time.last()})\n", "selections = selections.set_index('BeginDate')" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 436 }, "colab_type": "code", "id": "db_RET1gfmLf", "outputId": "6be953ad-02e1-4886-d1a9-20a1269b66a2" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EndDate
BeginDate
2017-02-03 22:27:28.7848412017-02-03 22:29:03.285435
2017-02-03 22:31:13.7862492017-02-03 22:33:55.787267
2017-02-03 23:35:52.8113782017-02-03 23:56:48.319683
2017-02-04 00:09:19.8246712017-02-04 00:16:04.827367
2017-02-04 00:38:52.8365032017-02-04 00:41:03.337379
......
2017-02-09 08:00:15.4353802017-02-09 08:00:24.435452
2017-02-09 08:00:51.4356662017-02-09 08:01:27.435952
2017-02-09 08:02:52.9366242017-02-09 08:07:00.438590
2017-02-09 08:15:01.9423992017-02-09 08:46:04.957144
2017-02-09 09:43:13.9842732017-02-09 09:46:58.986053
\n", "

97 rows × 1 columns

\n", "
" ], "text/plain": [ " EndDate\n", "BeginDate \n", "2017-02-03 22:27:28.784841 2017-02-03 22:29:03.285435\n", "2017-02-03 22:31:13.786249 2017-02-03 22:33:55.787267\n", "2017-02-03 23:35:52.811378 2017-02-03 23:56:48.319683\n", "2017-02-04 00:09:19.824671 2017-02-04 00:16:04.827367\n", "2017-02-04 00:38:52.836503 2017-02-04 00:41:03.337379\n", "... ...\n", "2017-02-09 08:00:15.435380 2017-02-09 08:00:24.435452\n", "2017-02-09 08:00:51.435666 2017-02-09 08:01:27.435952\n", "2017-02-09 08:02:52.936624 2017-02-09 08:07:00.438590\n", "2017-02-09 08:15:01.942399 2017-02-09 08:46:04.957144\n", "2017-02-09 09:43:13.984273 2017-02-09 09:46:58.986053\n", "\n", "[97 rows x 1 columns]" ] }, "execution_count": 46, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "selections" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "O5H3KGodtKKR" }, "source": [ "## Conclusion\n", "\n", "The above steps have walked you through the development of the GLS-MP model currently deployed at NASA to assist SITLs with data selection. \n", "\n", "Since being implemented into the near real-time data stream, the GLS-MP model has selected 78% of SITL-identified MP crossings in the outbound leg of its orbit, 44% more than the existing MP-crossing selection algorithm onboard MMS spacecraft (ABS). \n", "\n", "The model and its associated paper represent the first attempt to introduce machine learning into critical mission operations. \n", "\n", "Additionally, the nature of the model and its training make it easily adoptable for use in other phenomena-detection tasks, such as identifying reconnection jets or Kelvin-Helmholtz waves in the magntopause. By expanding GLS-MP into a hierarchy of machine learning models, MMS progresses toward full autonomy in its burst management system, thereby reducing operations costs and transferring information and resources back to answering fundamental science questions." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "machine_shape": "hm", "name": "notebook.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 4 }