{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multiclass Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Find this notebook at `EpyNN/epynnlive/captcha_mnist/train.ipynb`.\n", "* Regular python code at `EpyNN/epynnlive/captcha_mnist/train.py`.\n", "\n", "Run the notebook online with [Google Colab](https://colab.research.google.com/github/Synthaze/EpyNN/blob/main/epynnlive/captcha_mnist/train.ipynb).\n", "\n", "**Level: Advanced**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook we will review:\n", "\n", "* Handling numerical data image to proceed with Neural Network regression.\n", "* Training of Feed-Forward (FF) and Convolutional Neural Network (CNN) for multiclass classification tasks.\n", "* MNIST database as a benchmark for neural networks.\n", "\n", "**It is assumed that all *basics* notebooks were already reviewed:**\n", "\n", "* [Basics with Perceptron (P)](../dummy_boolean/train.ipynb)\n", "* [Basics with string sequence](../dummy_string/train.ipynb)\n", "* [Basics with numerical time-series](../dummy_time/train.ipynb)\n", "* [Basics with images](../dummy_image/train.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**This notebook does not enhance, extend or replace EpyNN's documentation.**\n", "\n", "**Relevant documentation pages for the current notebook:**\n", "\n", "* [Fully Connected (Dense)](https://epynn.net/Dense.html)\n", "* [Convolution (CNN)](https://epynn.net/Convolution.html)\n", "* [Pooling (CNN)](https://epynn.net/Pooling.html)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Please follow this [external link](http://yann.lecun.com/exdb/mnist/) for detailed contextual information about the MNIST database." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Environment and data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Follow [this link](prepare_dataset.ipynb) for details about data preparation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Briefly, the MNIST database contains 70 000 images (60K + 10k).\n", "\n", "Each image contains a single handwritten digit. In total, 500 different human writers contributed to the whole.\n", "\n", "The MNIST database is curated in that images were size normalized and centered." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[32mMake: /media/synthase/beta/EpyNN/epynnlive/captcha_mnist/datasets\u001b[0m\n", "\u001b[1m\u001b[32mMake: /media/synthase/beta/EpyNN/epynnlive/captcha_mnist/models\u001b[0m\n", "\u001b[1m\u001b[32mMake: /media/synthase/beta/EpyNN/epynnlive/captcha_mnist/plots\u001b[0m\n", "\u001b[1m\u001b[32mMake: mnist_database.tar\u001b[0m\n" ] } ], "source": [ "# EpyNN/epynnlive/captcha_mnist/train.ipynb\n", "# Install dependencies\n", "!pip3 install --upgrade-strategy only-if-needed epynn\n", "\n", "# Standard library imports\n", "import random\n", "\n", "# Related third party imports\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "# Local application/library specific imports\n", "import epynn.initialize\n", "from epynn.commons.maths import relu, softmax\n", "from epynn.commons.library import (\n", " configure_directory,\n", " read_model,\n", ")\n", "from epynn.network.models import EpyNN\n", "from epynn.embedding.models import Embedding\n", "from epynn.convolution.models import Convolution\n", "from epynn.pooling.models import Pooling\n", "from epynn.flatten.models import Flatten\n", "from epynn.dropout.models import Dropout\n", "from epynn.dense.models import Dense\n", "from epynnlive.captcha_mnist.prepare_dataset import (\n", " prepare_dataset,\n", " download_mnist,\n", ")\n", "from epynnlive.captcha_mnist.settings import se_hPars\n", "\n", "\n", "########################## CONFIGURE ##########################\n", "random.seed(1)\n", "np.random.seed(1)\n", "\n", "np.set_printoptions(threshold=10)\n", "\n", "np.seterr(all='warn')\n", "\n", "configure_directory()\n", "\n", "\n", "############################ DATASET ##########################\n", "download_mnist()\n", "\n", "X_features, Y_label = prepare_dataset(N_SAMPLES=750)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's inspect what we retrieved." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "750\n", "(28, 28, 1)\n", "{0.0, 3.0, 6.0, 11.0, 12.0, 13.0, 15.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 27.0, 28.0, 32.0, 39.0, 40.0, 42.0, 43.0, 44.0, 45.0, 46.0, 48.0, 55.0, 67.0, 82.0, 86.0, 87.0, 90.0, 91.0, 94.0, 97.0, 100.0, 101.0, 104.0, 107.0, 110.0, 112.0, 115.0, 116.0, 125.0, 127.0, 129.0, 143.0, 147.0, 150.0, 151.0, 155.0, 158.0, 162.0, 163.0, 166.0, 167.0, 169.0, 170.0, 172.0, 181.0, 183.0, 184.0, 186.0, 189.0, 202.0, 205.0, 206.0, 210.0, 211.0, 214.0, 215.0, 219.0, 220.0, 221.0, 223.0, 225.0, 227.0, 228.0, 230.0, 233.0, 234.0, 237.0, 238.0, 239.0, 240.0, 241.0, 242.0, 243.0, 247.0, 250.0, 252.0, 253.0, 255.0}\n" ] } ], "source": [ "print(len(X_features))\n", "print(X_features[0].shape)\n", "print(set(X_features[0].flatten().tolist())) # Get all unique tones in image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We retrieved sample features describing ``750`` samples.\n", "\n", "For each sample, we retrieved features as a three-dimensional array of shape ``(width, height, depth)``. Each image has dimensions of 28x28 and so contains 784 pixels.\n", "\n", "In the context, remember that the ``depth`` dimension represents the number of channels which encode the image. While the depth of any RGB image would be equal to 3, the depth of a grayscale image is equal to one. We namely have a regular grayscale image in that shades of gray range from 0 to 255. \n", "\n", "Let's recall how this looks." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUIAAAEYCAYAAAApuP8NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdAElEQVR4nO3de7AV1Zk28OcR5TMR5CIGURETZSBookwMXhNRMBIqBrU0ijeMGPKZGIlkiMBMajIxkzGTlBVFo8GCEvKJShQDaH0iOWV0TNQSFRS5yKUgoAcoRATFiOg7f+x2sbrDPuyzL929ez2/qlPn7cve/R7OWS+91u5eTTODiEjI9ss6ARGRrKkQikjwVAhFJHgqhCISPBVCEQmeCqGIBK+QhZDkWpJDK9zXSB5b5XGqfq1Io6kdVK6QhTCPSP4fktNIbie5keS4rHMSSRvJb5H8K8mdJP+cdT6f2D/rBALyUwB9AfQBcBiAJ0kuNbPHM81KJF1bAfwGQH8AZ2ebyh6FPyMkOYjksyS3kWwleQfJjondhpNcQ3ILyV+R3M97/TUkl5F8m+R8kn2qTGUUgJvN7G0zWwbgHgBXV/leIu2Sl3ZgZn8ys1kA3qzl56m3whdCAB8BuBFADwCnAhgC4HuJfS4AcBKAfwYwAsA1AEByBIBJAC4EcCiA/wFw/94OQvIykq+U2dYNQC8Ai73ViwEcV9VPJNJ+mbeDPCt8ITSzF83sOTPbbWZrAfwOwJmJ3X5pZlvN7G8onbaPjNb/XwD/ZWbLzGw3gF8AOHFv/xua2Uwz+2KZNDpF39/x1r0DoHNVP5RIO+WkHeRW4QshyX8i+Wj0AcV2lH6JPRK7rffidQAOj+I+AG6LuhPbUBrfIIAj2pnGu9H3g711BwPY0c73EalKTtpBbhW+EAK4C8ByAH3N7GCUTvGZ2Ke3Fx+FPeMX6wF818y6el+fMrO/ticBM3sbQCuAE7zVJwB4rT3vI1KDzNtBnoVQCDsD2A7gXZL9AVy3l33Gk+xGsjeAsQAejNbfDWAiyeMAgGQXkhdXmccMAP8WHac/gO8AuLfK9xJpr1y0A5IdSB6I0hUr+5E8kOQB1bxXPYVQCP8FwGUodUPvwZ5frm8OgBcBLALwGICpAGBmjwD4JYAHou7EEgBf39tBSF5Osq0zvH8HsBqlLsdTAH6lS2ckRXlpB1cCeB+lM9SvRPE97f9x6ouamFVEQhfCGaGISJtUCEUkeCqEIhK8mgohyWEkV5BcRXJCvZISaSZqB82v6g9LSHYA8DqAcwBsAPACgJFmtrSN1+iTmexsMbNDs06iaNQOmouZJa+dBFDbGeEgAKvMbI2Z7QLwAEr3J0o+rcs6gYJSOyiAWgrhEYjfkrMBe7nlhuQYkgtJLqzhWCJ5pXZQAA2fj9DMpgCYAqhLIOFSO8i3Ws4I30D83sQjo3UiIVE7KIBaCuELAPqS/Gw0weOlAObWJy2RpqF2UABVd43NbDfJ6wHMB9ABwDQz02wqEhS1g2JI9V5jjY1k6kUzOynrJETtIEuNuHxGRKQQVAhFJHgqhCISPBVCEQmeCqGIBE+FUESC1/Bb7IrijDPOiC0fe+yxZfddtGjRXmMRySedEYpI8FQIRSR46hondO3a1cX33LPnKYOnnHJKbL/DDz/cxR9//HFs27p1e6b+u+SSS1z81ltvxfZbu3ZtLamK5MqAAQNiy/Pnz3ex316SOnTo0LCcKqUzQhEJngqhiARPXeOEsWPHuvj888+v6j369Onj4ueee67sfhdffLGL//jHP1Z1LJG8GDNmTGz5sMMOc3Fy+ChvdEYoIsFTIRSR4KkQikjwghwj9Mcupk2bFts2aNCgvb5m+fLlseXNmze7+IQTToht69KlS0V5zJgxw8Uvv/xybNuZZ55Z0XuIZGnkyJEuvuKKK8rut23bNhdv2bKlkSlVRWeEIhI8FUIRCV4wzyy59tprXXzVVVe5+NRTTy37Gr87nLyUZvXq1S4+99xzY9v8ywi++c1vln3//fbb8//Qjh07Yttmz57t4vHjx7s4eXdKO+iZJTnRzM8sGTFiRGzZ/ztt6xKZyZMnu3jcuHH1T6xCemaJiEgZKoQiEjwVQhEJXmEvnxk8eHBs+Xe/+52L/bG55LiGP5Hq1772NRe3NTbnz7KRXO7fv7+L586dG9uvb9++Lu7cuXNs26hRo1zc0tLi4vvuu69sHiKN5o+1F4nOCEUkePsshCSnkdxMcom3rjvJBSRXRt+7NTZNkWypHRTbPi+fIflVAO8CmGFmx0fr/hvAVjO7heQEAN3M7KZ9HqzBlw34k6rOmjUrtu2ss85ysd813rVrV2y/m27a82Pcfvvtdc2vX79+seWnn37axd27dy/7uqeeesrFQ4cOrfbwunymBs3UDurt6KOPdnGyXX35y192cXKYaeXKlS72h4iyVPXlM2b2NICtidUjAEyP4ukAzq8lOZG8Uzsotmo/LOlpZq1RvBFAz3I7khwDYEy57SJNTO2gIGr+1NjMrK1TfTObAmAKUP8ugd8VBoCHH37YxV/96lfLvs7/ZHj69OmxbfXuDvtWrFgRW/bvSHnooYdi2/zJXf2uyYknnhjbT48LzYcs20GjnXfeeS4eOHBgbJvfHU52je++++7GJlZH1X5qvIlkLwCIvm/ex/4iRaR2UBDVFsK5AD650G0UgDn1SUekqagdFEQll8/cD+BZAP1IbiA5GsAtAM4huRLA0GhZpLDUDoptn2OEZjayzKYhdc6l3ZLPGm5rXNC3ePFiFzdyTHBf/PE9//nHQPyhT/544eOPPx7bz59kVhonz+2gETp27OjiHj16VPSan//857HlO++8s645NZLuLBGR4KkQikjwmnrShR//+MdVvW7q1Kl1zqR2lU64esghhzQ4ExHgqKOOcvGkSZMqek3yWSQffvhhXXNqJJ0RikjwVAhFJHhN1zW++uqrXdyeR15ec801Lv7LX/5Sz5Qawp8YwvfMM8+knImE6Gc/+1lF+y1cuNDF8+bNa1Q6DaczQhEJngqhiARPhVBEgtd0Y4T+JTNtPUfVv3sE+Mc7MvImefV+uZ8tj5f+SPMbMGBAbPkrX/mKi8uNVwPAySef3LCc0qQzQhEJngqhiASv6brG/nM/2uoaJydc3bRpU8NyqpY/yeqDDz5Ydr+1a9e6ONnlF6mHMWPik2f7k3n47ayZJlttD50RikjwVAhFJHgqhCISvKYbI6xU8kFJefTEE0+4uFu38s8GX7dunYs1Rij14t+iesUVV1T0mkcffbRR6WRKZ4QiEjwVQhEJXmG7xslJW/1uaJr69+/v4mRO/iSryUuB5szZ80C0KVOmNCg7CZn/nJwuXbqU3c8fjvGHaYpEZ4QiEjwVQhEJHs0svYORdT3Yk08+GVuu9HGeo0ePdnFyMoaNGze2O49OnTrFlmfMmOHiESNGlH2dfzP7mjVrYtuOOeaYduexDy+a2Un1flNpv3q3g2p99NFHLm7rLq3Jkye7eNy4cQ3NqdHMjHtbrzNCEQmeCqGIBG+fhZBkb5JPklxK8jWSY6P13UkuILky+l7+imCRJqd2UGz7HCMk2QtALzN7iWRnAC8COB/A1QC2mtktJCcA6GZmN+3jveo6NnLuuefGliu96t0fm1u0aFFs2zvvvFPRe5B7hho6duwY2zZo0KC9vib53NedO3e6+Kc//Wls2+9///uK8mgHjRHWIM/toFp+20+OEc6cOdPFV155ZWo5NVrVY4Rm1mpmL0XxDgDLABwBYASAT+a6mo7SH4VIIakdFFu7LqgmeTSAgQCeB9DTzFqjTRsB9CzzmjEAxuxtm0gzUjsonoovnyHZCcBTAP7TzGaT3GZmXb3tb5tZm+Mj9e4SdO3aNbZ8ww037DUG4lfO+13jti4baEul73Hbbbe5ONndTXkCBXWN6yCP7aA9rrvuOhffcccdLk7+Dft3RK1evbrxiaWkpstnSB4A4GEA95nZ7Gj1pmjc5JPxk831SFQkr9QOiquST40JYCqAZWZ2q7dpLoBRUTwKwJzka0WKQu2g2CoZIzwdwJUAXiW5KFo3CcAtAGaRHA1gHYBvNSRDkXxQOyiwpr7Fri1DhgyJLU+cONHPw8XV/vxtvYc/c8xdd93l4t27d1d1rDrRGGFOZDlGOG/ePBcPHz7cxRojFBEJnAqhiASvsBOztrS0tLksIvIJnRGKSPBUCEUkeCqEIhI8FUIRCZ4KoYgET4VQRIJX2DtL5B/ozpKcUDvIju4sEREpQ4VQRIKnQigiwVMhFJHgqRCKSPBUCEUkeCqEIhI8FUIRCZ4KoYgEL+2JWbeg9ICbHlGcpTzkAKSXR58UjiGV2QLgPYT197cvaeRRtg2keoudOyi5MOvbvfKQQ57ykHTl5feuPErUNRaR4KkQikjwsiqEUzI6ri8POQD5yUPSlZffu/JARmOEIiJ5oq6xiARPhVBEgpdqISQ5jOQKkqtITkjxuNNIbia5xFvXneQCkiuj791SyKM3ySdJLiX5GsmxWeUi2Qm5HeS1DaRWCEl2AHAngK8DGABgJMkBKR3+XgDDEusmAGgxs74AWqLlRtsN4EdmNgDAKQC+H/0bZJGLZEDtIJ9tIM0zwkEAVpnZGjPbBeABACPSOLCZPQ1ga2L1CADTo3g6gPNTyKPVzF6K4h0AlgE4IotcJDNBt4O8toE0C+ERANZ7yxuidVnpaWatUbwRQM80D07yaAADATyfdS6SKrWDSJ7agD4sAWCla4hSu46IZCcADwP4oZltzzIXkU+k+beXtzaQZiF8A0Bvb/nIaF1WNpHsBQDR981pHJTkASj9AdxnZrOzzEUyEXw7yGMbSLMQvgCgL8nPkuwI4FIAc1M8ftJcAKOieBSAOY0+IEkCmApgmZndmmUukpmg20Fu24CZpfYFYDiA1wGsBvCvKR73fgCtAD5EaUxmNIBDUPp0aiWAPwHonkIeZ6B0yv8KgEXR1/AsctFXdl8ht4O8tgHdYiciwdOHJSISPBVCEQmeCqGIBE+FUESCp0IoIsFTIRSR4KkQikjwVAhFJHgqhCISPBVCEQmeCqGIBE+FUESCV8hCSHItyaEV7mskj63yOFW/VqTR1A4qV8hCmEckfx09oWsHyeUkr8o6J5G0kfxvkutJbie5juSkrHMCVAjT9B6A8wB0QWniydtInpZtSiKpmwqgv5kdDOA0AJeTvDDjnIpfCEkOIvksyW0kW0neEc0M7BtOcg3JLSR/RXI/7/XXkFxG8m2S80n2qSYPM/t3M1tuZh+b2fMA/gfAqTX8aCIVy1E7WGFm73mrPgaQebe68IUQwEcAbgTQA6XCMwTA9xL7XADgJAD/jNJjBa8BAJIjAEwCcCGAQ1EqXvfv7SAkLyP5SiUJkfwUgC8DeK2dP4tItXLTDkhOIPkuSrNkHwRgZnU/Uh1lPW15g6YDXwtgaJltPwTwiLdsAIZ5y99D6UHTAPD/AYz2tu0HYCeAPt5rj60iv+kAHgdKM4TrS1+N+MpzOwBAlB7l+R8AOmf9b1X4M0KS/0TyUZIbSW4H8AuU/lf0+c+ZXQfg8Cjug9JY3jaS21B6ODZRw3NoSf4KwPEAvmXRX4RIo+WtHVjJywDeR6kYZqrwhRDAXQCWA+hrpQHaSSj9En3+4xWPAvBmFK8H8F0z6+p9fcrM/lpNIiT/A8DXAXzNEs9yFWmw3LSDhP0BHFOH96lJCIWwM4DtAN4l2R/AdXvZZzzJbiR7AxgL4MFo/d0AJpI8DgBIdiF5cTVJkJwI4DKUuipvVfMeIjXIvB2Q3I/kd6NjkOQgAN9H6el1mQqhEP4LSgVoB4B7sOeX65sD4EWUHi34GEof8cPMHgHwSwAPRN2JJSid0f0DkpeTbOvDj1+g9L/sKpLvRl+5uIZKgpCXdnABSo8x3QHg/wGYHH1lSo/zFJHghXBGKCLSJhVCEQleTYWQ5DCSK0iuIjmhXkmJNBO1g+ZX9RghyQ4AXgdwDkpXiL8AYKSZLa1feiL5pnZQDPvX8NpBAFaZ2RoAIPkASrfllP0DIKlPZrKzxcwOzTqJAlI7aCJmlrx2EkBtXeMjEL8SfQNquNJcGm5d1gkUlNpBAdRyRlgRkmMAjGn0cUTyTO0g32ophG8gfkvOkdG6GDObAmAKoC6BFJLaQQHU0jV+AUBfkp+N5jW7FMDc+qQl0jTUDgqg6jNCM9tN8noA8wF0ADDNzDS/ngRF7aAYUr3FTl2CTL1oZidlnYSoHWSpEZ8ai4gUggqhiARPhVBEgqdCKCLBUyEUkeCpEIpI8FQIRSR4Db/XWETyqVOnTi4+/vjjY9suuugiF2/fvueBiwMHDozt16tXLxfffffdsW0zZsxw8ccff1xbsg2mM0IRCZ4KoYgET7fYhUO32OVEmu3gmGPiz06/+eabXTxs2DAXd+3aNbbf3//+dxfv3r3bxQcddFBsvw8++MDFBx54YGzbOeec4+KWlswfXQxAt9iJiJSlQigiwQuya3zccce52D99T+rXr5+LL7vssorf3/+0bPXq1WX3e/zxx128fPnyit+/Suoa50Sa7cD/GwPin96uWrXKxW+99VZsv2effdbF/t/mwQcfHNvP70LPmzcvts1/3QUXXNCetBtGXWMRkTJUCEUkeCqEIhK8wo4RXnvttbHl0047zcX+VfPJywHStGHDBhf/5Cc/iW3zxxnrRGOEOZFmOzjqqKNiy3/7298adqyXX345tuyPsR922GEu9u9USZvGCEVEylAhFJHgNXXXONl9POuss1zs3wweHbvd7//II4+4+O2332736wFg6NChLk52U3zbtm2LLZ9++ukurtOlNeoa50ReLiOrh5NPPtnFzzzzTGzb7NmzXXzppZe6OM2ak6SusYhIGSqEIhI8FUIRCV5TjxEmc/dvH/rCF74Q2+ZPPPmNb3zDxbfeemtsP3+s7s0333Txrl27qsrRv4TgD3/4Q2ybf6vfpk2byr5ux44dVR07QWOEOdHMY4TJy80WLlzo4m7dusW2+W1uy5YtjU2sQlWPEZKcRnIzySXeuu4kF5BcGX3v1tZ7iDQ7tYNiq6RrfC+AYYl1EwC0mFlfAC3RskiR3Qu1g8La5zNLzOxpkkcnVo8AMDiKpwP4M4Cb6plYJZYsWRJbHjBggIuTV9AvXbrUxbNmzWpoXt27d3fxeeed52K/K5y0ePHi2HKdusNSJ3luB43Wo0cPFyfbjj/x69lnnx3blpfucCWq/bCkp5m1RvFGAD3rlI9IM1E7KIian2JnZtbW4C/JMQDG1HockTxTO2hu1RbCTSR7mVkryV4ANpfb0cymAJgC1P/TsiFDhsSWP/OZz7j4/fffr+eh2nTIIYfElv07XvznQiTt3LnTxb/+9a/rn5g0Wi7aQT34kyIAwFVXXeVi/66QE088MbaffzXFhRdeGNvm7ztz5kwXb926tZZUG6LarvFcAKOieBSAOfVJR6SpqB0URCWXz9wP4FkA/UhuIDkawC0AziG5EsDQaFmksNQOiq2ST41Hltk0pMx6kcJROyi2pr6zJEsjR+5pF8lJVf27QnyLFi2KLft3tdx33331S27vdGdJTuSlHZx55pkuvvfee2Pb+vTpU9djvfrqqy4+4YQT6vre7aHZZ0REylAhFJHg1XwdYSj8rjAA3HXXXS7u3Llz2dctWLDAxZdffnlsW/JZsiJpeu+991ycfN6IfwnYmjVrXDxnTuUfjPtt5je/+Y2Lk0NJN998c8Xv2Sg6IxSR4KkQikjw9Klxgn9F/ahRo1z8pS99Kbaf3x1OTpAwfvx4F/vPPcn4JnR9apwTzdAO6m3evHkuPuOMM2LbkvMYNpI+NRYRKUOFUESCp0IoIsEL8vKZQw891MW33XZbbJs/keqnP/1pF7/zzjux/fwJKv1LAwDg+eefr0eaIoUxdepUFyfHCPNAZ4QiEjwVQhEJXjBd429/+9su/sEPfuDitm4Af/311108adKk2Db/sphGOPLII13sPwOlLa+88kqj0hGpm/33j5cd/5koWV1ipjNCEQmeCqGIBE+FUESCV6gxQv8hSslLWi666CIXd+zY0cXbt2+P7efPFvOd73zHxcnLZw466CAX9+7du6L8Lrnkktiy/xxmMn7nz8CBA138uc99rux7+mMqPXvqaZKST/444O7du2Pb8vD8Y50RikjwVAhFJHhNPftMctaK+fPnuzg5W0w5N9xwQ2z52GOPdbF/Op/kd0OTz1euxkcffRRbXr9+vYv939ETTzwR2++3v/2ti5csWdLWITT7TE6EOPvM5s17Hvl8wAEHxLZp9hkRkRxQIRSR4DX1p8bXX399bLnS7rDv9ttvrzmP5LNHPvjgg4pe99BDD7n4ueeei2178MEHa85LwuTfuXHLLfFnzk+cONHFH374YV2P26FDh9jy5MmTXewPM+XhGSVJOiMUkeDtsxCS7E3ySZJLSb5Gcmy0vjvJBSRXRt/TG/EUSZnaQbFVcka4G8CPzGwAgFMAfJ/kAAATALSYWV8ALdGySFGpHRTYPscIzawVQGsU7yC5DMARAEYAGBztNh3AnwHc1JAsy9i5c2dsua3nBB944IEu9u8Kacu2bdtcfMcdd8S2tba2uth/MA0AvPHGGxW9vzSPPLeDJH/i03HjxsW2ff7zn3fxjTfe6GJ/pqX28O96mjJlSmzb2Wef7eJXX33Vxf7YYV6068MSkkcDGAjgeQA9oz8OANgIYK/3d5EcA2BMDTmK5IraQfFU/GEJyU4AHgbwQzOL3aBrpSt+93qRqJlNMbOTdDGvFIHaQTFVdGcJyQMAPApgvpndGq1bAWCwmbWS7AXgz2bWbx/vk9kV9V/84hddPHjw4IpeM3PmTBfn4cbwGunOkho1Szvo1KmTi5cuXRrb5k8QsnbtWhf7l9UA8b93v6udnGDk4osv3utxAWDx4sUuHjZsmIs3bdrUZv6NVPWdJSxNizIVwLJPfvmRuQA+eQL6KABzak1SJK/UDoqtkjHC0wFcCeBVkouidZMA3AJgFsnRANYB+FZDMhTJB7WDAqvkU+NnAOz1dBJA7bMNiDQBtYNia+rZZ6RdNEaYE2m2g+OPPz627I97J7dVIjmBsF8/WlpaYtvGjx/v4kWLFrX7WI2g2WdERMpQIRSR4KlrHA51jXMiy3bQv39/F48cOdLFyQmK/bu2XnrpJRc/8MADsf0ee+wxF+/YsSO2LTnZcB6oaywiUoYKoYgET13jcKhrnBNqB9lR11hEpAwVQhEJngqhiARPhVBEgqdCKCLBUyEUkeCpEIpI8FQIRSR4KoQiEjwVQhEJngqhiARPhVBEgqdCKCLBq+QpdvW0BaUnffWI4izlIQcgvTz6pHAMqcwWAO8hrL+/fUkjj7JtINVpuNxByYVZTwmVhxzylIekKy+/d+VRoq6xiARPhVBEgpdVIZyS0XF9ecgByE8ekq68/N6VBzIaIxQRyRN1jUUkeCqEIhK8VAshyWEkV5BcRXJCisedRnIzySXeuu4kF5BcGX3vlkIevUk+SXIpyddIjs0qF8lOyO0gr20gtUJIsgOAOwF8HcAAACNJDkjp8PcCGJZYNwFAi5n1BdASLTfabgA/MrMBAE4B8P3o3yCLXCQDagf5bANpnhEOArDKzNaY2S4ADwAYkcaBzexpAFsTq0cAmB7F0wGcn0IerWb2UhTvALAMwBFZ5CKZCbod5LUNpFkIjwCw3lveEK3LSk8za43ijQB6pnlwkkcDGAjg+axzkVSpHUTy1Ab0YQkAK11DlNp1RCQ7AXgYwA/NbHuWuYh8Is2/vby1gTQL4RsAenvLR0brsrKJZC8AiL5vTuOgJA9A6Q/gPjObnWUukong20Ee20CahfAFAH1JfpZkRwCXApib4vGT5gIYFcWjAMxp9AFJEsBUAMvM7NYsc5HMBN0OctsGzCy1LwDDAbwOYDWAf03xuPcDaAXwIUpjMqMBHILSp1MrAfwJQPcU8jgDpVP+VwAsir6GZ5GLvrL7Crkd5LUN6BY7EQmePiwRkeCpEIpI8FQIRSR4KoQiEjwVQhEJngqhiARPhVBEgve/+L8hLEPzf8YAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)\n", "\n", "ax0.imshow(X_features[Y_label.index(0)], cmap='gray')\n", "ax0.set_title('label: 0')\n", "\n", "ax1.imshow(X_features[Y_label.index(1)], cmap='gray')\n", "ax1.set_title('label: 1')\n", "\n", "ax2.imshow(X_features[Y_label.index(2)], cmap='gray')\n", "ax2.set_title('label: 2')\n", "\n", "ax3.imshow(X_features[Y_label.index(3)], cmap='gray')\n", "ax3.set_title('label: 3')\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "According to [comments from the authors](http://yann.lecun.com/exdb/mnist/) of the MNIST database: *\"The original black and white (bilevel) images from NIST were size normalized \\[...\\]. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm\"*." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feed-Forward (FF)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As explained in other examples of network training, we like to begin by a Feed-Forward network each time we start to investigate a new problem. This is because:\n", "\n", "* Feed-Forward (Dense-based) networks are computationally faster to train than other architectures.\n", "* Such networks may be less sensitive to training hyperparameters and thus they are easier to train.\n", "* There are no reasonable day/night results depending on the architecture: one Feed-Forward - even if inappropriate - architecture may result in metrics that will not be made two times better by using an appropriate architecture. This mostly depends on hyperparameters and data processing." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "About this last item, one may look at the above-linked page which reports on classifier error rate with respect to method or architecture.\n", "\n", "As a condensate:\n", "\n", "* Single-layer FF (Perceptron) - 12.0% error.\n", "* Single-layer FF (Perceptron) - 7.6% error with deskewing (rotating a scanned image to compensate for skewing).\n", "* Two-layers FF with 800 hidden nodes - 1.6% error.\n", "* Committee of 35 Convolutional Neural Networks (CNN) - 0.23% error (best of all reported).\n", "\n", "First, we can see that from the very basic perceptron to the most evolved CNN there is an improvement from 88% to 99.77% accuracy. This is a huge gap, but both architectures succeed at recognizing patterns and drawing a decent model. \n", "\n", "Note that the gap is significantly reduced by simply applying *deskewing* to the image data, which highlights the critical importance of data pre-processing in the ability of the network to result in an accurate model.\n", "\n", "Also note that perceptrons are only able of linear regression while multi-layers Feed-Forward or CNN and other architectures can achieve non-linear regressions.\n", "\n", "When considering the two-layer Feed-Forward with 800 hidden nodes versus the state-of-the-art CNN, the accuracy goes from 98.4% to 99.77%. Overall, this illustrates that classical Feed-Forward networks may be a good alternative to more appropriate architectures when balancing between accuracy, difficulty of training and need for more computational time." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embedding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start from sample features containing values ranging from 0 to 255.\n", "\n", "The most important thing is to **normalize** sample features within \\[0, 1\\].\n", "\n", "The second thing is to set a batch size which, in this case, will make the regression converging faster.\n", "\n", "The last - but not least - thing is to one-hot encode the set of sample label. While one may not do this with binary classification problems and thus use a single node in the output dense layer - from which 0 or 1 decisions can be extracted -, the same can not reasonably be done with multiclass classification. Because we have ten distinct labels in the game, the reasonable choice is to use 10 output nodes and to one-hot encode the set of sample label." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "embedding = Embedding(X_data=X_features,\n", " Y_data=Y_label,\n", " X_scale=True,\n", " Y_encode=True,\n", " batch_size=32,\n", " relative_size=(2, 1, 0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "About data normalization:\n", "\n", "* It diminishes the complexity of the problem.\n", "* It accelerates training and makes it easier.\n", "* If no normalization is done in this example, the loss function and gradients will both readily explode and generate floating points errors." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "About this last point, remember that exponentials are used in the context of neural networks. If you begin feeding exponential functions with large numbers such as 255, you can be sure you will get floatting points errors pretty quickly: the computer typically deals with 64-bits float number, which has a maximal value of ``1.8 * 10 ** 308``. \n", "\n", "Note that in previously used 32-bits processors architecture, this maximal value was down to ``3.4 * 10 ** 38``. Compare with ``exp(255)`` which outputs ``5.56 * 10 ** 110`` and conclude that un-normalized data in the context of this notebook would have overflowed straight away on a - not that ancient - 32-bits system. Actually, if you attempt to run this notebook on a [Raspberry Pi](https://www.raspberrypi.org/) - why not? - you would overflow without normalization because the device relies on 32-bits CPUs. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Flatten-(Dense)n with Dropout" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We already explained why we need to use a *flatten* layer when attempting to forward arrays of more than 2 dimensions to the *dense* layer: image data have three dimensions, and the array of sample features - the array of sample images - has four dimensions. Therefore we need to flatten this array to give it a bi-dimensional shape." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "tags": [] }, "outputs": [], "source": [ "name = 'Flatten_Dropout02_Dense-64-relu_Dropout05_Dense-10-softmax'\n", "\n", "se_hPars['learning_rate'] = 0.001\n", "se_hPars['softmax_temperature'] = 5\n", "\n", "flatten = Flatten()\n", "\n", "dropout1 = Dropout(drop_prob=0.2)\n", "\n", "hidden_dense = Dense(64, relu)\n", "\n", "dropout2 = Dropout(drop_prob=0.5)\n", "\n", "dense = Dense(10, softmax)\n", "\n", "layers = [embedding, flatten, dropout1, hidden_dense, dropout2, dense]\n", "\n", "model = EpyNN(layers=layers, name=name)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "We have set up a custom learning rate, along with the 10 units in the output layers because we have labels one-hot encoded over 10 digits and again because we have ten distinct labels.\n", "\n", "Dropout layers were set to reduce overfitting and the *softmax temperature* was set higher than one to smooth the output probability distribution which practically results in more numerically stable training." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m--- EpyNN Check OK! --- \u001b[0m\r" ] } ], "source": [ "model.initialize(loss='CCE', seed=1, se_hPars=se_hPars.copy(), end='\\r')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have set the loss function to use *Categorical Cross Entropy* instead of *Binary Cross Entropy*. This is because we have ten distinct labels instead of two." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[37mEpoch 99 - Batch 14/14 - Accuracy: 0.906 Cost: 0.36511 - TIME: 15.75s RATE: 6.35e+00e/s TTC: 0s \u001b[0m\n", "\n", "+-------+----------+----------+----------+-------+--------+-------+-----------------------------------------------------------------------+\n", "| \u001b[1m\u001b[37mepoch\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[32maccuracy\u001b[0m | | \u001b[1m\u001b[31mCCE\u001b[0m | | \u001b[37mExperiment\u001b[0m |\n", "| | \u001b[37mDense\u001b[0m | \u001b[37mDense\u001b[0m | \u001b[1m\u001b[32mdtrain\u001b[0m | \u001b[1m\u001b[32mdval\u001b[0m | \u001b[1m\u001b[31mdtrain\u001b[0m | \u001b[1m\u001b[31mdval\u001b[0m | |\n", "+-------+----------+----------+----------+-------+--------+-------+-----------------------------------------------------------------------+\n", "| \u001b[1m\u001b[37m0\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.240\u001b[0m | \u001b[1m\u001b[32m0.220\u001b[0m | \u001b[1m\u001b[31m2.235\u001b[0m | \u001b[1m\u001b[31m2.239\u001b[0m | \u001b[37m1635012330_Flatten_Dropout02_Dense-64-relu_Dropout05_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m10\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.700\u001b[0m | \u001b[1m\u001b[32m0.664\u001b[0m | \u001b[1m\u001b[31m1.155\u001b[0m | \u001b[1m\u001b[31m1.258\u001b[0m | \u001b[37m1635012330_Flatten_Dropout02_Dense-64-relu_Dropout05_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m20\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.798\u001b[0m | \u001b[1m\u001b[32m0.740\u001b[0m | \u001b[1m\u001b[31m0.741\u001b[0m | \u001b[1m\u001b[31m0.910\u001b[0m | \u001b[37m1635012330_Flatten_Dropout02_Dense-64-relu_Dropout05_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m30\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.830\u001b[0m | \u001b[1m\u001b[32m0.764\u001b[0m | \u001b[1m\u001b[31m0.618\u001b[0m | \u001b[1m\u001b[31m0.819\u001b[0m | \u001b[37m1635012330_Flatten_Dropout02_Dense-64-relu_Dropout05_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m40\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.864\u001b[0m | \u001b[1m\u001b[32m0.808\u001b[0m | \u001b[1m\u001b[31m0.474\u001b[0m | \u001b[1m\u001b[31m0.714\u001b[0m | \u001b[37m1635012330_Flatten_Dropout02_Dense-64-relu_Dropout05_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m50\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.872\u001b[0m | \u001b[1m\u001b[32m0.788\u001b[0m | \u001b[1m\u001b[31m0.435\u001b[0m | \u001b[1m\u001b[31m0.722\u001b[0m | \u001b[37m1635012330_Flatten_Dropout02_Dense-64-relu_Dropout05_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m60\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.872\u001b[0m | \u001b[1m\u001b[32m0.820\u001b[0m | \u001b[1m\u001b[31m0.425\u001b[0m | \u001b[1m\u001b[31m0.627\u001b[0m | \u001b[37m1635012330_Flatten_Dropout02_Dense-64-relu_Dropout05_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m70\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.900\u001b[0m | \u001b[1m\u001b[32m0.824\u001b[0m | \u001b[1m\u001b[31m0.351\u001b[0m | \u001b[1m\u001b[31m0.607\u001b[0m | \u001b[37m1635012330_Flatten_Dropout02_Dense-64-relu_Dropout05_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m80\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.926\u001b[0m | \u001b[1m\u001b[32m0.824\u001b[0m | \u001b[1m\u001b[31m0.292\u001b[0m | \u001b[1m\u001b[31m0.729\u001b[0m | \u001b[37m1635012330_Flatten_Dropout02_Dense-64-relu_Dropout05_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m90\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.930\u001b[0m | \u001b[1m\u001b[32m0.848\u001b[0m | \u001b[1m\u001b[31m0.269\u001b[0m | \u001b[1m\u001b[31m0.616\u001b[0m | \u001b[37m1635012330_Flatten_Dropout02_Dense-64-relu_Dropout05_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m99\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[37m1.00e-03\u001b[0m | \u001b[1m\u001b[32m0.942\u001b[0m | \u001b[1m\u001b[32m0.828\u001b[0m | \u001b[1m\u001b[31m0.232\u001b[0m | \u001b[1m\u001b[31m0.625\u001b[0m | \u001b[37m1635012330_Flatten_Dropout02_Dense-64-relu_Dropout05_Dense-10-softmax\u001b[0m |\n", "+-------+----------+----------+----------+-------+--------+-------+-----------------------------------------------------------------------+\n" ] } ], "source": [ "model.train(epochs=100, init_logs=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again the Feed-Forward network has been quite performant to reproduce the training data, less about the validation ones. Overfitting is a general problem in the field." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "model.plot(path=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will not optimize this Feed-Forward model and instead we will go ahead with Convolutional networks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convolutional Neural Network (CNN)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As detailed in [Basics with images](epynnlive/dummy_image/train.html#Convolutional-Neural-Network-(CNN)), CNNs are preferred when it comes to considering *a priori* spatial relationships between data points." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embedding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using same embedding configuration as above." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "embedding = Embedding(X_data=X_features,\n", " Y_data=Y_label,\n", " X_scale=True,\n", " Y_encode=True,\n", " batch_size=32,\n", " relative_size=(2, 1, 0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that we use only *500 + 250* sample images herein. Benchmarks using the MNIST Database generally employ the whole 60 000 images for the training set itself, but it is likely to overload the RAM or not run at all on many system configurations. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Conv-MaxPool-Flatten-Dense" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using a single *Convolution-Pooling* block with ``16`` filters and *4x4* filter window for the *Convolution* layer." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "tags": [] }, "outputs": [], "source": [ "name = 'Convolution-16-2_Pooling-3-Max_Flatten_Dense-10-softmax'\n", "\n", "se_hPars['learning_rate'] = 0.005\n", "se_hPars['softmax_temperature'] = 5\n", "\n", "layers = [\n", " embedding,\n", " Convolution(unit_filters=16, filter_size=(2, 2), activate=relu),\n", " Pooling(pool_size=(2, 2)),\n", " Flatten(),\n", " Dense(10, softmax)\n", "]\n", "\n", "model = EpyNN(layers=layers, name=name)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "Initialize the model with *Categorical Cross Entropy* loss." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m--- EpyNN Check OK! --- \u001b[0m\r" ] } ], "source": [ "model.initialize(loss='CCE', seed=1, se_hPars=se_hPars.copy(), end='\\r')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train for 20 epochs." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[37mEpoch 9 - Batch 14/14 - Accuracy: 0.75 Cost: 0.52838 - TIME: 7.37s RATE: 1.36e+00e/s TTC: 1s \u001b[0m\n", "\n", "+-------+-------------+----------+----------+-------+--------+-------+--------------------------------------------------------------------+\n", "| \u001b[1m\u001b[37mepoch\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[32maccuracy\u001b[0m | | \u001b[1m\u001b[31mCCE\u001b[0m | | \u001b[37mExperiment\u001b[0m |\n", "| | \u001b[37mConvolution\u001b[0m | \u001b[37mDense\u001b[0m | \u001b[1m\u001b[32mdtrain\u001b[0m | \u001b[1m\u001b[32mdval\u001b[0m | \u001b[1m\u001b[31mdtrain\u001b[0m | \u001b[1m\u001b[31mdval\u001b[0m | |\n", "+-------+-------------+----------+----------+-------+--------+-------+--------------------------------------------------------------------+\n", "| \u001b[1m\u001b[37m0\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.594\u001b[0m | \u001b[1m\u001b[32m0.564\u001b[0m | \u001b[1m\u001b[31m1.880\u001b[0m | \u001b[1m\u001b[31m1.916\u001b[0m | \u001b[37m1635012346_Convolution-16-2_Pooling-3-Max_Flatten_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m1\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.728\u001b[0m | \u001b[1m\u001b[32m0.716\u001b[0m | \u001b[1m\u001b[31m1.083\u001b[0m | \u001b[1m\u001b[31m1.149\u001b[0m | \u001b[37m1635012346_Convolution-16-2_Pooling-3-Max_Flatten_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m2\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.770\u001b[0m | \u001b[1m\u001b[32m0.728\u001b[0m | \u001b[1m\u001b[31m0.767\u001b[0m | \u001b[1m\u001b[31m0.824\u001b[0m | \u001b[37m1635012346_Convolution-16-2_Pooling-3-Max_Flatten_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m3\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.782\u001b[0m | \u001b[1m\u001b[32m0.792\u001b[0m | \u001b[1m\u001b[31m0.661\u001b[0m | \u001b[1m\u001b[31m0.748\u001b[0m | \u001b[37m1635012346_Convolution-16-2_Pooling-3-Max_Flatten_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m4\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.814\u001b[0m | \u001b[1m\u001b[32m0.788\u001b[0m | \u001b[1m\u001b[31m0.573\u001b[0m | \u001b[1m\u001b[31m0.668\u001b[0m | \u001b[37m1635012346_Convolution-16-2_Pooling-3-Max_Flatten_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m5\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.820\u001b[0m | \u001b[1m\u001b[32m0.784\u001b[0m | \u001b[1m\u001b[31m0.585\u001b[0m | \u001b[1m\u001b[31m0.707\u001b[0m | \u001b[37m1635012346_Convolution-16-2_Pooling-3-Max_Flatten_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m6\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.866\u001b[0m | \u001b[1m\u001b[32m0.816\u001b[0m | \u001b[1m\u001b[31m0.434\u001b[0m | \u001b[1m\u001b[31m0.613\u001b[0m | \u001b[37m1635012346_Convolution-16-2_Pooling-3-Max_Flatten_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m7\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.774\u001b[0m | \u001b[1m\u001b[32m0.744\u001b[0m | \u001b[1m\u001b[31m0.610\u001b[0m | \u001b[1m\u001b[31m0.773\u001b[0m | \u001b[37m1635012346_Convolution-16-2_Pooling-3-Max_Flatten_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m8\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.842\u001b[0m | \u001b[1m\u001b[32m0.816\u001b[0m | \u001b[1m\u001b[31m0.426\u001b[0m | \u001b[1m\u001b[31m0.639\u001b[0m | \u001b[37m1635012346_Convolution-16-2_Pooling-3-Max_Flatten_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m9\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.878\u001b[0m | \u001b[1m\u001b[32m0.856\u001b[0m | \u001b[1m\u001b[31m0.389\u001b[0m | \u001b[1m\u001b[31m0.580\u001b[0m | \u001b[37m1635012346_Convolution-16-2_Pooling-3-Max_Flatten_Dense-10-softmax\u001b[0m |\n", "+-------+-------------+----------+----------+-------+--------+-------+--------------------------------------------------------------------+\n" ] } ], "source": [ "model.train(epochs=10, init_logs=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The overfitting using this CNN architecture is reduced compared with the pure Feed-Forward design. However, accuracy remains slightly lower for both training and validation data.\n", "\n", "Note the ``learning_rate`` may be a bit too high here: there are jumps up and down for both accuracy and loss, particularly visible on the validation set." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "model.plot(path=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looking at the plot, the model seems to have converged.\n", "\n", "We will try to optimize it further." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Conv-MaxPool-Flatten-(Dense)n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below we added one ``Dense(64, relu)`` hidden layer. This is the same design as used for the Feed-Forward network excepted, of course, for the *Convolution-Pooling* block. The thing to pay attention to here is the absence of *Dropout* layer compared to the Feed-Forward network." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[37mEpoch 9 - Batch 14/14 - Accuracy: 0.906 Cost: 0.5155 - TIME: 9.21s RATE: 1.09e+00e/s TTC: 2s \u001b[0m\n", "\n", "+-------+-------------+----------+----------+----------+-------+--------+-------+----------------------------------------------------------------------------------+\n", "| \u001b[1m\u001b[37mepoch\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[37mlrate\u001b[0m | \u001b[1m\u001b[32maccuracy\u001b[0m | | \u001b[1m\u001b[31mCCE\u001b[0m | | \u001b[37mExperiment\u001b[0m |\n", "| | \u001b[37mConvolution\u001b[0m | \u001b[37mDense\u001b[0m | \u001b[37mDense\u001b[0m | \u001b[1m\u001b[32mdtrain\u001b[0m | \u001b[1m\u001b[32mdval\u001b[0m | \u001b[1m\u001b[31mdtrain\u001b[0m | \u001b[1m\u001b[31mdval\u001b[0m | |\n", "+-------+-------------+----------+----------+----------+-------+--------+-------+----------------------------------------------------------------------------------+\n", "| \u001b[1m\u001b[37m0\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.570\u001b[0m | \u001b[1m\u001b[32m0.520\u001b[0m | \u001b[1m\u001b[31m1.822\u001b[0m | \u001b[1m\u001b[31m1.863\u001b[0m | \u001b[37m1635012353_Convolution-16-2_Pooling-3-Max_Flatten_Dense-64-relu_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m1\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.628\u001b[0m | \u001b[1m\u001b[32m0.564\u001b[0m | \u001b[1m\u001b[31m1.172\u001b[0m | \u001b[1m\u001b[31m1.293\u001b[0m | \u001b[37m1635012353_Convolution-16-2_Pooling-3-Max_Flatten_Dense-64-relu_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m2\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.630\u001b[0m | \u001b[1m\u001b[32m0.632\u001b[0m | \u001b[1m\u001b[31m1.011\u001b[0m | \u001b[1m\u001b[31m1.081\u001b[0m | \u001b[37m1635012353_Convolution-16-2_Pooling-3-Max_Flatten_Dense-64-relu_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m3\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.776\u001b[0m | \u001b[1m\u001b[32m0.740\u001b[0m | \u001b[1m\u001b[31m0.680\u001b[0m | \u001b[1m\u001b[31m0.798\u001b[0m | \u001b[37m1635012353_Convolution-16-2_Pooling-3-Max_Flatten_Dense-64-relu_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m4\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.852\u001b[0m | \u001b[1m\u001b[32m0.812\u001b[0m | \u001b[1m\u001b[31m0.514\u001b[0m | \u001b[1m\u001b[31m0.623\u001b[0m | \u001b[37m1635012353_Convolution-16-2_Pooling-3-Max_Flatten_Dense-64-relu_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m5\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.770\u001b[0m | \u001b[1m\u001b[32m0.736\u001b[0m | \u001b[1m\u001b[31m0.621\u001b[0m | \u001b[1m\u001b[31m0.773\u001b[0m | \u001b[37m1635012353_Convolution-16-2_Pooling-3-Max_Flatten_Dense-64-relu_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m6\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.912\u001b[0m | \u001b[1m\u001b[32m0.820\u001b[0m | \u001b[1m\u001b[31m0.348\u001b[0m | \u001b[1m\u001b[31m0.587\u001b[0m | \u001b[37m1635012353_Convolution-16-2_Pooling-3-Max_Flatten_Dense-64-relu_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m7\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.830\u001b[0m | \u001b[1m\u001b[32m0.780\u001b[0m | \u001b[1m\u001b[31m0.520\u001b[0m | \u001b[1m\u001b[31m0.676\u001b[0m | \u001b[37m1635012353_Convolution-16-2_Pooling-3-Max_Flatten_Dense-64-relu_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m8\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.900\u001b[0m | \u001b[1m\u001b[32m0.856\u001b[0m | \u001b[1m\u001b[31m0.315\u001b[0m | \u001b[1m\u001b[31m0.536\u001b[0m | \u001b[37m1635012353_Convolution-16-2_Pooling-3-Max_Flatten_Dense-64-relu_Dense-10-softmax\u001b[0m |\n", "| \u001b[1m\u001b[37m9\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[37m5.00e-03\u001b[0m | \u001b[1m\u001b[32m0.902\u001b[0m | \u001b[1m\u001b[32m0.860\u001b[0m | \u001b[1m\u001b[31m0.290\u001b[0m | \u001b[1m\u001b[31m0.506\u001b[0m | \u001b[37m1635012353_Convolution-16-2_Pooling-3-Max_Flatten_Dense-64-relu_Dense-10-softmax\u001b[0m |\n", "+-------+-------------+----------+----------+----------+-------+--------+-------+----------------------------------------------------------------------------------+\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "name = 'Convolution-16-2_Pooling-3-Max_Flatten_Dense-64-relu_Dense-10-softmax'\n", "\n", "se_hPars['learning_rate'] = 0.005\n", "se_hPars['softmax_temperature'] = 5\n", "\n", "layers = [\n", " embedding,\n", " Convolution(unit_filters=16, filter_size=(2, 2), activate=relu),\n", " Pooling(pool_size=(2, 2)),\n", " Flatten(),\n", " Dense(64, relu),\n", " Dense(10, softmax)\n", "]\n", "\n", "model = EpyNN(layers=layers, name=name)\n", "\n", "model.initialize(loss='CCE', seed=1, se_hPars=se_hPars.copy(), end='\\r')\n", "\n", "model.train(epochs=10, init_logs=False)\n", "\n", "model.plot(path=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This result is better than the one obtained with the Feed-Forward network. Although we did not use *Dropout* regularization, there is less overfitting and higher accuracy on the validation set." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For code, maths and pictures behind the convolution and pooling layers, follow these links:\n", "\n", "* [Convolution (CNN)](https://epynn.net/Convolution.html)\n", "* [Pooling (CNN)](https://epynn.net/Pooling.html)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Write, read & Predict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A trained model can be written on disk such as:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m\u001b[32mMake: /media/synthase/beta/EpyNN/epynnlive/captcha_mnist/models/1635012353_Convolution-16-2_Pooling-3-Max_Flatten_Dense-64-relu_Dense-10-softmax.pickle\u001b[0m\n" ] } ], "source": [ "model.write()\n", "\n", "# model.write(path=/your/custom/path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A model can be read from disk such as:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "model = read_model()\n", "\n", "# model = read_model(path=/your/custom/path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can retrieve new features and predict on them." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "X_features, _ = prepare_dataset(N_SAMPLES=10)\n", "\n", "dset = model.predict(X_features, X_scale=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Results can be extracted such as:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 4 [1.31587431e-03 4.08317943e-05 1.30351360e-03 1.47063124e-03\n", " 9.78023255e-01 1.93182534e-03 1.40558825e-03 3.30961357e-04\n", " 4.94257777e-04 1.36832613e-02]\n", "1 1 [7.28977838e-05 9.29707175e-01 3.32038257e-02 1.98144195e-03\n", " 2.40938005e-04 1.34165640e-02 1.50138537e-03 1.67654867e-03\n", " 1.77799721e-02 4.19250987e-04]\n", "2 3 [4.78064492e-03 3.29830971e-06 1.58072929e-03 9.54198963e-01\n", " 6.41479062e-05 1.35982207e-02 5.87131031e-06 3.48727032e-05\n", " 2.55519849e-02 1.81267435e-04]\n", "3 3 [4.14994350e-03 2.29267766e-03 1.15714086e-02 5.79869691e-01\n", " 3.52961208e-04 6.38732938e-02 1.03465443e-04 9.56096045e-03\n", " 3.23308022e-01 4.91757634e-03]\n", "4 1 [4.89436878e-06 9.80078656e-01 1.78383500e-03 7.31766268e-06\n", " 7.22962122e-05 1.67800975e-04 1.39517610e-04 3.73043330e-04\n", " 1.73534589e-02 1.91798796e-05]\n", "5 4 [1.27447904e-03 3.62544371e-03 1.60817813e-02 5.22806853e-04\n", " 5.99689165e-01 3.52024299e-01 1.56190299e-02 1.34865787e-04\n", " 8.51033165e-03 2.51779700e-03]\n", "6 2 [4.34791238e-02 5.76774386e-03 8.07714111e-01 1.08754447e-02\n", " 7.84182950e-04 2.48067747e-02 1.02054993e-02 2.28036464e-02\n", " 6.73872013e-02 6.17627146e-03]\n", "7 0 [9.99507456e-01 5.13513994e-09 3.37688225e-05 3.65112950e-05\n", " 4.74996217e-08 2.58860168e-04 1.39608623e-04 7.52209081e-06\n", " 1.61164137e-05 1.04078231e-07]\n", "8 1 [3.17634507e-05 9.73326939e-01 7.03134936e-03 4.32744881e-04\n", " 1.55381335e-04 2.48030926e-03 6.37703894e-04 2.73099384e-03\n", " 1.27339923e-02 4.38822977e-04]\n", "9 9 [2.59605004e-05 2.29246440e-05 1.49105703e-05 2.63219182e-04\n", " 4.59102905e-02 8.82117671e-04 5.73083834e-05 7.31917762e-02\n", " 1.00839461e-03 8.78623098e-01]\n" ] } ], "source": [ "for n, pred, probs in zip(dset.ids, dset.P, dset.A):\n", " print(n, pred, probs)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 4 }