{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Basics with string sequence" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Find this notebook at `EpyNN/epynnlive/dummy_string/train.ipynb`.\n", "* Regular python code at `EpyNN/epynnlive/dummy_string/train.py`.\n", "\n", "Run the notebook online with [Google Colab](https://colab.research.google.com/github/Synthaze/EpyNN/blob/main/epynnlive/dummy_string/train.ipynb).\n", "\n", "**Level: Beginner**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook we will review:\n", "\n", "* Handling sequential string data.\n", "* Training of Feed-Forward (FF) and recurrent networks (RNN, LSTM, GRU).\n", "* Differences between decisions and probabilities and related functions.\n", "\n", "**It is assumed that all *basics* notebooks were already reviewed:**\n", "\n", "* [Basics with Perceptron (P)](../dummy_boolean/train.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**This notebook does not enhance, extend or replace EpyNN's documentation.**\n", "\n", "**Relevant documentation pages for the current notebook:**\n", "\n", "* [Fully Connected (Dense)](https://epynn.net/Dense.html)\n", "* [Recurrent Neural Network (RNN)](https://epynn.net/RNN.html)\n", "* [Long Short-Term Memory (LSTM)](https://epynn.net/LSTM.html)\n", "* [Gated Recurrent Unit (GRU)](https://epynn.net/GRU.html)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Environment and data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Follow [this link](prepare_dataset.ipynb) for details about data preparation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Briefly, these dummy string data consist of sequences of characters. Sample features are each represented by one sequence and can be either associated with a positive or negative label.\n", "\n", "Positive sequences are met when the first element in the sequence is equal to the last element in this same sequence, and reciprocally." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# EpyNN/epynnlive/dummy_string/train.ipynb\n", "# Install dependencies\n", "!pip3 install --upgrade-strategy only-if-needed epynn\n", "\n", "# Standard library imports\n", "import random\n", "\n", "# Related third party imports\n", "import numpy as np\n", "\n", "# Local application/library specific imports\n", "import epynn.initialize\n", "from epynn.commons.io import one_hot_decode_sequence\n", "from epynn.commons.maths import relu, softmax\n", "from epynn.commons.library import (\n", " configure_directory,\n", " read_model,\n", ")\n", "from epynn.network.models import EpyNN\n", "from epynn.embedding.models import Embedding\n", "from epynn.flatten.models import Flatten\n", "from epynn.rnn.models import RNN\n", "from epynn.gru.models import GRU\n", "from epynn.lstm.models import LSTM\n", "from epynn.dense.models import Dense\n", "from epynnlive.dummy_string.prepare_dataset import prepare_dataset\n", "from epynnlive.dummy_string.settings import se_hPars\n", "\n", "\n", "########################## CONFIGURE ##########################\n", "random.seed(1)\n", "\n", "np.set_printoptions(threshold=10)\n", "\n", "np.seterr(all='warn')\n", "\n", "configure_directory()\n", "\n", "\n", "############################ DATASET ##########################\n", "X_features, Y_label = prepare_dataset(N_SAMPLES=480)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's control what we retrieved." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "480\n", "12\n", "['G', 'A', 'C', 'T', 'T', 'G', 'G', 'C', 'C', 'A', 'T', 'C']\n", "1\n" ] } ], "source": [ "print(len(X_features))\n", "print(len(X_features[0]))\n", "print(X_features[0])\n", "print(Y_label[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We retrieved a set of sample features describing ``480`` samples.\n", "\n", "Each sample is described by ``12`` string features.\n", "\n", "Herein the label is ``1`` because the first and last element are different." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feed-Forward (FF)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To compare Feed-Forward and recurrent networks, we are going to train a simple Perceptron first." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embedding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The principle of [One-hot encoding of string features](prepare_dataset.ipynb#One-hot-encoding-of-string-features) was detailed before.\n", "\n", "Briefly, we can not do math on string data. Therefore, the one-hot encoding process may be summarized as such:\n", "\n", "* List of all elements of size vocab_size. This basically answers: what is the number of distinct elements we can find in your data?\n", "* Each element is associated with one index in the range(0, vocab_size). This provides an ``element_to_idx`` encoder.\n", "* For one sample and for each element in the associated list of features, a zero array is initialized. This array is set to one at the index which is assigned to the ``element_to_idx`` encoder.\n", "\n", "This is achieved during instantiation of the *embedding* layer by setting up ``X_encode=True``." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "embedding = Embedding(X_data=X_features,\n", " Y_data=Y_label,\n", " X_encode=True,\n", " Y_encode=True,\n", " relative_size=(2, 1, 0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's inspect some properties." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'A': 0, 'C': 1, 'G': 2, 'T': 3}\n", "[[0. 0. 1. 0.]\n", " [1. 0. 0. 0.]\n", " [0. 1. 0. 0.]\n", " ...\n", " [1. 0. 0. 0.]\n", " [0. 0. 0. 1.]\n", " [0. 1. 0. 0.]]\n", "{0: 'A', 1: 'C', 2: 'G', 3: 'T'}\n", "['G', 'A', 'C', 'T', 'T', 'G', 'G', 'C', 'C', 'A', 'T', 'C']\n" ] } ], "source": [ "print(embedding.e2i) # element_to_idx\n", "print(embedding.dtrain.X[0])\n", "\n", "print(embedding.i2e) # idx_to_element\n", "print(one_hot_decode_sequence(embedding.dtrain.X[0], embedding.i2e))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Encoded sequences may be decoded as shown above." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Flatten-Dense - Perceptron" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's inspect the shape of the data." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(320, 12, 4)\n" ] } ], "source": [ "print(embedding.dtrain.X.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It contains 320 samples (m), each described by a sequence of 12 features (s) containing 4 elements (v).\n", "\n", "12 features is the length of the sequences and 4 elements is the size of the vocabulary. Remember one-hot encoding makes a zero array of this size and sets 1 at the index corresponding to the element being encoded.\n", "\n", "Still, the fully-connected or *dense* layer can only process bi-dimensional input arrays. That is the reason why we need to invoke a *flatten* layer in between the *embedding* and *dense* layer." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "tags": [] }, "outputs": [], "source": [ "name = 'Flatten_Dense-2-softmax'\n", "\n", "se_hPars['learning_rate'] = 0.001\n", "\n", "flatten = Flatten()\n", "\n", "dense = Dense(2, softmax)\n", "\n", "layers = [embedding, flatten, dense]\n", "\n", "model = EpyNN(layers=layers, name=name)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "Initialize using most classically a *MSE* or *Binary Cross Entropy* loss function." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m--- EpyNN Check OK! --- \u001b[0m\r" ] } ], "source": [ "model.initialize(loss='BCE', seed=1, se_hPars=se_hPars.copy(), end='\\r')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train for hundred epochs." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[37mEpoch 49 - Batch 0/0 - Accuracy: 0.738 Cost: 0.55582 - TIME: 2.14s RATE: 2.34e+01e/s TTC: 0s \u001b[0m\n", "\n", "+-------+----------+----------+-------+--------+-------+------------------------------------+\n", "| \u001b[1m\u001b[37mepoch\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[32maccuracy\u001b[0m | | \u001b[1m\u001b[31mBCE\u001b[0m | | \u001b[37mExperiment\u001b[0m |\n", "| | \u001b[37mDense\u001b[0m | \u001b[1m\u001b[32mdtrain\u001b[0m | \u001b[1m\u001b[32mdval\u001b[0m | \u001b[1m\u001b[31mdtrain\u001b[0m | \u001b[1m\u001b[31mdval\u001b[0m | |\n", "+-------+----------+----------+-------+--------+-------+------------------------------------+\n", "| \u001b[1m\u001b[37m0\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.688\u001b[0m | \u001b[1m\u001b[32m0.694\u001b[0m | \u001b[1m\u001b[31m0.631\u001b[0m | \u001b[1m\u001b[31m0.640\u001b[0m | \u001b[37m1635014418_Flatten_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m5\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.716\u001b[0m | \u001b[1m\u001b[32m0.731\u001b[0m | \u001b[1m\u001b[31m0.609\u001b[0m | \u001b[1m\u001b[31m0.620\u001b[0m | \u001b[37m1635014418_Flatten_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m10\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.719\u001b[0m | \u001b[1m\u001b[32m0.750\u001b[0m | \u001b[1m\u001b[31m0.596\u001b[0m | \u001b[1m\u001b[31m0.609\u001b[0m | \u001b[37m1635014418_Flatten_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m15\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.728\u001b[0m | \u001b[1m\u001b[32m0.750\u001b[0m | \u001b[1m\u001b[31m0.585\u001b[0m | \u001b[1m\u001b[31m0.601\u001b[0m | \u001b[37m1635014418_Flatten_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m20\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.728\u001b[0m | \u001b[1m\u001b[32m0.750\u001b[0m | \u001b[1m\u001b[31m0.577\u001b[0m | \u001b[1m\u001b[31m0.595\u001b[0m | \u001b[37m1635014418_Flatten_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m25\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.756\u001b[0m | \u001b[1m\u001b[31m0.571\u001b[0m | \u001b[1m\u001b[31m0.590\u001b[0m | \u001b[37m1635014418_Flatten_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m30\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.756\u001b[0m | \u001b[1m\u001b[31m0.566\u001b[0m | \u001b[1m\u001b[31m0.587\u001b[0m | \u001b[37m1635014418_Flatten_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m35\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.562\u001b[0m | \u001b[1m\u001b[31m0.585\u001b[0m | \u001b[37m1635014418_Flatten_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m40\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.559\u001b[0m | \u001b[1m\u001b[31m0.583\u001b[0m | \u001b[37m1635014418_Flatten_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m45\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.756\u001b[0m | \u001b[1m\u001b[31m0.557\u001b[0m | \u001b[1m\u001b[31m0.582\u001b[0m | \u001b[37m1635014418_Flatten_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m49\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.741\u001b[0m | \u001b[1m\u001b[32m0.756\u001b[0m | \u001b[1m\u001b[31m0.555\u001b[0m | \u001b[1m\u001b[31m0.581\u001b[0m | \u001b[37m1635014418_Flatten_Dense-2-softmax\u001b[0m |\n", "+-------+----------+----------+-------+--------+-------+------------------------------------+\n" ] } ], "source": [ "model.train(epochs=50, init_logs=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot the results." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "model.plot(path=False)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "Strictly speaking, the Perceptron seems to have converged in the right direction." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "By reducing the learning rate, all other things being equal, we obtained greater accuracy, lower cost and smoother curves on the plot.\n", "\n", "You may have observed something possibly counter-intuitive: \n", "\n", "* The cost, which describes the mean difference between **output probabilities** and labels, is lower for the validation set compared to training set.\n", "* The accuracy, which describes the mean difference between **output decisions** and labels, is higher for the validation set compared to training set.\n", "\n", "While the *cost* says the error is higher when evaluating on the validation set, the *accuracy* says the opposite.\n", "\n", "That’s because accuracy compares **decisions** and labels, whereas the cost from the loss function compares **probabilities** and labels.\n", "\n", "For code, maths and pictures behind the *Flatten* and *Dense* layers, follow these links:\n", "\n", "* [Flatten - Adapter](https://epynn.net/Flatten.html)\n", "* [Fully Connected (Dense)](https://epynn.net/Dense.html)\n", "\n", "Let’s take a break and understand the difference in addition to making clear some semantics." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "### Difference between accuracy and cost" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "The question is: Can we expect identical costs for the training and validation set **if** the accuracy is identical for the training and validation set?\n", "\n", "Let’s compare what outputs from the *dense* layer (A) to the set of sample label (Y)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.24990897 0.75009103]\n", " [0.15080179 0.84919821]\n", " [0.28627148 0.71372852]\n", " ...\n", " [0.09204908 0.90795092]\n", " [0.15813641 0.84186359]\n", " [0.41724774 0.58275226]]\n", "(320, 2)\n", "[[0. 1.]\n", " [0. 1.]\n", " [1. 0.]\n", " ...\n", " [1. 0.]\n", " [0. 1.]\n", " [0. 1.]]\n", "(320, 2)\n" ] } ], "source": [ "# This is probability distributions for each sample. \n", "print(model.embedding.dtrain.A)\n", "print(model.embedding.dtrain.A.shape)\n", "\n", "# These are the labels we target\n", "print(model.embedding.dtrain.Y) \n", "print(model.embedding.dtrain.Y.shape) " ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "We have probabilities (A) versus binary values (Y).\n", "\n", "To compute the accuracy, one needs to convert probabilities to decisions, as well as to retrieve single-digit labels." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 1 1 ... 1 1 1]\n", "[1 1 0 ... 0 1 1]\n" ] } ], "source": [ "print(np.argmax(model.embedding.dtrain.A, axis=1))\n", "\n", "# Equivalent to calling model.embedding.dtrain.y directly\n", "print(np.argmax(model.embedding.dtrain.Y, axis=1))" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "Then, accuracy is computed such as:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ True True False ... False True True]\n", "0.740625\n" ] } ], "source": [ "print ((np.argmax(model.embedding.dtrain.A, axis=1) == model.embedding.dtrain.y))\n", "print ((np.argmax(model.embedding.dtrain.A, axis=1) == model.embedding.dtrain.y).mean())" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "The cost is computed from **probabilities, not from decisions**. This apart from the fact that accuracy and cost are simply two different functions.\n", "\n", "To compute a cost, we first need to compute the loss, which provides information for each single probability in the array (A)." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(320,)\n" ] } ], "source": [ "# This is the cost. It deals with \"true\" labels against probabilities\n", "\n", "loss = model.training_loss(model.embedding.dtrain.Y, model.embedding.dtrain.A)\n", "\n", "print(loss.shape)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "The cost is a form of average of the loss. Whereas the loss is an array from element-wise comparison between probabilities and labels, the cost is a scalar which is an average per sample, itself an average of the element-wise loss for this sample." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(320,)\n", "()\n", "0.5554707080208504\n" ] } ], "source": [ "print(loss.shape) # Averaged for each sample\n", "print(loss.mean().shape) # Average of above - scalar\n", "\n", "cost = loss.mean()\n", "\n", "print(cost)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "Note that what is fed back in the network during the backward propagation phase is not the loss. It is the **derivative** of the loss." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.2875607 0.16346266 1.2508147 ... 2.38543339 0.17213728 0.53999313]\n", "[[ 0.66658576 -0.66658576]\n", " [ 0.58879069 -0.58879069]\n", " [-1.74659385 1.74659385]\n", " ...\n", " [-5.43188494 5.43188494]\n", " [ 0.59392045 -0.59392045]\n", " [ 0.85799753 -0.85799753]]\n" ] } ], "source": [ "dloss = model.training_loss(model.embedding.dtrain.Y, model.embedding.dtrain.A, deriv=True)\n", "\n", "print(loss)\n", "print(dloss) # dloss is referred to as dA" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "The loss function and derivatives natively provided with EpyNN can be found in `EpyNN/epynn/commons/loss.py`.\n", "\n", "The metrics natively provided with EpyNN can be found in `EpyNN/epynn/commons/metrics.py`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Recurrent Architectures" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Herein, we are going to chain simple schemes based on recurrent architectures.\n", "\n", "There are three most commonly cited recurrent layers:\n", "\n", "* **Recurrent Neural Network (RNN)**: This is the most simple recurrent layer. It is composed of one to many recurrent units. Each cell performs a single activation which outputs the *hidden cell state* or simply *hidden state*.\n", "* **Long Short-Term Memory (LSTM)**: By contrast with the RNN cell, the LSTM cell requires four activations which correspond to three different gates: forget, input (two activations), and output. To compute the hidden cell state, it then requires a fifth activation. Note that in addition to the hidden cell state, there is another so-called cell *memory* state.\n", "* **Gated Recurrent Unit (GRU)**: Compared to the LSTM cell, the GRU cell has only two gates: reset and update. Practically talking, GRU trains faster than LSTM and is reported to perform better on small datasets or shorter sequences. Both GRU and LSTM, however, are state-of-the-art architectures to deal with sequential data.\n", "\n", "See [here](epynnlive/dummy_time/train.html#Recurrent-Neural-Network-(RNN)) for more detailed practical descriptions or simply via the pages linked on top of this notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embedding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we use the same setup as for Feed-Forward networks." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "embedding = Embedding(X_data=X_features,\n", " Y_data=Y_label,\n", " X_encode=True,\n", " Y_encode=True,\n", " relative_size=(2, 1, 0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now chain the simplest schemes to train binary classifiers based on recurrent layers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### RNN-Dense" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The number of RNN units in the RNN layer is set to 1." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "tags": [] }, "outputs": [], "source": [ "name = 'RNN-1_Dense-2-softmax'\n", "\n", "se_hPars['learning_rate'] = 0.001\n", "\n", "rnn = RNN(1)\n", "\n", "dense = Dense(2, softmax)\n", "\n", "layers = [embedding, rnn, dense]\n", "\n", "model = EpyNN(layers=layers, name=name)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "Initialize the model." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m--- EpyNN Check OK! --- \u001b[0m\r" ] } ], "source": [ "model.initialize(loss='BCE', seed=1, se_hPars=se_hPars.copy(), end='\\r')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train for 50 epochs." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[37mEpoch 9 - Batch 0/0 - Accuracy: 0.734 Cost: 0.58678 - TIME: 0.94s RATE: 1.06e+01e/s TTC: 0s \u001b[0m\n", "\n", "+-------+----------+----------+----------+-------+--------+-------+----------------------------------+\n", "| \u001b[1m\u001b[37mepoch\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[32maccuracy\u001b[0m | | \u001b[1m\u001b[31mBCE\u001b[0m | | \u001b[37mExperiment\u001b[0m |\n", "| | \u001b[37mRNN\u001b[0m | \u001b[37mDense\u001b[0m | \u001b[1m\u001b[32mdtrain\u001b[0m | \u001b[1m\u001b[32mdval\u001b[0m | \u001b[1m\u001b[31mdtrain\u001b[0m | \u001b[1m\u001b[31mdval\u001b[0m | |\n", "+-------+----------+----------+----------+-------+--------+-------+----------------------------------+\n", "| \u001b[1m\u001b[37m0\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.544\u001b[0m | \u001b[1m\u001b[32m0.581\u001b[0m | \u001b[1m\u001b[31m0.676\u001b[0m | \u001b[1m\u001b[31m0.659\u001b[0m | \u001b[37m1635014422_RNN-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m1\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.648\u001b[0m | \u001b[1m\u001b[31m0.630\u001b[0m | \u001b[37m1635014422_RNN-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m2\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.629\u001b[0m | \u001b[1m\u001b[31m0.610\u001b[0m | \u001b[37m1635014422_RNN-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m3\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.615\u001b[0m | \u001b[1m\u001b[31m0.595\u001b[0m | \u001b[37m1635014422_RNN-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m4\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.605\u001b[0m | \u001b[1m\u001b[31m0.585\u001b[0m | \u001b[37m1635014422_RNN-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m5\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.598\u001b[0m | \u001b[1m\u001b[31m0.577\u001b[0m | \u001b[37m1635014422_RNN-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m6\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.593\u001b[0m | \u001b[1m\u001b[31m0.571\u001b[0m | \u001b[37m1635014422_RNN-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m7\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.589\u001b[0m | \u001b[1m\u001b[31m0.566\u001b[0m | \u001b[37m1635014422_RNN-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m8\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.587\u001b[0m | \u001b[1m\u001b[31m0.563\u001b[0m | \u001b[37m1635014422_RNN-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m9\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.585\u001b[0m | \u001b[1m\u001b[31m0.560\u001b[0m | \u001b[37m1635014422_RNN-1_Dense-2-softmax\u001b[0m |\n", "+-------+----------+----------+----------+-------+--------+-------+----------------------------------+\n" ] } ], "source": [ "model.train(epochs=10, init_logs=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You may already note that results are virtually identical to just using a basic Perceptron, although slightly better." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "model.plot(path=False)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "While the y-scale on the plot is a bit misleading when looking at the accuracy, there is no overfitting in there because the BCE cost is the same for both training and validation set at the end of the regression." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "For code, maths and pictures behind the *RNN* layer, follow this link:\n", "\n", "* [Recurrent Neural Network (RNN)](https://epynn.net/RNN.html)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### LSTM-Dense" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's now proceed with an *LSTM* layer composed of the 1 unit, all other things being equal." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "tags": [] }, "outputs": [], "source": [ "name = 'LSTM-1_Dense-2-softmax'\n", "\n", "se_hPars['learning_rate'] = 0.005\n", "\n", "lstm = LSTM(1)\n", "\n", "dense = Dense(2, softmax)\n", "\n", "layers = [embedding, lstm, dense]\n", "\n", "model = EpyNN(layers=layers, name=name)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "Initialize the model." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m--- EpyNN Check OK! --- \u001b[0m\r" ] } ], "source": [ "model.initialize(loss='BCE', seed=1, se_hPars=se_hPars.copy(), end='\\r')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train for 50 epochs." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[37mEpoch 9 - Batch 0/0 - Accuracy: 0.734 Cost: 0.57893 - TIME: 0.97s RATE: 1.03e+01e/s TTC: 0s \u001b[0m\n", "\n", "+-------+----------+----------+----------+-------+--------+-------+-----------------------------------+\n", "| \u001b[1m\u001b[37mepoch\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[32maccuracy\u001b[0m | | \u001b[1m\u001b[31mBCE\u001b[0m | | \u001b[37mExperiment\u001b[0m |\n", "| | \u001b[37mLSTM\u001b[0m | \u001b[37mDense\u001b[0m | \u001b[1m\u001b[32mdtrain\u001b[0m | \u001b[1m\u001b[32mdval\u001b[0m | \u001b[1m\u001b[31mdtrain\u001b[0m | \u001b[1m\u001b[31mdval\u001b[0m | |\n", "+-------+----------+----------+----------+-------+--------+-------+-----------------------------------+\n", "| \u001b[1m\u001b[37m0\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.586\u001b[0m | \u001b[1m\u001b[31m0.563\u001b[0m | \u001b[37m1635014423_LSTM-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m1\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.580\u001b[0m | \u001b[1m\u001b[31m0.552\u001b[0m | \u001b[37m1635014423_LSTM-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m2\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.550\u001b[0m | \u001b[37m1635014423_LSTM-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m3\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014423_LSTM-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m4\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014423_LSTM-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m5\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014423_LSTM-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m6\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014423_LSTM-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m7\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014423_LSTM-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m8\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014423_LSTM-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m9\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014423_LSTM-1_Dense-2-softmax\u001b[0m |\n", "+-------+----------+----------+----------+-------+--------+-------+-----------------------------------+\n" ] } ], "source": [ "model.train(epochs=10, init_logs=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The accuracy metrics are simply identical to what we have seen with a simple *RNN*, which is much faster to compute. It is not significantly better than what we obtained from a simple Perceptron, itself way faster to compute than the *RNN* based network." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "model.plot(path=False)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "By contrast with the RNN-based network, we observe here a slight overfitting because the cost is lower for the training dataset compared to the validation dataset.\n", "\n", "For code, maths and pictures behind the *LSTM* layer, follow this link:\n", "\n", "* [Long Short-Term Memory (LSTM)](https://epynn.net/LSTM.html)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GRU-Dense" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's now proceed with a GRU layer, all other things being equal." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "tags": [] }, "outputs": [], "source": [ "name = 'GRU-1_Dense-2-softmax'\n", "\n", "se_hPars['learning_rate'] = 0.005\n", "\n", "gru = GRU(1)\n", "\n", "flatten = Flatten()\n", "\n", "dense = Dense(2, softmax)\n", "\n", "layers = [embedding, gru, dense]\n", "\n", "model = EpyNN(layers=layers, name=name)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "Initialize the network." ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m--- EpyNN Check OK! --- \u001b[0m\r" ] } ], "source": [ "model.initialize(loss='BCE', seed=1, se_hPars=se_hPars.copy(), end='\\r')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train for 50 epochs." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[37mEpoch 9 - Batch 0/0 - Accuracy: 0.734 Cost: 0.57872 - TIME: 1.12s RATE: 8.93e+00e/s TTC: 0s \u001b[0m\n", "\n", "+-------+----------+----------+----------+-------+--------+-------+----------------------------------+\n", "| \u001b[1m\u001b[37mepoch\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[32maccuracy\u001b[0m | | \u001b[1m\u001b[31mBCE\u001b[0m | | \u001b[37mExperiment\u001b[0m |\n", "| | \u001b[37mGRU\u001b[0m | \u001b[37mDense\u001b[0m | \u001b[1m\u001b[32mdtrain\u001b[0m | \u001b[1m\u001b[32mdval\u001b[0m | \u001b[1m\u001b[31mdtrain\u001b[0m | \u001b[1m\u001b[31mdval\u001b[0m | |\n", "+-------+----------+----------+----------+-------+--------+-------+----------------------------------+\n", "| \u001b[1m\u001b[37m0\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.583\u001b[0m | \u001b[1m\u001b[31m0.559\u001b[0m | \u001b[37m1635014424_GRU-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m1\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.551\u001b[0m | \u001b[37m1635014424_GRU-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m2\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.550\u001b[0m | \u001b[37m1635014424_GRU-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m3\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014424_GRU-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m4\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014424_GRU-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m5\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014424_GRU-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m6\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014424_GRU-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m7\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014424_GRU-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m8\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014424_GRU-1_Dense-2-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m9\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.734\u001b[0m | \u001b[1m\u001b[32m0.762\u001b[0m | \u001b[1m\u001b[31m0.579\u001b[0m | \u001b[1m\u001b[31m0.549\u001b[0m | \u001b[37m1635014424_GRU-1_Dense-2-softmax\u001b[0m |\n", "+-------+----------+----------+----------+-------+--------+-------+----------------------------------+\n" ] } ], "source": [ "model.train(epochs=10, init_logs=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot the results." ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "model.plot(path=False)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "Overall, and using this dummy dataset made of string features, there is no significant metrics/cost difference from the simple Perceptron to recurrent RNN, GRU and LSTM. In this situation, one would favor the simple Perceptron because it computes faster. At least, it is important to note that the best architecture is not the fanciest, but simply the one that suits your needs and resources." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "For code, maths and pictures behind the *GRU* layer, follow this link:\n", "\n", "* [Gated Recurrent Unit (GRU)](https://epynn.net/GRU.html)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Write, read & Predict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A trained model can be written on disk such as:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[32mMake: /media/synthase/beta/EpyNN/epynnlive/dummy_string/models/1635014424_GRU-1_Dense-2-softmax.pickle\u001b[0m\n" ] } ], "source": [ "model.write()\n", "\n", "# model.write(path=/your/custom/path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A model can be read from disk such as:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "model = read_model()\n", "\n", "# model = read_model(path=/your/custom/path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can retrieve new features and predict on them." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "X_features, _ = prepare_dataset(N_SAMPLES=10)\n", "\n", "dset = model.predict(X_features, X_encode=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Results can be extracted such as:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 1 [0.27207978 0.72792022]\n", "1 1 [0.26511927 0.73488073]\n", "2 1 [0.26579407 0.73420593]\n", "3 1 [0.26865469 0.73134531]\n", "4 1 [0.26554721 0.73445279]\n", "5 1 [0.2699316 0.7300684]\n", "6 1 [0.26448517 0.73551483]\n", "7 1 [0.26460705 0.73539295]\n", "8 1 [0.26701639 0.73298361]\n", "9 1 [0.26556502 0.73443498]\n" ] } ], "source": [ "for n, pred, probs in zip(dset.ids, dset.P, dset.A):\n", " print(n, pred, probs)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 4 }