mirror of
https://github.com/kuhyx/ARAI.git
synced 2026-07-04 12:03:16 +02:00
3201 lines
1.7 MiB
Plaintext
3201 lines
1.7 MiB
Plaintext
|
|
{
|
|||
|
|
"nbformat": 4,
|
|||
|
|
"nbformat_minor": 0,
|
|||
|
|
"metadata": {
|
|||
|
|
"kernelspec": {
|
|||
|
|
"name": "python3",
|
|||
|
|
"display_name": "Python 3"
|
|||
|
|
},
|
|||
|
|
"colab": {
|
|||
|
|
"provenance": [],
|
|||
|
|
"gpuType": "T4",
|
|||
|
|
"include_colab_link": true
|
|||
|
|
},
|
|||
|
|
"accelerator": "GPU",
|
|||
|
|
"language_info": {
|
|||
|
|
"name": "python"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"cells": [
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "view-in-github",
|
|||
|
|
"colab_type": "text"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<a href=\"https://colab.research.google.com/github/kuhyx/ARAI/blob/main/TWM_KerasIntro.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "c9tFwxET8YY5"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"%matplotlib inline"
|
|||
|
|
],
|
|||
|
|
"execution_count": 1,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "M9PoJm628YY_"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Introduction to Deep Learning with Keras and TensorFlow\n",
|
|||
|
|
"\n",
|
|||
|
|
"Based on excelent work by **[Daniel Moser (UT Southwestern Medical Center)](https://github.com/AviatorMoser/keras-mnist-tutorial)**, Resources: **[Xavier Snelgrove](https://github.com/wxs/keras-mnist-tutorial), [Yash Katariya](https://github.com/yashk2810/MNIST-Keras)**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "lh0dsQg08YZA"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"To help you understand the fundamentals of deep learning, this demo will walk through the basic steps of building two toy models for classifying handwritten numbers with accuracies surpassing 95%. The first model will be a basic fully-connected neural network, and the second model will be a deeper network that introduces the concepts of convolution and pooling."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "LJVXVJ6_8YZB"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## The Task for the AI\n",
|
|||
|
|
"\n",
|
|||
|
|
"Our goal is to construct and train an artificial neural network on thousands of images of handwritten digits so that it may successfully identify others when presented. The data that will be incorporated is the MNIST database which contains 60,000 images for training and 10,000 test images. We will use the Keras Python API with TensorFlow as the backend."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "3eVu55U98YZC"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src=\"https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/mnist.png\" >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Ms_a4hzY8YZC"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Prerequisite Python Modules\n",
|
|||
|
|
"\n",
|
|||
|
|
"First, some software needs to be loaded into the Python environment."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "D7Bns2Lo8YZD",
|
|||
|
|
"outputId": "4037ddf7-5ac4-4f77-f482-f0a02108c74c",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"import numpy as np # advanced math library\n",
|
|||
|
|
"import matplotlib.pyplot as plt # MATLAB like plotting routines\n",
|
|||
|
|
"import random # for generating random numbers\n",
|
|||
|
|
"\n",
|
|||
|
|
"import tensorflow as tf\n",
|
|||
|
|
"\n",
|
|||
|
|
"from keras.datasets import mnist # MNIST dataset is included in Keras%\n",
|
|||
|
|
"from keras.models import Sequential # Model type to be used\n",
|
|||
|
|
"\n",
|
|||
|
|
"from keras.layers import Dense, Dropout, Activation # Types of layers to be used in our model\n",
|
|||
|
|
"from keras.utils import to_categorical # NumPy related tools\n",
|
|||
|
|
"\n",
|
|||
|
|
"from keras import optimizers\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"from sklearn.metrics import confusion_matrix\n",
|
|||
|
|
"import itertools\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(tf.__version__)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 2,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"2.15.0\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "QygDQ7Ch8YZG"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Loading Training Data\n",
|
|||
|
|
"\n",
|
|||
|
|
"The MNIST dataset is conveniently bundled within Keras, and we can easily analyze some of its features in Python."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "SdZZph6i8YZG",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "ee9e11b8-6b44-4fa0-f974-e83d7999e930"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The MNIST data is split between 60,000 28 x 28 pixel training images and 10,000 28 x 28 pixel images\n",
|
|||
|
|
"(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(\"X_train shape\", X_train.shape)\n",
|
|||
|
|
"print(\"y_train shape\", y_train.shape)\n",
|
|||
|
|
"print(\"X_test shape\", X_test.shape)\n",
|
|||
|
|
"print(\"y_test shape\", y_test.shape)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 3,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
|
|||
|
|
"11490434/11490434 [==============================] - 2s 0us/step\n",
|
|||
|
|
"X_train shape (60000, 28, 28)\n",
|
|||
|
|
"y_train shape (60000,)\n",
|
|||
|
|
"X_test shape (10000, 28, 28)\n",
|
|||
|
|
"y_test shape (10000,)\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "9v5sjP468YZJ"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Using matplotlib, we can plot some sample images from the training set directly into this Jupyter Notebook. We can egzamine the interclass variability - how many different ways of writing the same digit there are!"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Ov1T_A6X8YZJ",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 811
|
|||
|
|
},
|
|||
|
|
"outputId": "51c0509e-a421-4161-99c3-a44b346222bd"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"plt.rcParams['figure.figsize'] = (9,9) # Make the figures a bit bigger\n",
|
|||
|
|
"\n",
|
|||
|
|
"def visualize_classes(X, y):\n",
|
|||
|
|
" for i in range(0, 10):\n",
|
|||
|
|
" img_batch = X[y == i][0:10]\n",
|
|||
|
|
" img_batch = np.reshape(img_batch, (img_batch.shape[0]*img_batch.shape[1], img_batch.shape[2]))\n",
|
|||
|
|
" if i > 0:\n",
|
|||
|
|
" img = np.concatenate([img, img_batch], axis = 1)\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" img = img_batch\n",
|
|||
|
|
" plt.figure(figsize=(10,20))\n",
|
|||
|
|
" plt.axis('off')\n",
|
|||
|
|
" plt.imshow(img, cmap='gray')\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"visualize_classes(X_train, y_train)\n"
|
|||
|
|
],
|
|||
|
|
"execution_count": 4,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 1000x2000 with 1 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAMaCAYAAAABQDBSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz953Ncd3oljp/OOeeIHAgQzFSi5OFIMx55XF6vvVu1frP/nqtc663yfj0O47FHM8pDimIASRCxge5G5+7bOYffC/2eR7dBgCIlEOhu3lOFYgLA7otPeMJ5zpENBoMBJEiQIEGCBAkSJEiQIOEUIT/vFyBBggQJEiRIkCBBgoTJg5RoSJAgQYIECRIkSJAg4dQhJRoSJEiQIEGCBAkSJEg4dUiJhgQJEiRIkCBBggQJEk4dUqIhQYIECRIkSJAgQYKEU4eUaEiQIEGCBAkSJEiQIOHUISUaEiRIkCBBggQJEiRIOHVIiYYECRIkSJAgQYIECRJOHcqX/USZTPY6X4cECRIkSJAgQYIECRLGBC/j+S11NCRIkCBBggQJEiRIkHDqkBINCRIkSJAgQYIECRIknDqkREOCBAkSJEiQIEGCBAmnjpee0ZAgQYIECRIkvDmQyWSQy+VQKpWQy7+vSw4GA/R6PXQ6nXN8dRIkSBgHSImGBAkSJEiQIIEhl8uh0+mg0Whgt9tx9epVeDweFoXpdrtYX1/H3bt30Wq1zvnVSpAgYZQhJRoSJEiQIEGCBIZcLofBYIDJZMLCwgL+5m/+BhcvXuQOR7vdxt///d/j0aNHUqIhQYKEF0JKNCRIkCBBggQJ0Gg00Gq1UKvVcLlcsFqtcLvdsNlssFqt6Pf76PV6kMlkUCql8EGCBAk/DOmkkCBBggQJEt5wKJVKLCwsYGVlBQaDAfPz8/D5fLDb7VhYWIDFYkGpVEImk0GlUkG5XH4pDX0JEiS82ZASDQkSJEiQIOENh1wuRzAYxM2bN2Gz2XDx4kVMT09DpVJBr9dDpVKhXC6jXC5DEATU63Up0ZAgQcIPYqISDZVKBaVSCYVCAaPRCK1WC6VSyYdkq9VCo9FAt9tFo9FArVZDv99Hp9NBt9s975cvQYIECRIkvHbQrIVMJoPJZILdbodOp8PU1BQ8Hg/MZjMMBgNUKhXkcjnfkcViEbFYDPl8HoVCAb1e77zfigQJEkYcE5NoiA9Mg8GACxcuIBQKwWq1YmFhAXa7Hel0Gtvb26hUKtjb28Pm5ibq9ToEQUCpVDrvtyBBggQJEiS8digUCuh0OiiVSly+fBkffvghnE4n5ufnMT8/D5VKBZPJBK1Wi3a7jVKphGaziSdPnuA3v/kNkskkEomEJG8rQYKEH8REJRoajQZGoxEWiwVTU1NYWlqCy+XC9evX4fF4EI1GodVqIQgCOp0Okskk5HI5qtXqeb98CRIkSJAg4Uwgk8mgUqmgVqvh8Xhw6dIleL1e+Hw++P1+9swYDAbodDpoNpuo1WrIZDLY3t5GPB5Hs9mUOhoSJEj4QYx9oqFSqVglY2FhAaurq7BYLFheXkYwGITFYoFGowEA6HQ6+Hw+mM1m9Ho9qFQqlEolfP3118jn8xLf9BgolUo4nU5YLBZotVp4PB4YDAb+98FggGQyiWg0ilarhXq9jmazeY6v+GxAtAOiIMjlcq4CqtVqNrkCgGKxiEKhgH6/j36/f86v/HShUCigVquhUCgAgJ+JXq+HVquFRqOBz+eDzWb70f9Hs9lEpVJBu91GJpPhSqq0X0+GTCaDwWCATqeDTqdDKBSC2WyGIAiIRCKo1WrodDpot9tj/RzF++8kGI1G2O12pgHRGtVoNFCpVGg0GkgkEqhUKjCbzXC73dDr9TAYDDCbzZDL5RgMBhgMBmi328hms6hWqxx4j4u8K71npVIJs9mMmZkZmM1mLC4usrKUTqeDTCZDp9NBJpPhWYxkMolqtYpIJIJqtYp2u41erzfWa0fC+ID2uUqlwtTUFHw+H2QyGd+1hUIBe3t7qNVq6Ha7aLfb5/2SJYgw9omGTqeDy+WCyWTCn/3Zn+Gv/uqvYDQamWOqUCig1WoBAFarFSsrK+j1elhcXMS7776LXC6HcrmMjY0NqTpzDNRqNZaWlrC8vAyPx4Nbt24hFAoNGTf9/ve/xz/90z9BEAQkk0m0Wq2Jv4AUCgUHLjQbZDKZMDs7C6vVCpVKxQnZxsYGHj58yAHJJCUbKpUKVquVk3makfJ6vfB4PHA4HPjwww9x8eJFAPhR6yKbzWJ3dxflchlffPEFfve736FarU5k4nZaUCgUcDqd8Hg88Pv9+PWvf43l5WU8evQI//AP/4BoNIpKpYJisTi25x4lDWq1GiqVis+kowiFQrhy5QpMJhMUCgXvXZvNBqPRiHQ6jf/8z//E3t4e5ufn8bOf/QwulwtTU1NYWFiASqVCt9tFv9+HIAi4e/cu9vf3EYvF8OWXX45NoqFQKGAymWAwGBAMBvHRRx8hGAxienoac3Nz0Ov1UKvVkMlkaDabWF9fx9OnT1GtVnF4eIhKpYJ4PI58Pi8Ngks4UygUCiiVSlgsFvz85z/Hz3/+c2g0GhgMBmg0Gjx48AB///d/j2g0ysmGdDeMDsY20aBLRa1Ww2AwwGg0wuVyIRQKQa/Xc7WKMBgMoFQqYTQaAXyXoJjNZmg0GphMJsjlcvT7fenw/P+Dqn4qlQoWi4UDltnZWczOzvK/d7tdbG5uwmQyodVqTbS2urgaqlaruZpCa81oNMJms8HhcECj0UCv1wMADg8PoVaruQI4zgcg7Tvxc6DuhUwm4wvBarXC4XDA7XZjenoay8vLAMCV4VeB1WpFt9uFIAj8bFutFjqdzpk+S3EgS7//Me/nLCCXy6HVamE0GmG1WhEMBjE/P498Pg+DwTDUhRpl0Do7+mfxMLNGo4FGozkx0TCbzXC5XLBYLFAqlVwYcDgcMJvNAMDzCGazGT6fDz6fDzMzM1hcXIRarUa320W320Uul0MsFkO5XEaxWByr8468L9RqNUwmE9xuN/x+P5xOJ/R6PZ9XwHfFkHK5jEwmg2q1ilQqhUqlAkEQ0G63x/oMkzB6OGnvDgYD3utKpRIajQYulwvT09PQarUwmUzQaDTIZDLQ6XRQqVRjca69aRifU1IEMUVlcXERN2/ehN1ux4ULF6DT6YYu0cFgwAEe/Z6gVCqh1Wphs9ng9XrRbDZRrVa5Ij+KAcRZQCaTweFwwOl0wmq14tKlS7h27RqsVitMJhOA7yvT9JzGPYA+CdQRo8DE7/dDq9XCYrHAbDbzpa3X66HT6eB2u2EwGPhCHwwG0Ol0UCgUqFQq2N3dxeHh4ditLZlMBp1Ox3QoUqaxWCyYnp7mBJ4qxk6nE263G0ajEW63e2i9vCp0Oh0CgQBsNhsuXbqEXC6HYrGISCSCeDyOfr+Pbrf7Wp8pJdwajYaTKZlMhlKpxLS4UYJMJuNnb7FY0Gw2kclkUCwWR1ZhT6VSccIgDi5IQZD2HSUK9OF0OuFyuU6kTzmdTg5M5HI5FAoFF1Ho99evX4fb7cbc3ByWl5fhdDpht9vR7/fRbrdRqVRQr9eRzWZ5EDqfz4/VMDS9ZwrQ/H4/wuEwJ2BiDAYDtFotVKtVFItFpFIpFItFVKvVkV0/EsYLNCMk3oti0LkOABaLBTabDXa7HYFAAG63GwqFgvcnfTSbzbHak28KxjLRoKzWaDTiypUr+O///b/D5XJxK1ycHdMwGyUb9Cs5oOp0OjgcDgQCAaZjUNttXCkFPxVyuRxutxsrKytwOp1466238M477/AlJQ7ojgZ34xZA/xCIEqXRaLCwsIB33nkHVquVByc1Gg0Hc1TRp4BHJpOh1+vBbrfDZDIhn8+j1WohkUiM3XOiwNVms8FkMuHKlSsIhUJwuVxYW1uD1WoFAObLE32ROj7Aj18bRqMROp0OvV4P/X4fKpUKhUIBv//975HP53nO4HUGQGq1Gm63m7ugxGWPRCIolUojl2goFApepxaLBdVqFclkEoVCYWQDRUraKVmVyWTQarUIh8Ow2WywWCwIhUIwGo1Qq
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "uGjX5ZD68YZM"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Let's examine a single digit a little closer, and print out the array representing the last digit."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "wW_rrJfj8YZN",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "5b74571c-eab3-455f-fa2b-aa7d6f1f28f3"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# just a little function for pretty printing a matrix\n",
|
|||
|
|
"def matprint(mat, fmt=\"g\"):\n",
|
|||
|
|
" col_maxes = [max([len((\"{:\"+fmt+\"}\").format(x)) for x in col]) for col in mat.T]\n",
|
|||
|
|
" for x in mat:\n",
|
|||
|
|
" for i, y in enumerate(x):\n",
|
|||
|
|
" print((\"{:\"+str(col_maxes[i])+fmt+\"}\").format(y), end=\" \")\n",
|
|||
|
|
" print(\"\")\n",
|
|||
|
|
"\n",
|
|||
|
|
"# now print!\n",
|
|||
|
|
"matprint(X_train[-1])"
|
|||
|
|
],
|
|||
|
|
"execution_count": 5,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 38 48 48 22 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 62 97 198 243 254 254 212 27 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 67 172 254 254 225 218 218 237 248 40 0 21 164 187 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 89 219 254 97 67 14 0 0 92 231 122 23 203 236 59 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 25 217 242 92 4 0 0 0 0 4 147 253 240 232 92 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 101 255 92 0 0 0 0 0 0 105 254 254 177 11 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 167 244 41 0 0 0 7 76 199 238 239 94 10 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 192 121 0 0 2 63 180 254 233 126 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 190 196 14 2 97 254 252 146 52 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 130 225 71 180 232 181 60 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 130 254 254 230 46 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 6 77 244 254 162 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 110 254 218 254 116 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 131 254 154 28 213 86 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 66 209 153 19 19 233 60 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 142 254 165 0 14 216 167 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 90 254 175 0 18 229 92 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 26 229 249 176 222 244 44 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 73 193 197 134 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "s3koPAgMdzoG"
|
|||
|
|
},
|
|||
|
|
"source": [],
|
|||
|
|
"execution_count": 5,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "vQPPE2Bq8YZP"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Each pixel is an 8-bit integer from 0-255. 0 is full black, while 255 is full white. This what we call a single-channel pixel. It's called monochrome.\n",
|
|||
|
|
"\n",
|
|||
|
|
"*Fun-fact! Your computer screen has three channels for each pixel: red, green, blue. Each of these channels also likely takes an 8-bit integer. 3 channels -- 24 bits total -- 16,777,216 possible colors!*"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "z09IKy7PR9M9"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## More input data nalysis\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"Visualizing high-dimensional data by projecting it into a low-dimensional space is a classic operation that anyone working with data has probably done at least once in their life. There are a huge variety of methods for reducing dimensionality, but one very popular method is t-SNE, a method proposed by Geoffry Hinton’s group back in 2008. [ _more..._ ](https://mlexplained.com/2018/09/14/paper-dissected-visualizing-data-using-t-sne-explained/)\n",
|
|||
|
|
"\n",
|
|||
|
|
"Code by [Zaid Alyafeai](https://github.com/zaidalyafeai/Notebooks) [MIT license]"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "59m4-pjCSBgM"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"from sklearn.manifold import TSNE\n",
|
|||
|
|
"import matplotlib.patheffects as PathEffects\n",
|
|||
|
|
"import seaborn as sns\n",
|
|||
|
|
"\n",
|
|||
|
|
"RS=19238\n",
|
|||
|
|
"\n",
|
|||
|
|
"class_names = [ str(clid) for clid in range(10) ]\n",
|
|||
|
|
"\n",
|
|||
|
|
"def scatter(x, colors):\n",
|
|||
|
|
" # We choose a color palette with seaborn.\n",
|
|||
|
|
" palette = np.array(sns.color_palette(\"hls\", 10))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # We create a scatter plot.\n",
|
|||
|
|
" f = plt.figure(figsize=(8, 8))\n",
|
|||
|
|
" ax = plt.subplot(aspect='equal')\n",
|
|||
|
|
" sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40,\n",
|
|||
|
|
" c=palette[colors.astype(\"int\")])\n",
|
|||
|
|
" plt.xlim(-25, 25)\n",
|
|||
|
|
" plt.ylim(-25, 25)\n",
|
|||
|
|
" ax.axis('off')\n",
|
|||
|
|
" ax.axis('tight')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # We add the labels for each digit.\n",
|
|||
|
|
" txts = []\n",
|
|||
|
|
" for i in range(10):\n",
|
|||
|
|
" # Position of each label.\n",
|
|||
|
|
" xtext, ytext = np.median(x[colors == i, :], axis=0)\n",
|
|||
|
|
" txt = ax.text(xtext, ytext, class_names[i], fontsize=15)\n",
|
|||
|
|
" txt.set_path_effects([\n",
|
|||
|
|
" PathEffects.Stroke(linewidth=5, foreground=\"w\"),\n",
|
|||
|
|
" PathEffects.Normal()])\n",
|
|||
|
|
" txts.append(txt)\n",
|
|||
|
|
"\n",
|
|||
|
|
"def plot_tsne(X, y):\n",
|
|||
|
|
" print('calculating tsne ...')\n",
|
|||
|
|
" proj = TSNE(random_state=RS, learning_rate=200, init='random').fit_transform(X)\n",
|
|||
|
|
" scatter(proj, y)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 6,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "1_gbsFQ3Srw_",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 670
|
|||
|
|
},
|
|||
|
|
"outputId": "f20b5dea-b04d-4f61-857b-5785251c616e"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"X = np.reshape(X_train, (X_train.shape[0], 28 * 28))[0:2000]\n",
|
|||
|
|
"y = y_train[0:2000]\n",
|
|||
|
|
"plot_tsne(X, y)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 7,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"calculating tsne ...\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 800x800 with 1 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAJ8CAYAAABunRBBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd5hcVfnA8e+d3me29+ym995IbxB67yCgCEhRFEH5ISqCCIoICgIqglSR3msCoSQhPaSH9GzvO73P3N8fm93sZKdtSyHn8zw8D3vvufee2ezuvHPK+0qyLMsIgiAIgiAIxw3Fke6AIAiCIAiCcHiJAFAQBEEQBOE4IwJAQRAEQRCE44wIAAVBEARBEI4zIgAUBEEQBEE4zogAUBAEQRAE4TgjAkBBEARBEITjjAgABUEQBEEQjjMiABQEQRAEQTjOiABQEARBEAThOCMCQEEQBEEQhOOMCAAFQRAEQRCOMyIAFARBEARBOM6IAFAQBEEQBOE4IwJAQRAEQRCE44wIAAVBEARBEI4zIgAUBEEQBEE4zogAUBAEQRAE4TgjAkBBEARBEITjjAgABUEQBEEQjjMiABQEQRAEQTjOiABQEARBEAThOCMCQEEQBEEQhOOMCAAFQRAEQRCOMyIAFARBEARBOM6IAFAQBEEQBOE4IwJAQRAEQRCE44wIAAVBEARBEI4zIgAUBEEQBEE4zogAUBAEQRAE4TgjAkBBEARBEITjjAgABUEQBEEQjjOqI90BQRCE44l/906CNdUoTWYMo8YgqcSfYUEQDj/xl0cQBOEwCNZWU/fEowT2720/prTayPneDzBNnnoEeyYIwvFIkmVZPtKdEARB+C6LeL2U33kbkZbmzicVCopu/w36ocMPf8cEQThuiTWAgiAIfazh+afjB38A0Sgt779zeDskCMJxTwSAgiAIfajhv8/i/npp0jbezRsQkzGCIBxOIgAUBEHoI/59e3F88uGR7oYgCEInIgAUBEHoA/49u6j+8x/SamsYNRZJktK+d9TnJdRQTzQQ6G73BEE4zoldwIIgCL0sWFdL9Z//QNTnS91Yksg4/ay07htqaqTp1f/iXrMKwmEknQ7ztJlkXXApSqOxh70WBOF4IgJAQRCEXub45IO0g7+8629Oawdw2GGn6r7fEW5qbD8m+/04lywmsGcXRb+6G4VW25NuC4JwHBFTwIIgCL0kGvDj+nopruVfpdXedvLpmKdOS6ut/eP3Y4K/jgL79+Fa9mXa/RQEQRAjgIIgHFfkcBjf9i1EvF60/UrR5Bf2yn2dy76k4bmnkQP+tNrrhg4n87yL0mobamrE8fmnSdu4Vi7HOv+ktO4nCIIgAkBBEL5zwk4HUY8HVVY2Co2m/bjr66U0vvwiEXtL+zHDqLHkXnsDKqut28/zbt1E/ZOPp93eMGosBT/7RacycKGmRrwb1iFHIuiHjUSVnU3VH+8huH9fynvK/jSmnAVBEA4QAaAgCN8ZgYr9NL38X7xbNoIso9DrMc+aS9b5l+DbtoW6fz0Gh+Tb827eQPWD91Hyu/uRlMpuPbcriZwlnY6cK38QE/zJkQgNzz+N88slEI0ebKtWI4dCad1XWzogrXahhnocSxbh37UDSa3BNGkq5umzxPpBQTjOiABQEITvhGB1FVX33U3U520/FvX5cHzyIcGKciJ+f6fgr/3ainI861ZjmnxCl58ryzK+bVvSaqvKzCLvRz9GnZsfc7zp1f/ijDPFm27wh0KBdcHClM28mzdQ88hDyMGD6WN8Wzbh+PRjin75G5QWS3rPEwThmCc2gQiC8J3Q/NZrMcFfR75tWwju3Z30es8367r13KjLlV5DtYbSBx/ttOM34vXi+Gxxt54NgFJJ7tU/QltalrRZNOCn9vFHYoK/NsHKChpefKb7fRAE4ZgjRgAFQTjmyeEw7jUre3iTrpdiCzU2UHX/3THTtokYR41GUnT+zB3YszNuUJYOw4RJ5F75Q1S2jJRt3Su/Jur1JD6/ZiURp1OMAgrCcUKMAAqCcMyTI+HUQZgq+eddw6gxXX5uwwvPJEzNEkOSsJ1yRvxzcYLCdOX94Lq0gj+AYG118gaRCKHG+m73RRCEY4sYARQE4Zin0OpQ6PREk+yElRQKEo3xqfPyk67/827ZhGPxxwQqy1EaTZinzcAwbiLeDamnjSWdjtwrf5gw2bNu0FAURhNRjzvlvWKvG4LSnN5oXcTpJFidIgAElBZrl/ogCMKxS5Llbsx7CIIgHGX23/4zQnW1SdsYxk/Cu3E9RCLtx7T9B2I75QyCVRUgSRjHTkA3YGD7+eZ336T59Zc73UtdUEioJnlQZRgzjvybfoZCq0varuXDd2l6+cWkbTqSDAbKHnocpS75feFAfsL/PIkcTr2hxDhxMjnf+wGqjMy0+yIIwrFJBICCIHwn1D31D1xffd6la0wnzCCwf2+nQE43dDiFt9xOqLGBil//ott9yv7e97GdeErC81G/H9fyr/Bt30KwtoZQXS1yoHU9oKRWo8rKJlRbE/fajDPOIeuCS5I+379vL5X33JnWGsU26tw8iu+6L2ltYTkcxrt5AxGXC01hEbqBg9O+vyAIRwcxBSwIwneCde6CLgeA7hXL4h73f7uN2n88gjo3L/kNlMqY0cSOJI0G8wkzE14arK6i6s9/INLSHHNclZNH1nkXoi4opPLuOxNeb1/8ERmnn4VCbwAg4nbhWLIYz7rVyKEw+qHDCLe0dCn4AwjV1+H88jMyTj0z7nn36hU0PP8fIk5H+zFtaX/ybvhJr1VVEQSh74kRQEEQjgg5Go27K7YnKn53B4F9e3vtfvpRY/Bt3pi0jdJmI2K3xx5UKMj94fVYZsyOe40sy5T/6taEU8jGCZMxjBpNw3NPJ312wc9+iXHcBEIN9VT98Z70NqSkQdJoUOfkoh86AutJp6ApaA3sfNu3UvXAvXGDSlVmJiX3PojSYOiVPgiC0LfECKAgCIdNoKKclnffxLNuDXIkjG7wEDJOPRPj+EndvqccjRLYtwc5HCbU0pL6gi4I1iZfU6gwGCn+3f04P1uEe81K5GAQ3aAh2Baeim7AoITX+bZsSrp+0LN+Ddr+qSt7RP2tdYcbnn2q14I/ADkYJFhVSbCqEudXS8i/+TaMo8fS8t7bCUcUw83N7L/9Z2hLy7DOno/CasW9/CsibheagkIscxagzsnttT4KgtAzYgRQEITDwr9rJ1V/vrd9jVtH2ZddiW3haV2+p/PLJTS/9Rrh5qbe6GKXWReeSs5lV3X5unQ2fZhnzcP11ZKkbdSFReRf/xMq7rqjW3kM06Uwmyl78O/suf773X/OgYTViUZFBUE4vEQeQEEQDouGF/8TN/gDaHrlJSIuZ5fu5/jiM+qf/ucRC/40xSVknn0B0Fplw7t1M95tW4gGgymvVRoSb7Bo41m7KmWbUHUVLR+936fBH7RWO3GvW9Ozm0Qi1D/9z9T5CAVBOCzEFLAgCH0uWF1FYO+ehOflcAjXyq+xnXhyWveTIxGa33y1t7rXLabps1Ho9TS9/jKOxR8R9bXmIFSYzGScdiYZp52V8FrjpClILz6DnCBYVFptRBz2tPrh27q5y33vjnBTA/rhI3v2vEgE55JPyb70it7rmCAI3SJGAAVB6HPpjO51ZQTQv3snEXsP1vtJCozjJ5F92ZUgSd26hezz0vjyC7S8+2Z78AcQdbtoeuW/NL/zZsJrlUYTmeddnOCkEtPU6Wn3I+p2oxsyLO328aiyslO2cX+9rFeCzUBleczXcjjc43sKgtB1IgAUBKHPqfPyW1OmJOHdvBH7x+8TSaMiRjpJjSWNJv5xixXrSScTdjlbg7S26VOVClVuHvo0S8IpLVYciz9OeL7lg3faN2nEk3HK6eTdcDPa0rL2Y/oRoyi87VdYZs5Jqw/QmpA658ofojCZ077mUJqSUlAk//cJVlV0+/4dKQ1G5HCY5vfeYt+tP2b3Nd9jz03X0PDSc0Tcrl55hiAIqYlNIIIgHBY1f38Yz5qVKdspDEYKbvkl+sFDE7aJeDzsu+VG5GD8NYU9YVmwEOenn
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "UFWr5nQo8YZQ"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Formatting the input data layer\n",
|
|||
|
|
"\n",
|
|||
|
|
"Instead of a 28 x 28 matrix, we build our network to accept a 784-length vector.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Each image needs to be then reshaped (or flattened) into a vector. We'll also normalize the inputs to be in the range [0-1] rather than [0-255]. Normalizing inputs is generally recommended, so that any additional dimensions (for other network architectures) are of the same scale."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "B-b0XGKd8YZQ"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src='https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/flatten.png' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "0ge5s-pT8YZR",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "2e430092-404d-4abf-f2c0-5cbf97adb29a"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"X_train = X_train.reshape(60000, 784) # reshape 60,000 28 x 28 matrices into 60,000 784-length vectors.\n",
|
|||
|
|
"X_test = X_test.reshape(10000, 784) # reshape 10,000 28 x 28 matrices into 10,000 784-length vectors.\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train = X_train.astype('float32') # change integers to 32-bit floating point numbers\n",
|
|||
|
|
"X_test = X_test.astype('float32')\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train /= 255 # normalize each value for each pixel for the entire vector for each input\n",
|
|||
|
|
"X_test /= 255\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(\"Training matrix shape\", X_train.shape)\n",
|
|||
|
|
"print(\"Testing matrix shape\", X_test.shape)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 8,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Training matrix shape (60000, 784)\n",
|
|||
|
|
"Testing matrix shape (10000, 784)\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "sS8jHU6V8YZT"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"We then modify our classes (unique digits) to be in the one-hot (categorical) format, i.e.\n",
|
|||
|
|
"\n",
|
|||
|
|
"```\n",
|
|||
|
|
"0 -> [1, 0, 0, 0, 0, 0, 0, 0, 0]\n",
|
|||
|
|
"1 -> [0, 1, 0, 0, 0, 0, 0, 0, 0]\n",
|
|||
|
|
"2 -> [0, 0, 1, 0, 0, 0, 0, 0, 0]\n",
|
|||
|
|
"etc.\n",
|
|||
|
|
"```\n",
|
|||
|
|
"\n",
|
|||
|
|
"If the final output of our network is very close to one of these classes, then it is most likely that class. For example, if the final output is:\n",
|
|||
|
|
"\n",
|
|||
|
|
"```\n",
|
|||
|
|
"[0, 0.94, 0, 0, 0, 0, 0.06, 0, 0]\n",
|
|||
|
|
"```\n",
|
|||
|
|
"then it is most probable that the image is that of the digit `1`."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "AFvPkCFw8YZU"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"nb_classes = 10 # number of unique digits\n",
|
|||
|
|
"\n",
|
|||
|
|
"Y_train = to_categorical(y_train, nb_classes)\n",
|
|||
|
|
"Y_test = to_categorical(y_test, nb_classes)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 9,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "q_B8ZH5G8YZX"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Building a 3-layer fully connected network\n",
|
|||
|
|
"\n",
|
|||
|
|
"<img src=\"https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/figure.png\" />"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "M-uOWlRv8YZX"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The Sequential model is a linear stack of layers and is very common.\n",
|
|||
|
|
"\n",
|
|||
|
|
"model = Sequential()"
|
|||
|
|
],
|
|||
|
|
"execution_count": 10,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "iyoUd9WE8YZa"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## The first hidden layer"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "IdgP4lai8YZb"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The first hidden layer is a set of 512 nodes (artificial neurons).\n",
|
|||
|
|
"# Each node will receive an element from each input vector and apply some weight and bias to it.\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.add(Dense(512, input_shape=(784,))) #(784,) is not a typo -- that represents a 784 length vector!"
|
|||
|
|
],
|
|||
|
|
"execution_count": 11,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "c1MiY3u68YZd"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# An \"activation\" is a non-linear function applied to the output of the layer above.\n",
|
|||
|
|
"# It checks the new value of the node, and decides whether that artifical neuron has fired.\n",
|
|||
|
|
"# The Rectified Linear Unit (ReLU) converts all negative inputs to nodes in the next layer to be zero.\n",
|
|||
|
|
"# Those inputs are then not considered to be fired.\n",
|
|||
|
|
"# Positive values of a node are unchanged.\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.add(Activation('relu'))"
|
|||
|
|
],
|
|||
|
|
"execution_count": 12,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "tQCRmiMI8YZg"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"$$f(x) = max (0,x)$$\n",
|
|||
|
|
"<img src = 'https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/relu.jpg' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "kudeX_kN8YZg"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Dropout zeroes a selection of random outputs (i.e., disables their activation)\n",
|
|||
|
|
"# Dropout helps protect the model from memorizing or \"overfitting\" the training data.\n",
|
|||
|
|
"model.add(Dropout(0.2))"
|
|||
|
|
],
|
|||
|
|
"execution_count": 13,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "PY5XiXYc8YZj"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Adding the second hidden layer"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Tnz-mpIS8YZj"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The second hidden layer appears identical to our first layer.\n",
|
|||
|
|
"# However, instead of each of the 512-node receiving 784-inputs from the input image data,\n",
|
|||
|
|
"# they receive 512 inputs from the output of the first 512-node layer.\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.add(Dense(512))\n",
|
|||
|
|
"model.add(Activation('relu'))\n",
|
|||
|
|
"model.add(Dropout(0.2))"
|
|||
|
|
],
|
|||
|
|
"execution_count": 14,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "vkeHSS0h8YZm"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## The Final Output Layer"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "cndhIgBK8YZm"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The final layer of 10 neurons in fully-connected to the previous 512-node layer.\n",
|
|||
|
|
"# The final layer of a FCN should be equal to the number of desired classes (10 in this case).\n",
|
|||
|
|
"model.add(Dense(10))"
|
|||
|
|
],
|
|||
|
|
"execution_count": 15,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "02Zk2VWL8YZp"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The \"softmax\" activation represents a probability distribution over K different possible outcomes.\n",
|
|||
|
|
"# Its values are all non-negative and sum to 1.\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.add(Activation('softmax'))"
|
|||
|
|
],
|
|||
|
|
"execution_count": 16,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "YReJLqWL8YZr",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "acf9807b-a564-41d7-fbb7-a212242469e5"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Summarize the built model\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.summary()"
|
|||
|
|
],
|
|||
|
|
"execution_count": 17,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Model: \"sequential\"\n",
|
|||
|
|
"_________________________________________________________________\n",
|
|||
|
|
" Layer (type) Output Shape Param # \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
" dense (Dense) (None, 512) 401920 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation (Activation) (None, 512) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dropout (Dropout) (None, 512) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_1 (Dense) (None, 512) 262656 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_1 (Activation) (None, 512) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dropout_1 (Dropout) (None, 512) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_2 (Dense) (None, 10) 5130 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_2 (Activation) (None, 10) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
"Total params: 669706 (2.55 MB)\n",
|
|||
|
|
"Trainable params: 669706 (2.55 MB)\n",
|
|||
|
|
"Non-trainable params: 0 (0.00 Byte)\n",
|
|||
|
|
"_________________________________________________________________\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "UcndsYyk8YZv"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Compiling the model\n",
|
|||
|
|
"\n",
|
|||
|
|
"Keras is built on top of Theano and TensorFlow. Both packages allow you to define a *computation graph* in Python, which then compiles and runs efficiently on the CPU or GPU without the overhead of the Python interpreter.\n",
|
|||
|
|
"\n",
|
|||
|
|
"When compiing a model, Keras asks you to specify your **loss function** and your **optimizer**. The loss function we'll use here is called *categorical cross-entropy*, and is a loss function well-suited to comparing two probability distributions.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Our predictions are probability distributions across the ten different digits (e.g. \"we're 80% confident this image is a 3, 10% sure it's an 8, 5% it's a 2, etc.\"), and the target is a probability distribution with 100% for the correct category, and 0 for everything else. The cross-entropy is a measure of how different your predicted distribution is from the target distribution. [More detail at Wikipedia](https://en.wikipedia.org/wiki/Cross_entropy)\n",
|
|||
|
|
"\n",
|
|||
|
|
"The optimizer helps determine how quickly the model learns through **gradient descent**. The rate at which descends a gradient is called the **learning rate**."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Wsv1Z9LC8YZw"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = \"https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/gradient_descent.png\" >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "vZWeBx-58YZx"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = \"https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/learning_rate.png\" >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "ptVA6aSg8YZy"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"So are smaller learning rates better? Not quite! It's important for an optimizer not to get stuck in local minima while neglecting the global minimum of the loss function. Sometimes that means trying a larger learning rate to jump out of a local minimum."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "zSGiajLT8YZ1"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = 'https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/complicated_loss_function.png' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "7RUmQxfp8YZ2"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Let's use the Adam optimizer for learning\n",
|
|||
|
|
"adam = tf.optimizers.Adam(learning_rate=0.001)\n",
|
|||
|
|
"model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])"
|
|||
|
|
],
|
|||
|
|
"execution_count": 18,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "d-u8EqW48YZ3"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Train the model!\n",
|
|||
|
|
"This is the fun part!"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "rTynqX-H8YZ4"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"The batch size determines over how much data per step is used to compute the loss function, gradients, and back propagation. Large batch sizes allow the network to complete it's training faster; however, there are other factors beyond training speed to consider.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Too large of a batch size smoothes the local minima of the loss function, causing the optimizer to settle in one because it thinks it found the global minimum.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Too small of a batch size creates a very noisy loss function, and the optimizer may never find the global minimum.\n",
|
|||
|
|
"\n",
|
|||
|
|
"So a good batch size may take some trial and error to find!"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "MxhTSlBI8YZ5",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "88c04566-8028-4fd6-fea8-0a2c728eaa64"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model.fit(X_train, Y_train,\n",
|
|||
|
|
" batch_size=128, epochs=5,\n",
|
|||
|
|
" verbose=1)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 19,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/5\n",
|
|||
|
|
"469/469 [==============================] - 5s 4ms/step - loss: 0.2438 - accuracy: 0.9265\n",
|
|||
|
|
"Epoch 2/5\n",
|
|||
|
|
"469/469 [==============================] - 2s 3ms/step - loss: 0.1011 - accuracy: 0.9684\n",
|
|||
|
|
"Epoch 3/5\n",
|
|||
|
|
"469/469 [==============================] - 2s 5ms/step - loss: 0.0716 - accuracy: 0.9775\n",
|
|||
|
|
"Epoch 4/5\n",
|
|||
|
|
"469/469 [==============================] - 2s 5ms/step - loss: 0.0564 - accuracy: 0.9816\n",
|
|||
|
|
"Epoch 5/5\n",
|
|||
|
|
"469/469 [==============================] - 2s 3ms/step - loss: 0.0450 - accuracy: 0.9852\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<keras.src.callbacks.History at 0x7f0d88148e20>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 19
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "YufwyCC38YZ6"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"The two numbers, in order, represent the value of the loss function of the network on the training set, and the overall accuracy of the network on the training data. But how does it do on data it did not train on?"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "wQORnbe08YZ7"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Evaluate Model's Accuracy on Test Data"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "SX3n6p108YZ7",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "989db8c3-3559-47bd-b32d-a84a3855f339"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"score = model.evaluate(X_test, Y_test)\n",
|
|||
|
|
"print('Test score:', score[0])\n",
|
|||
|
|
"print('Test accuracy:', score[1])"
|
|||
|
|
],
|
|||
|
|
"execution_count": 20,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"313/313 [==============================] - 1s 3ms/step - loss: 0.0637 - accuracy: 0.9812\n",
|
|||
|
|
"Test score: 0.0636596530675888\n",
|
|||
|
|
"Test accuracy: 0.9811999797821045\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "NYD0_NjusHRL"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"def plot_confusion_matrix(cm, classes,\n",
|
|||
|
|
" normalize=False,\n",
|
|||
|
|
" title='Confusion matrix',\n",
|
|||
|
|
" cmap=plt.cm.Blues):\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" This function prints and plots the confusion matrix.\n",
|
|||
|
|
" Normalization can be applied by setting `normalize=True`.\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" if normalize:\n",
|
|||
|
|
" cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
|
|||
|
|
" print(\"Normalized confusion matrix\")\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" print('Confusion matrix, without normalization')\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(cm)\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
|
|||
|
|
" plt.title(title)\n",
|
|||
|
|
" plt.colorbar()\n",
|
|||
|
|
" tick_marks = np.arange(len(classes))\n",
|
|||
|
|
" plt.xticks(tick_marks, classes, rotation=45)\n",
|
|||
|
|
" plt.yticks(tick_marks, classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
" fmt = '.2f' if normalize else 'd'\n",
|
|||
|
|
" thresh = cm.max() / 2.\n",
|
|||
|
|
" for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
|
|||
|
|
" plt.text(j, i, format(cm[i, j], fmt),\n",
|
|||
|
|
" horizontalalignment=\"center\",\n",
|
|||
|
|
" color=\"white\" if cm[i, j] > thresh else \"black\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.ylabel('True label')\n",
|
|||
|
|
" plt.xlabel('Predicted label')\n",
|
|||
|
|
" plt.tight_layout()"
|
|||
|
|
],
|
|||
|
|
"execution_count": 21,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "yWaT13H78YZ9"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"### Inspecting the output\n",
|
|||
|
|
"\n",
|
|||
|
|
"It's always a good idea to inspect the output and make sure everything looks sane. Here we'll look at some examples it gets right, and some examples it gets wrong."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "LFECu0nE8YZ-",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"outputId": "f534eaf7-27ee-4861-8df4-aa8d9c03b489"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The predict_classes function outputs the highest probability class\n",
|
|||
|
|
"# according to the trained classifier for each input example.\n",
|
|||
|
|
"predict = model.predict(X_test)\n",
|
|||
|
|
"predicted_classes = np.argmax(predict,axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Check which items we got right / wrong\n",
|
|||
|
|
"correct_indices = np.nonzero(predicted_classes == y_test)[0]\n",
|
|||
|
|
"incorrect_indices = np.nonzero(predicted_classes != y_test)[0]\n",
|
|||
|
|
"cnf_matrix = confusion_matrix(y_test, predicted_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"class_names = [str(i) for i in range(10)]\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Plot non-normalized confusion matrix\n",
|
|||
|
|
"plt.figure()\n",
|
|||
|
|
"plot_confusion_matrix(cnf_matrix, classes=class_names,\n",
|
|||
|
|
" title='Confusion matrix, without normalization')\n",
|
|||
|
|
"\n",
|
|||
|
|
"plt.show()\n"
|
|||
|
|
],
|
|||
|
|
"execution_count": 22,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"313/313 [==============================] - 1s 2ms/step\n",
|
|||
|
|
"Confusion matrix, without normalization\n",
|
|||
|
|
"[[ 974 1 0 0 0 0 0 2 2 1]\n",
|
|||
|
|
" [ 0 1124 3 0 0 0 2 1 5 0]\n",
|
|||
|
|
" [ 2 1 1012 3 3 0 2 6 2 1]\n",
|
|||
|
|
" [ 0 1 2 992 0 5 0 4 1 5]\n",
|
|||
|
|
" [ 0 0 4 0 952 0 4 1 0 21]\n",
|
|||
|
|
" [ 2 0 0 6 1 872 2 2 2 5]\n",
|
|||
|
|
" [ 4 2 0 0 7 4 941 0 0 0]\n",
|
|||
|
|
" [ 0 4 5 3 0 0 0 1012 0 4]\n",
|
|||
|
|
" [ 1 0 6 3 4 5 2 6 940 7]\n",
|
|||
|
|
" [ 1 2 0 3 3 0 0 7 0 993]]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 2 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1UAAAN6CAYAAABmBWMlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACzsElEQVR4nOzdd3gU1dvG8XsTSCE9QBIiLYBShIBUQ1EEBBFELCgKEopgoRhRRCw0keqriFIEFVDhR1GRIqD0Ir0pooI0QSAJCCQklECy7x+RxTVBEibJzCbfD9dcurOzs/eeTGb3yTlz1ma32+0CAAAAANwUN7MDAAAAAIAro6gCAAAAAAMoqgAAAADAAIoqAAAAADCAogoAAAAADKCoAgAAAAADKKoAAAAAwACKKgAAAAAwoJDZAQAAAADcnIsXLyolJcXsGFni4eEhLy8vs2PkCooqAAAAwAVdvHhR3n5FpSvnzY6SJWFhYTp06FC+LKwoqgAAAAAXlJKSIl05L88q0ZK7h9lx/ltqimJ/ma6UlBSKKgAAAAAW4+4hm8WLKrvZAXIZRRUAAADgymxu6YuVWT2fQfn71QEAAABALqOoAgAAAAADKKoAAAAAwACuqQIAAABcmU2SzWZ2iv9m8XhG0VMFAAAAAAZQVAEAAACAAQz/AwAAAFwZU6qbLn+/OgAAAADIZRRVAAAAAGAAw/8AAAAAV2azucDsfxbPZxA9VQAAAABgAEUVAAAAABhAUQUAAAAABnBNFQAAAODKmFLddPn71QEAAABALqOoAgAAAAADGP4HAAAAuDKmVDcdPVUAAAAAYABFFQAAAAAYQFEFAAAAAAZwTRUAAADg0lxgSvV83peTv18dAAAAAOQyiioAAAAAMIDhfwAAAIArY0p109FTBQAAAAAGUFQBAAAAgAEM/wMAAABcmc0FZv+zej6D8verAwAAAIBcRlEFAAAAAAZQVAEAAACAAVxTBQAAALgyplQ3HT1VAAAAAGAARRUAAAAAGMDwPwAAAMCVMaW66fL3qwMAAACAXEZRBQAAAAAGMPwPAAAAcGXM/mc6eqoAAAAAwACKKgAAAAAwgKIKAAAAAAzgmioAAADAlTGluuny96sDAAAAgFxGUQUAAAAABjD8DwAAAHBlNpv1h9cxpToAAAAA4HooqgAAAADAAIb/AQAAAK7MzZa+WJnV8xlETxUAAAAAGEBRBQAAAAAGUFQBAAAAgAFcUwUAAAC4MpubC0ypbvF8BuXvVwcAAAAAuYyiCgAAAAAMYPgfAAAA4MpstvTFyqyezyB6qgAAAADAAIoqAAAAADCAogoAAAAADOCaKgAAAMCVMaW66fL3qwMAAACAXEZRBQAAAAAGMPwPAAAAcGVMqW46eqoAAAAAwACKKgAAAAAwgOF/AAAAgCtj9j/T5e9XBwAAAAC5jKIKAAAAAAygqAIAAAAAA7imCgAAAHBlTKluOnqqAAAAAMAAiioAAAAAMIDhfwAAAIArY0p10+XvVwcAAAAAuYyiCsjHfv/9dzVv3lwBAQGy2Wz65ptvcnT/hw8fls1m07Rp03J0v/lB2bJl1blzZ7NjZJCdn9nVbd95553cD4ZMDR48WLZ/Xdxt1rFl1WMaAKyAogrIZQcOHNAzzzyjcuXKycvLS/7+/mrQoIHef/99XbhwIVefOzo6Wrt379bbb7+tzz//XLVr187V58uPfvnlFw0ePFiHDx82O0quWbx4sQYPHmx2jAyGDx+e438IwH/bsGGDBg8erLNnz5odBUB2XJ39z+pLPsY1VUAu+vbbb9WuXTt5enqqU6dOqlq1qlJSUrR+/Xr169dPe/bs0eTJk3PluS9cuKCNGzfq9ddfV69evXLlOcqUKaMLFy6ocOHCubJ/K/jll180ZMgQNW7cWGXLls3y4/bu3Ss3N+v93Sqzn9nixYs1fvx4yxVWw4cP16OPPqq2bduaHcVScvPY2rBhg4YMGaLOnTsrMDAwz54XAFwdRRWQSw4dOqT27durTJkyWrlypUqUKOG4r2fPntq/f7++/fbbXHv+kydPSlKGD0Y5yWazycvLK9f272rsdrsuXrwob29veXp6mh0nU/zMjElOTpaPj4+pGcw6tqx6TAOAFfAnJyCXjB49WklJSfrkk0+cCqqrKlSooBdeeMFx+8qVK3rrrbdUvnx5eXp6qmzZsnrttdd06dIlp8eVLVtWrVu31vr161W3bl15eXmpXLly+uyzzxzbDB48WGXKlJEk9evXTzabzdHL0rlz50x7XDK7dmPZsmVq2LChAgMD5evrq4oVK+q1115z3H+963NWrlypRo0aycfHR4GBgXrwwQf166+/Zvp8+/fvd/xVPCAgQF26dNH58+ev37B/a9y4sapWraqffvpJd999t4oUKaIKFSroyy+/lCStWbNG9erVk7e3typWrKjly5c7Pf6PP/7Q888/r4oVK8rb21tFixZVu3btnIb5TZs2Te3atZMk3XPPPbLZbLLZbFq9erWkaz+L7777TrVr15a3t7c++ugjx31Xrz+x2+265557VLx4ccXHxzv2n5KSomrVqql8+fJKTk6+4Wv+p759+6po0aKy2+2Odb1795bNZtO4ceMc6+Li4mSz2TRx4kRJGX9mnTt31vjx4yXJ8fr+fRxI0uTJkx3HZp06dbR169YM22Tl557V489msyk5OVnTp093ZPqv63lWr14tm82mOXPm6O2331bJkiXl5eWlpk2bav/+/Rm2nzt3rmrVqiVvb28VK1ZMHTt21LFjxzJk9fX11YEDB3T//ffLz89PHTp0cOTr1auX5s6dqypVqsjb21tRUVHavXu3JOmjjz5ShQoV5OXlpcaNG2cYPrpu3Tq1a9dOpUuXlqenp0qVKqUXX3wxS0OC/31t0z9/bv9erj7vTz/9pM6dOzuGIYeFhalr167666+/nH4G/fr1kyRFRERk2Edm11QdPHhQ7dq1U3BwsIoUKaI777wzwx+LsvuzAQBXRE8VkEsWLlyocuXKqX79+lna/umnn9b06dP16KOP6qWXXtLmzZs1YsQI/frrr5o3b57Ttvv379ejjz6qbt26KTo6Wp9++qk6d+6sWrVq6fbbb9fDDz+swMBAvfjii3riiSd0//33y9fXN1v59+zZo9atWysyMlJDhw6Vp6en9u/frx9++OE/H7d8+XK1bNlS5cqV0+DBg3XhwgV98MEHatCggXbs2JHhA/Vjjz2miIgIjRgxQjt27NDHH3+skJAQjRo16oYZz5w5o9atW6t9+/Zq166dJk6cqPbt22vGjBmKiYnRs88+qyeffFJjxozRo48+qqNHj8rPz0+StHXrVm3YsEHt27dXyZIldfjwYU2cOFGNGzfWL7/8oiJFiuiuu+5Snz59NG7cOL322muqXLmyJDn+K6UPiXriiSf0zDPPqHv37qpYsWKGnDabTZ9++qkiIyP17LPP6uuvv5YkDRo0SHv27NHq1auz3fvRqFEjvffee9qzZ4+qVq0qKf2Dupubm9atW6c+ffo41knSXXfdlel+nnnmGR0/flzLli3T559/nuk2M2fO1Llz5/TMM8/IZrNp9OjRevjhh3Xw4EHHMMLs/txv5PPPP9fTTz+tunXrqkePHpKk8uXL3/BxI0eOlJubm15++WUlJCRo9OjR6tChgzZv3uzYZtq0aerSpYvq1KmjESNGKC4uTu+//75++OEH7dy506l398qVK2rRooUaNmyod955R0WKFHHct27dOi1YsEA9e/aUJI0YMUKtW7fWK6+8ogkTJuj555/XmTNnNHr0aHXt2lUrV650PHbu3Lk6f/68nnvuORUtWlRbtmzRBx98oD///FNz587Ndlv92xtvvKH4+HjH7/2yZct08OBBdenSRWFhYY6hx3v27NGmTZtks9n08MMPa9++ffrf//6n9957T8WKFZMkFS9ePNPnjYuLU/369XX+/Hn16dNHRYsW1fTp09WmTRt9+eWXeuihh7L9swFws1xgSvX83pdjB5DjEhIS7JLsDz74YJa237Vrl12S/emnn3Za//LLL9sl2VeuXOlYV6ZMGbsk+
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"source": [],
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "d4tL55qnBfAA"
|
|||
|
|
},
|
|||
|
|
"execution_count": 22,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Ct14A0rsD-Ja"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"**Correct predictions**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "ew6oWSoY8YaA",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 906
|
|||
|
|
},
|
|||
|
|
"outputId": "040cbfff-02a8-431e-da94-8a67510fa803"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"def show_samples(indices, preds, images, labels, count=3, names = []):\n",
|
|||
|
|
" plt.figure()\n",
|
|||
|
|
" for i, sample in enumerate(indices[:count**2]):\n",
|
|||
|
|
" pred_id = int(np.argmax(preds[sample]))\n",
|
|||
|
|
" real_id = int(labels[sample])\n",
|
|||
|
|
" pred_score = preds[sample][pred_id]\n",
|
|||
|
|
" real_score = preds[sample][real_id]\n",
|
|||
|
|
" plt.subplot(count,count,i+1)\n",
|
|||
|
|
" plt.imshow(images[sample].reshape(28,28), cmap='gray', interpolation='none')\n",
|
|||
|
|
" plt.axis('off')\n",
|
|||
|
|
" if len(names) > 0:\n",
|
|||
|
|
" plt.title(\"P: {} ({:.2f})\\nE: {} ({:.2f})\".format(names[pred_id], pred_score, names[real_id], real_score))\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" plt.title(\"P: {} ({:.2f})\\nE: {} ({:.2f})\".format(pred_id, pred_score, real_id, real_score))\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.tight_layout()\n",
|
|||
|
|
"\n",
|
|||
|
|
"show_samples(correct_indices, predict, X_test, y_test, 5)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 23,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 25 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAN5CAYAAAA/32uUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADmT0lEQVR4nOzdd3RU1fbA8T2kF4EkhB5C7x1BaYYu0iwERQFRURRQERug+FCKCihYHkWKoAg+EOWBBRQQEURBmogh0gIh9FBCCZCQnN8fvsyPc+8kdyaZSTLw/ayV5dpnzj2zIzs39+TOucemlFICAAAAAMhWkYJOAAAAAAAKOyZOAAAAAGCBiRMAAAAAWGDiBAAAAAAWmDgBAAAAgAUmTgAAAABggYkTAAAAAFhg4gQAAAAAFpg4AQAAAIAFJk7/M3jwYOnYsWNBp5GjlStXSmhoqJw6daqgU0EhNHHiRKlZs6ZkZmYWdCrZiouLE19fX9m1a1dBp4JCiBqGt+NaAt6OGrag8tHcuXOViNi/AgICVLVq1dSQIUPU8ePHczXm6NGjtTGNXxs2bLAc48CBA8rPz0/9+OOPWvu0adNUbGysioqKUiKi+vfv71JuGRkZasKECapixYoqICBA1atXTy1cuNBh37i4OHXnnXeqkJAQFRYWpvr27atOnjxp6tegQQM1bNgwl/KA+3iihnfv3q1eeukl1aBBAxUaGqpKly6tunTpon7//Xenx0hJSVHh4eHq448/1tr/85//qD59+qiqVasqEVExMTEu5zd79mxVs2ZNFRAQoKpWrao++OADh/2SkpJUr169VLFixdQtt9yievToofbv32/q16NHD3Xvvfe6nAfcwxM1rJRS48aNU927d1clS5ZUIqJGjx7t0vHUMJzlqRp25Xe2I1xLwFmequHrffbZZ0pEVEhIiNPHUMPWCmTiNGbMGDV//nw1a9Ys1b9/f1WkSBFVqVIldenSJZfH/OOPP9T8+fNNX1FRUSosLExdvXrVcoyhQ4eq6tWrm9qjo6NVeHi46ty5s/L19XW5UEaMGKFERD3xxBNq5syZqmvXrkpE1Oeff671O3z4sCpRooSqUqWKev/999X48eNVWFiYatCggSn/adOmqeDgYHX+/HmXcoF7eKKGX3jhBVW8eHE1YMAA9dFHH6mJEyeqKlWqKB8fH7Vq1SqnxpgyZYoqWrSounz5stYeExOjQkNDVdu2bVVYWJjLF50zZsxQIqJ69uypZs6cqfr166dERL399ttavwsXLqhq1aqpkiVLqgkTJqjJkyerqKgoVb58eZWcnKz1/e6775SIqH379rmUC9zDEzWslFIiokqXLq3uvPPOXE2cqGE4y1M17Ozv7OxwLQFneaqGs1y4cEGVLVtWhYSEuDRxooatFcjEyfiX9Oeff16JiEt/2clJYmKistls6oknnrDsm5aWpkqUKKFGjRpleu3gwYMqMzNTKaVUSEiIS4WSlJSk/Pz81JAhQ+xtmZmZqnXr1qp8+fLq2rVr9vZBgwapoKAgdejQIXvbqlWrlIiojz76SBv3xIkTysfHR82ZM8fpXOA+nqjhLVu2qAsXLmhtycnJKjIyUrVs2dKpMerXr6/69u1rak9MTFQZGRlKKaXq1Knj0kVnamqqioiIUF27dtXa+/Tpo0JCQtSZM2fsbRMmTFAiojZv3mxv2717t/Lx8VEjR47Ujk9LS1NhYWHqtddeczoXuI+nzsMJCQlKKaVOnTqVq4kTNQxneaKGXfmd7QjXEnCFp6+Hhw8frmrUqGE/1zmDGnZOoVjj1K5dOxERSUhIsLft379f9u/fn6vxPv/8c1FKSZ8+fSz7btiwQZKTk6VDhw6m16Kjo8Vms+Uqh2XLlkl6eroMHjzY3maz2WTQoEGSlJQkv/76q739yy+/lG7dukmFChXsbR06dJDq1avL4sWLtXFLliwp9evXl2XLluUqL3hGXmq4SZMmEhoaqrVFRERI69atZffu3ZbHJyQkyM6dOx3WcFRUlBQpkrsf87Vr18rp06e1GhYRGTJkiFy6dEm+/fZbe9uSJUukadOm0rRpU3tbzZo1pX379qYa9vPzkzZt2lDDhUxez8MVK1bM9XtTw3CHvNSwK7+zHeFaAu7gjuvhvXv3ypQpU2Ty5Mni6+vr9HHUsHMKxcQpqyAiIiLsbe3bt5f27dvnarwFCxZIVFSU3HHHHZZ9N27cKDabTRo1apSr98rO9u3bJSQkRGrVqqW1N2vWzP66iMiRI0fk5MmTcuutt5rGaNasmb3f9Zo0aSIbN250a77IG3fXsIjI8ePHpUSJEpb9smqhcePGuX4vR7Jqz1ibTZo0kSJFithfz8zMlJ07d2Zbw/v375cLFy6Yxti1a5ecP3/erTkj9zxRw86ihuEOealhZ39nZ4drCbiDO87Dzz33nLRt21a6dOni0ntTw84pkIlTSkqKJCcnS1JSkixatEjGjBkjQUFB0q1btzyP/ddff8nOnTvlwQcfdGp2HB8fL+Hh4VK0aNE8v/f1jh07JqVKlTLlUKZMGREROXr0qL3f9e3GvmfOnJGrV69q7ZUrV5bk5GQ5efKkW3OG8zxZwyIi69evl19//VUeeOABy77x8fEiIlKpUiW3vHeWY8eOiY+Pj5QsWVJr9/f3l4iICHsNZ9VodjUs8v/1nqVy5cqSmZlpzx35z9M17ApqGLnhzhp29nd2driWQG64+zz87bffyg8//CCTJ092+Vhq2DnO38NzI+NtwOjoaFmwYIGUK1fO3nbw4MFcjb1gwQIREac+picicvr0aQkLC8vVe+Xk8uXLEhAQYGoPDAy0v379f636Xv96Vr7JycmmCwLkD0/W8MmTJ+Whhx6SSpUqycsvv2zZ//Tp0+Lr62v6uF9eXb58Wfz9/R2+FhgY6HINX+/6GkbB8GQNu4oaRm64s4ad/Z2dHa4lkBvurOG0tDQZNmyYPPXUU1K7dm2Xc6GGnVMgE6epU6dK9erVxdfXV0qVKiU1atTI9WfYr6eUkoULF0rdunWlfv36Lh3nbkFBQaaZsYjIlStX7K9f/19n+mbJyje3nzdF3nmqhi9duiTdunWTCxcuyIYNG9x+IemKoKAgSUtLc/jalStXqGEv56kaLkyo4RubO2vY2d/ZOeFaAq5yZw1PmTJFkpOT5Y033sh1PtSwtQKZODVr1szhZxjz6pdffpFDhw7JW2+95fQxERERcvbsWbfnUqZMGVm7dq0opbR/0KxbkWXLlrX3u779eseOHZPw8HDT7DsrX2fWv8AzPFHDaWlpct9998nOnTvl+++/l7p16zp1XEREhFy7dk0uXLggt9xyi9vyKVOmjGRkZMjJkye1v+SkpaXJ6dOn7TWcVaPZ1bDI/9d7Fmq44HnqPJwb1DByw5017Ozv7OxwLYHccFcNp6SkyLhx42Tw4MFy/vx5+9rLixcvilJKDh48KMHBwTnelaGGnXND/XlxwYIFYrPZ5KGHHnL6mJo1a8rZs2clJSXFrbk0bNhQUlNTTU9F27Rpk/11EZFy5cpJZGSkbNmyxTTG5s2b7f2ul5CQICVKlJDIyEi35oyCk5mZKQ8//LCsWbNGFi5cKDExMU4fW7NmTRHRn8LjDlm1Z6zNLVu2SGZmpv31IkWKSL169RzW8KZNm6Ry5cqmi+GEhAQpUqSIVK9e3a05wztRwyhozv7Ozg7XEihIZ8+elYsXL8rEiROlUqVK9q8vv/xSUlNTpVKlSjJw4MAcx6CGnVNoJ06uPn4xPT1dvvjiC2nVqpX2GEMrzZs3F6WUbN26NTdpisg/M/34+Hit2O6++27x8/OTadOm2duUUjJjxgwpV66ctGjRwt7es2dP+eabb+Tw4cP2tjVr1siePXukV69epvfbunWrNG/ePNf5In+4UsPPPPOMLFq0SKZNmyb33XefS++TVQuOTjbOSk1Nlfj4eG29Rrt27SQ8PFymT5+u9Z0+fboEBwdL165d7W2xsbHy+++/azn8/fff8uOPP2Zbw3Xq1JFixYrlOmd4Xl62hXAFNQxPcbaGXfmd7QjXEvAUZ2q4ZMmSsnTpU
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "fhZp92ZxEGn4"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"**Wrong predictions**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "FrAIwG1CEFxe",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 906
|
|||
|
|
},
|
|||
|
|
"outputId": "4717e081-20c2-410b-8cf9-c8161eed590d"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples(incorrect_indices, predict, X_test, y_test, 5)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 24,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 25 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAN5CAYAAAA/32uUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXxMV/8H8M9k32xJLEFELLHWltLad7WVKlqqqFpqV12oomqJ1vJQlBBankdpqdaDtqitbZQqqbWkFKl9SagliSSS7++PPpmfc+9k7kRmTCKf9+uVV1/fM+eee0a/c+eeuffcYxIRAREREREREWXJxdkdICIiIiIiyu04cCIiIiIiIjLAgRMREREREZEBDpyIiIiIiIgMcOBERERERERkgAMnIiIiIiIiAxw4ERERERERGeDAiYiIiIiIyAAHTkRERERERAY4cPqfoUOHonXr1s7uhlVbtmyBn58frl+/7uyuUC7EHKa8bubMmahcuTIyMjKc3ZUsHT9+HG5ubjh27Jizu0K5EI/DlNflhRxevHgxypQpg5SUlEe/c3mEli9fLgDMf56enlKxYkUZNmyYXLly5aHbPXXqlHTt2lUKFy4s3t7e0rBhQ9m5c6fN2585c0bc3d0tbrNs2TKpXLmyeHp6SoUKFWT+/Pk2tblr1y7lvT74t3fvXnO9xMRE+fjjj6V169ZSokQJ8fPzk1q1asmiRYvk/v37unZr1qwpo0ePtvm9kX3lpxzWmjZtmgCQatWq6V7bunWrvPrqq1KtWjVxcXGRkJCQLNthDjuXI3L4xIkT8vbbb0vNmjXFz89PSpQoIe3bt5f9+/fb3MatW7fE399fPv30U6X8iy++kF69ekmFChUEgDRt2jTb/bP1M3DhwgXp3r27FCpUSAoUKCCdOnWS06dP6+p16tRJunTpku1+kH046jicnp4uM2bMkLJly4qnp6c88cQTsnr1apu3d9Rx+OTJk/Liiy9KqVKlxNvbWypVqiSTJ0+WxMREcx2eS+QtjsjhSZMmZXneCUB2795t2IajcvjAgQPyzDPPSIECBcTPz09at24tBw8e1NVLTU2V999/X0JDQ8XDw0NCQ0Nl6tSpkpaWptRLTk6W4sWLy7x582zug704ZeA0ZcoUWblypSxdulT69u0rLi4uEhoaqhwEbHXu3DkJDAyU4sWLS0REhHz00UdSs2ZNcXNzkx9//NGmNkaNGiVhYWG68sWLFwsA6dq1q0RFRUnv3r0FgHz44YeGbWYOnEaOHCkrV65U/q5fv26ud/ToUTGZTNKqVSuZOXOmLF68WLp06SIApE+fPrp2Fy1aJD4+PnL79m2b3hvZV37K4QedP39efHx8xNfX1+LAqW/fvuLl5SUNGjSQ0qVLWx04MYedyxE5/Oabb0rhwoWlf//+smTJEpk5c6aUL19eXF1dZdu2bTa1MXfuXClYsKAkJycr5U2bNhU/Pz9p3ry5FClSJNsDJ1s/A3fu3JGKFStKsWLFZMaMGTJnzhwJDg6W0qVLS3x8vFL3u+++EwDy559/ZqsvZB+OyGERkXfeeUcAyMCBAyUqKko6dOggAOTzzz+3aXtHHIfPnTsnhQsXlpCQEPnggw9kyZIl8sorrwgA6dSpk7kezyXyFkfk8OHDh3XnmytXrpTg4GApUqSIpKSkGLbhiByOiYkRLy8vqVixosyePVtmzpwpZcuWlYIFC0psbKxS94UXXhCTyST9+/eXyMhI6du3r/kzqTVmzBgJCQmRjIwMwz7Yk1MGTtpfId944w0BkK1fdjINHTpU3NzclH/8xMRECQ4Oljp16hhun5qaKoGBgTJhwgSlPCkpSQICAqRDhw5Kea9evcTX11du3Lhhtd3MgdOXX35ptd7169fl2LFjuvJ+/foJADl16pRSfvXqVXF1dZVPPvnEarvkGPkphx/04osvSosWLaRp06YWB04XL16U1NRUERHp0KGD1YETc9i5HJHDBw4ckDt37ihl8fHxUrRoUWnYsKFNbdSoUUNefvllXfm5c+ckPT1dRESqVauWrYFTdj4DM2bMEADy66+/mstOnDghrq6uMm7cOGX71NRUKVKkiEycONHmvpD9OCKHL1y4IO7u7jJs2DBzWUZGhjRu3FhKly5t8arNgxx1HI6IiBAAuvOEPn36CADz9jyXyFsckcOWnDt3Tkwmk8WBh5ajcrh9+/ZSpEgR5QeoS5cuiZ+fnzz//PPmsl9//VUA6I6rb775pphMJjl8+LBSfuDAAQEgO3bsMHxv9pQr5ji1aNECAHD27Flz2enTp3H69GnDbaOjo1G7dm1UqlTJXObj44NOnTrht99+w6lTp6xuv3v3bsTHx6NVq1ZK+a5du5CQkIChQ4cq5cOGDUNiYiK+/fZbw75lunPnDu7fv2/xtcDAQFSrVk1X3qVLFwDAiRMnlPJixYqhRo0a2LBhg837J8d7nHP4p59+wrp16/DRRx9lWadkyZJwd3e3qT3mcO6UkxwODw+Hn5+fUhYQEIDGjRvrjmGWnD17FkeOHNHlMAAEBwfDxeXhvqqy8xlYt24d6tati7p165rLKleujJYtW2Lt2rXK9u7u7mjWrBlzOJfJSQ5v2LABaWlpSq6YTCYMGTIEFy5cwN69e61u76jj8O3btwEAxYsXV8qDgoLg4uICDw8PADyXeFzkJIct+fzzzyEi6NWrl2FdR+VwdHQ0WrVqhYCAAHNZUFAQmjZtim+++QZ379411wOAHj16KNv36NEDIoI1a9Yo5eHh4fD393/kOZwrBk6ZCfHgP2rLli3RsmVLw21TUlLg7e2tK/fx8QEAxMTEWN1+z549MJlMqF27tlJ+8OBBAMCTTz6plIeHh8PFxcX8upF+/fqhYMGC8PLyQvPmzXHgwAGbtrty5QqAfw6GWuHh4dizZ49N7dCj8bjmcHp6OkaMGIEBAwbgiSeeMKxvK+Zw7pOTHM7KlStXLB7DtDJzoU6dOg+9L0ts/QxkZGTgyJEjunoAUK9ePZw+fRp37tzRtXHs2DHziS05X05y+ODBg/D19UWVKlWU8nr16plft8ZRx+FmzZoBAPr3749Dhw7h/PnzWLNmDSIjIzFy5Ej4+vpa3Z7nEnmLvY/Dq1atQnBwMJo0aWJY11E5bO0cJzU11fygncwHPWjrWjsXqlOnDn7++Wer+7c3pwycbt26hfj4eFy4cAFr1qzBlClT4O3tjY4dO2a7rUqVKuHIkSO6L7Xdu3cDAC5evGh1+9jYWPj7+6NgwYJK+eXLl+Hq6opixYop5R4eHggICMClS5estuvh4YGuXbti3rx52LBhA6ZNm4ajR4+icePGhkmWmpqKjz76CKGhocqvn5nKlSuH+Ph4XLt2zWo75Dj5IYeBf55c89dff2Hq1Km2vBWbMYedz545bEl0dDT27t2LF1980bBubGwsACA0NNQu+85k62fgxo0bSElJQVBQkK6NzDLt56VcuXLIyMgw950ePXvm8OXLl1G8eHGYTCalPKv//1qOOg63bdsWU6dOxbZt21C7dm2UKVMGPXr0wIgRIzB37lyr2/JcIvdz5HH4999/x5EjR9CzZ09dXlviqByuVKkSfvnlF6Snp5vLUlNTsW/fPgD/f46TedeNdiCUeSXK0rlQuXLlcPz4ccP3Zk9uj3Rv/6O9DBgSEoJVq1ahVKlS5rK4uDib2hoyZAg2bdqEF198EREREfD19cWiRYvMV3aSk5Otbp+QkIAiRYroypOTk82XwLW8vLwM223QoAEaNGhgjjt16oRu3bqhRo0aGDduHLZs2ZLltsOHD8fx48fx7bffws1N/78os7/x8fG6RKZHIz/kcEJCAt577z1MnDgRRYsWtem92Io57Hz2zGGta9eu4aWXXkJoaCjGjBljWD8hIQFubm662/1yytbPQOZ/PT09LdZ7sE6mB3OYnMOeOZycnJyt//9ajjoOA0DZsmXRpEkTdO3aFQEBAfj2228xffp0lChRAsOHD89yO55L5H6OPA6vWrUKAGy6TQ9wXA4PHToUQ4YMQf/+/TFmzBhkZGRg2rRpuHz5srl9AGjfvj1CQ
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "6ZzALozi8YaD"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Try experimenting with the batch size!\n",
|
|||
|
|
"\n",
|
|||
|
|
" * How does increasing the batch size to 10,000 affect the training time and test accuracy?\n",
|
|||
|
|
" * How about a batch size of 32?\n",
|
|||
|
|
" * Is there any difference in results between the students? If so - why?\n",
|
|||
|
|
" * Experiment with the learning rate in the optimizer"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "uPUgO1BY8YaD"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Introducing Convolution! What is it?"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "vs75ypBn8YaE"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Before, we built a network that accepts the normalized pixel values of each value and operates soley on those values. What if we could instead feed different features (e.g. **curvature, edges**) of each image into a network, and have the network learn which features are important for classifying an image?\n",
|
|||
|
|
"\n",
|
|||
|
|
"This possible through convolution! Convolution applies **kernels** (filters) that traverse through each image and generate **feature maps**."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "eVIKBobD8YaF"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = 'https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/convolution.gif' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "gYJDfK4p8YaG"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"In the above example, the image is a 5 x 5 matrix and the kernel going over it is a 3 x 3 matrix. A dot product operation takes place between the image and the kernel and the convolved feature is generated. Each kernel in a CNN learns a different characteristic of an image.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Kernels are often used in photoediting software to apply blurring, edge detection, sharpening, etc."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "qm9CKpnr8YaH"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = 'https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/kernels.png' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "QdHIoa-E8YaH"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Kernels in deep learning networks are used in similar ways, i.e. highlighting some feature. Combined with a system called **max pooling**, the non-highlighted elements are discarded from each feature map, leaving only the features of interest, reducing the number of learned parameters, and decreasing the computational cost (e.g. system memory)."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "ceCEMpHf8YaI"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = 'https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/max_pooling.png' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "poGu943u8YaJ"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"We can also take convolutions of convolutions -- we can stack as many convolutions as we want, as long as there are enough pixels to fit a kernel.\n",
|
|||
|
|
"\n",
|
|||
|
|
"*Warning: What you may find down there in those deep convolutions may not appear recognizable to you.*"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "u7KmEU3m8YaJ"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = 'https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/go_deeper.jpg' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "h88R-Bsa8YaK"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Building a \"Deep\" Convolutional Neural Network"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "qoXIbzAG8YaL"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# import some additional tools\n",
|
|||
|
|
"\n",
|
|||
|
|
"from keras.preprocessing.image import ImageDataGenerator\n",
|
|||
|
|
"from keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D, GlobalAveragePooling2D, Flatten\n",
|
|||
|
|
"from keras.layers import BatchNormalization"
|
|||
|
|
],
|
|||
|
|
"execution_count": 25,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "w0MBTlCL8YaM"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Reload the MNIST data\n",
|
|||
|
|
"(X_train, y_train), (X_test, y_test) = mnist.load_data()"
|
|||
|
|
],
|
|||
|
|
"execution_count": 26,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "TbBjjLil8YaO",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "5273e5ab-be9b-4dd9-c807-e28ccd847a5f"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Again, do some formatting\n",
|
|||
|
|
"# Except we do not flatten each image into a 784-length vector because we want to perform convolutions first\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train = X_train.reshape(60000, 28, 28, 1) #add an additional dimension to represent the single-channel\n",
|
|||
|
|
"X_test = X_test.reshape(10000, 28, 28, 1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train = X_train.astype('float32') # change integers to 32-bit floating point numbers\n",
|
|||
|
|
"X_test = X_test.astype('float32')\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train /= 255 # normalize each value for each pixel for the entire vector for each input\n",
|
|||
|
|
"X_test /= 255\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(\"Training matrix shape\", X_train.shape)\n",
|
|||
|
|
"print(\"Testing matrix shape\", X_test.shape)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 27,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Training matrix shape (60000, 28, 28, 1)\n",
|
|||
|
|
"Testing matrix shape (10000, 28, 28, 1)\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "aNNglEKV8YaR"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# one-hot format classes\n",
|
|||
|
|
"\n",
|
|||
|
|
"nb_classes = 10 # number of unique digits\n",
|
|||
|
|
"\n",
|
|||
|
|
"Y_train = to_categorical(y_train, nb_classes)\n",
|
|||
|
|
"Y_test = to_categorical(y_test, nb_classes)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 28,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "buX5gLwP8YaT"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model = Sequential() # Linear stacking of layers\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Convolution Layer 1\n",
|
|||
|
|
"model.add(Conv2D(16, (5, 5), input_shape=(28,28,1))) # 16 different 5x5 kernels -- so 16 feature maps\n",
|
|||
|
|
"model.add(Activation('relu') ) # activation\n",
|
|||
|
|
"model.add(MaxPooling2D(pool_size=(2,2))) # Pool the max values over a 2x2 kernel\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Convolution Layer 2\n",
|
|||
|
|
"model.add(Conv2D(32, (5, 5))) # 32 different 5x5 kernels -- so 32 feature maps\n",
|
|||
|
|
"model.add(Activation('relu')) # activation\n",
|
|||
|
|
"model.add(MaxPooling2D(pool_size=(2,2))) # Pool the max values over a 2x2 kernel\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.add(Flatten()) # Flatten final output matrix into a vector\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Fully Connected Layer\n",
|
|||
|
|
"model.add(Dense(128)) # 128 FC nodes\n",
|
|||
|
|
"model.add(Activation('relu')) # activation\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Fully Connected Layer\n",
|
|||
|
|
"model.add(Dense(10)) # final 10 FC nodes\n",
|
|||
|
|
"model.add(Activation('softmax')) # softmax activation"
|
|||
|
|
],
|
|||
|
|
"execution_count": 29,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "srtd-OZV8YaV",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "93630a9d-2080-4624-9586-3ad2820774de"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model.summary()"
|
|||
|
|
],
|
|||
|
|
"execution_count": 30,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Model: \"sequential_1\"\n",
|
|||
|
|
"_________________________________________________________________\n",
|
|||
|
|
" Layer (type) Output Shape Param # \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
" conv2d (Conv2D) (None, 24, 24, 16) 416 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_3 (Activation) (None, 24, 24, 16) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d (MaxPooling2 (None, 12, 12, 16) 0 \n",
|
|||
|
|
" D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_1 (Conv2D) (None, 8, 8, 32) 12832 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_4 (Activation) (None, 8, 8, 32) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_1 (MaxPoolin (None, 4, 4, 32) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" flatten (Flatten) (None, 512) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_3 (Dense) (None, 128) 65664 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_5 (Activation) (None, 128) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_4 (Dense) (None, 10) 1290 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_6 (Activation) (None, 10) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
"Total params: 80202 (313.29 KB)\n",
|
|||
|
|
"Trainable params: 80202 (313.29 KB)\n",
|
|||
|
|
"Non-trainable params: 0 (0.00 Byte)\n",
|
|||
|
|
"_________________________________________________________________\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "TJ5vnJet8YaX"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# we'll use the same optimizer\n",
|
|||
|
|
"adam = tf.optimizers.Adam(learning_rate=0.001)\n",
|
|||
|
|
"model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])"
|
|||
|
|
],
|
|||
|
|
"execution_count": 31,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "WShqV9h28Yaa"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# data augmentation prevents overfitting by slightly changing the data randomly\n",
|
|||
|
|
"# Keras has a great built-in feature to do automatic augmentation\n",
|
|||
|
|
"\n",
|
|||
|
|
"gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,\n",
|
|||
|
|
" height_shift_range=0.08, zoom_range=0.08, validation_split=0.2)\n",
|
|||
|
|
"\n",
|
|||
|
|
"test_gen = ImageDataGenerator()"
|
|||
|
|
],
|
|||
|
|
"execution_count": 32,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "VsGRzqqd8Yab"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# We can then feed our augmented data in batches\n",
|
|||
|
|
"# Besides loss function considerations as before, this method actually results in significant memory savings\n",
|
|||
|
|
"# because we are actually LOADING the data into the network in batches before processing each batch\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Before the data was all loaded into memory, but then processed in batches.\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_generator = gen.flow(X_train, Y_train, batch_size=128, subset='training')\n",
|
|||
|
|
"valid_generator = gen.flow(X_train, Y_train, batch_size=128, subset='validation')\n",
|
|||
|
|
"test_generator = test_gen.flow(X_test, Y_test, batch_size=128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 33,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "_DXSGa-z8Yae",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "b898cc08-08ba-4466-83d2-bb9427cd790a"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# We can now train our model which is fed data by our batch loader\n",
|
|||
|
|
"# Steps per epoch should always be total size of the set divided by the batch size\n",
|
|||
|
|
"\n",
|
|||
|
|
"# SIGNIFICANT MEMORY SAVINGS (important for larger, deeper networks)\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.fit(train_generator, steps_per_epoch=48000//128, epochs=5, verbose=1, validation_data=valid_generator, validation_steps = 12000 // 128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 34,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/5\n",
|
|||
|
|
"375/375 [==============================] - 24s 52ms/step - loss: 0.4303 - accuracy: 0.8682 - val_loss: 0.1679 - val_accuracy: 0.9474\n",
|
|||
|
|
"Epoch 2/5\n",
|
|||
|
|
"375/375 [==============================] - 19s 52ms/step - loss: 0.1356 - accuracy: 0.9595 - val_loss: 0.1257 - val_accuracy: 0.9621\n",
|
|||
|
|
"Epoch 3/5\n",
|
|||
|
|
"375/375 [==============================] - 18s 47ms/step - loss: 0.0981 - accuracy: 0.9704 - val_loss: 0.0902 - val_accuracy: 0.9724\n",
|
|||
|
|
"Epoch 4/5\n",
|
|||
|
|
"375/375 [==============================] - 17s 46ms/step - loss: 0.0783 - accuracy: 0.9763 - val_loss: 0.0735 - val_accuracy: 0.9785\n",
|
|||
|
|
"Epoch 5/5\n",
|
|||
|
|
"375/375 [==============================] - 19s 52ms/step - loss: 0.0666 - accuracy: 0.9791 - val_loss: 0.0713 - val_accuracy: 0.9780\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<keras.src.callbacks.History at 0x7f0cf03883a0>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 34
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "6YZaV3U-8Yah",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "d83dfad8-59c1-454c-9808-87122e10f625"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"score = model.evaluate(X_test, Y_test)\n",
|
|||
|
|
"print('Test score:', score[0])\n",
|
|||
|
|
"print('Test accuracy:', score[1])"
|
|||
|
|
],
|
|||
|
|
"execution_count": 35,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"313/313 [==============================] - 1s 3ms/step - loss: 0.0291 - accuracy: 0.9902\n",
|
|||
|
|
"Test score: 0.0290591549128294\n",
|
|||
|
|
"Test accuracy: 0.9901999831199646\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "WqvF3eS2pnLp",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"outputId": "cee59a44-ec0c-480e-9247-e60c74c92660"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The predict_classes function outputs the highest probability class\n",
|
|||
|
|
"# according to the trained classifier for each input example.\n",
|
|||
|
|
"predict = model.predict(X_test)\n",
|
|||
|
|
"predicted_classes = np.argmax(predict,axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Check which items we got right / wrong\n",
|
|||
|
|
"correct_indices = np.nonzero(predicted_classes == y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"incorrect_indices = np.nonzero(predicted_classes != y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"cnf_matrix = confusion_matrix(y_test, predicted_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"class_names = [str(i) for i in range(10)]\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Plot non-normalized confusion matrix\n",
|
|||
|
|
"plt.figure()\n",
|
|||
|
|
"plot_confusion_matrix(cnf_matrix, classes=class_names,\n",
|
|||
|
|
" title='Confusion matrix, without normalization')\n",
|
|||
|
|
"\n",
|
|||
|
|
"plt.show()"
|
|||
|
|
],
|
|||
|
|
"execution_count": 36,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"313/313 [==============================] - 1s 2ms/step\n",
|
|||
|
|
"Confusion matrix, without normalization\n",
|
|||
|
|
"[[ 976 0 1 0 0 0 0 1 2 0]\n",
|
|||
|
|
" [ 0 1122 1 1 1 0 2 2 5 1]\n",
|
|||
|
|
" [ 1 1 1017 0 0 0 0 11 2 0]\n",
|
|||
|
|
" [ 0 0 1 1004 0 0 0 2 2 1]\n",
|
|||
|
|
" [ 0 0 0 0 973 0 1 0 0 8]\n",
|
|||
|
|
" [ 0 0 0 4 0 884 2 1 0 1]\n",
|
|||
|
|
" [ 1 2 1 0 1 5 944 0 4 0]\n",
|
|||
|
|
" [ 0 1 2 1 3 0 0 1014 0 7]\n",
|
|||
|
|
" [ 0 0 1 0 0 1 0 1 965 6]\n",
|
|||
|
|
" [ 0 0 0 1 5 0 0 0 0 1003]]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 2 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1UAAAN6CAYAAABmBWMlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACp+ElEQVR4nOzdd3gUVd/G8XsTSKGkUJIQelF6QGnSLIAgTbCgICWAYqMYUNqj0hQR9FUEkWKhqEhRqQJKkyK9iqggTRBIQk1IgASSff+IbFwTJGGSndnk+/Ga63l2dnb3nrOzy/5yzpyx2e12uwAAAAAAt8XD7AAAAAAA4M4oqgAAAADAAIoqAAAAADCAogoAAAAADKCoAgAAAAADKKoAAAAAwACKKgAAAAAwgKIKAAAAAAzIY3YAAAAAALfn6tWrSkxMNDtGhnh5ecnHx8fsGNmCogoAAABwQ1evXpVvwcLS9ctmR8mQkJAQHT16NEcWVhRVAAAAgBtKTEyUrl+Wd5VwydPL7Dj/LSlRkb/OVGJiIkUVAAAAAIvx9JLN4kWV3ewA2YyiCgAAAHBnNo+Uxcqsns+gnL13AAAAAJDNKKoAAAAAwACKKgAAAAAwgHOqAAAAAHdmk2SzmZ3iv1k8nlH0VAEAAACAARRVAAAAAGAAw/8AAAAAd8aU6qbL2XsHAAAAANmMogoAAAAADGD4HwAAAODObDY3mP3P4vkMoqcKAAAAAAygqAIAAAAAAyiqAAAAAMAAzqkCAAAA3BlTqpsuZ+8dAAAAAGQziioAAAAAMIDhfwAAAIA7Y0p109FTBQAAAAAGUFQBAAAAgAEUVQAAAABgAOdUAQAAAG7NDaZUz+F9OTl77wAAAAAgm1FUAQAAAIABDP8DAAAA3BlTqpuOnioAAAAAMICiCgAAAAAMYPgfAAAA4M5sbjD7n9XzGZSz9w4AAAAAshlFFQAAAAAYQFEFAAAAAAZwThUAAADgzphS3XT0VAEAAACAARRVAAAAAGAAw/8AAAAAd8aU6qbL2XsHAAAAANmMogoAAAAADGD4HwAAAODOmP3PdPRUAQAAAIABFFUAAAAAYABFFQAAAAAYwDlVAAAAgDtjSnXT5ey9AwAAAIBsRlEFAAAAAAYw/A8AAABwZzab9YfXMaU6AAAAAOBmKKoAAAAAwACG/wEAAADuzMOWsliZ1fMZRE8VAAAAABhAUQUAAAAABlBUAQAAAIABnFMFAAAAuDObhxtMqW7xfAbl7L0DAAAAgGxGUQUAAAAABjD8DwAAAHBnNlvKYmVWz2cQPVUAAAAAYABFFQAAAAAYQFEFAAAAAAZwThUAAADgzphS3XQ5e+8AAAAAIJtRVAEAAACAAQz/AwAAANwZU6qbjp4qAAAAADCAogoAAAAADGD4HwAAAODOmP3PdDl77wAAAAAgm1FUAQAAAIABFFUAAAAAYADnVAEAAADujCnVTUdPFQAAAAAYQFEFAAAAAAYw/A8AAABwZ0ypbrqcvXcAAAAAkM0oqoAc7I8//lDz5s3l7+8vm82mhQsXZunzHzt2TDabTTNmzMjS580JypQpo+7du5sdI43MvGc3tn333XezPxjSNWLECNn+dXK3WceWVY9pALACiiogmx0+fFjPPfecypUrJx8fH/n5+alhw4b64IMPdOXKlWx97fDwcO3bt0+jR4/W559/rtq1a2fr6+VEv/76q0aMGKFjx46ZHSXbLFu2TCNGjDA7RhpvvfVWlv8hAP9t06ZNGjFihC5evGh2FACZcWP2P6svORjnVAHZ6LvvvlOHDh3k7e2tbt26qVq1akpMTNTGjRs1cOBA7d+/X9OmTcuW175y5Yo2b96sV199VX369MmW1yhdurSuXLmivHnzZsvzW8Gvv/6qkSNH6v7771eZMmUy/LgDBw7Iw8N6f7dK7z1btmyZJk2aZLnC6q233tLjjz+u9u3bmx3FUrLz2Nq0aZNGjhyp7t27KyAgwGWvCwDujqIKyCZHjx5Vx44dVbp0aa1Zs0bFihVz3Ne7d28dOnRI3333Xba9/pkzZyQpzQ+jrGSz2eTj45Ntz+9u7Ha7rl69Kl9fX3l7e5sdJ128Z8bEx8crf/78pmYw69iy6jENAFbAn5yAbDJu3DjFxcXp008/dSqobqhQoYJeeuklx+3r16/rjTfeUPny5eXt7a0yZcrof//7nxISEpweV6ZMGbVp00YbN25U3bp15ePjo3LlymnWrFmObUaMGKHSpUtLkgYOHCibzeboZenevXu6PS7pnbuxcuVKNWrUSAEBASpQoIAqVqyo//3vf477b3Z+zpo1a9S4cWPlz59fAQEBateunX777bd0X+/QoUOOv4r7+/urR48eunz58s0b9m/333+/qlWrpp9//ln33Xef8uXLpwoVKujrr7+WJK1bt0716tWTr6+vKlasqFWrVjk9/s8//9SLL76oihUrytfXV4ULF1aHDh2chvnNmDFDHTp0kCQ98MADstlsstls+vHHHyWlvhfff/+9ateuLV9fX02dOtVx343zT+x2ux544AEVLVpU0dHRjudPTExU9erVVb58ecXHx99yn/9pwIABKly4sOx2u2Nd3759ZbPZNGHCBMe6qKgo2Ww2TZ48WVLa96x79+6aNGmSJDn279/HgSRNmzbNcWzWqVNH27dvT7NNRt73jB5/NptN8fHxmjlzpiPTf53P8+OPP8pms2nevHkaPXq0SpQoIR8fHzVt2lSHDh1Ks/38+fNVq1Yt+fr6qkiRIurSpYtOnjyZJmuBAgV0+PBhtWrVSgULFlTnzp0d+fr06aP58+erSpUq8vX1Vf369bVv3z5J0tSpU1WhQgX5+Pjo/vvvTzN8dMOGDerQoYNKlSolb29vlSxZUv3798/QkOB/n9v0z/ft38uN1/3555/VvXt3xzDkkJAQ9ezZU+fOnXN6DwYOHChJKlu2bJrnSO+cqiNHjqhDhw4qVKiQ8uXLp3vuuSfNH4sy+94AgDuipwrIJkuWLFG5cuXUoEGDDG3/zDPPaObMmXr88cf18ssva+vWrRozZox+++03LViwwGnbQ4cO6fHHH9fTTz+t8PBwffbZZ+revbtq1aqlqlWr6tFHH1VAQID69++vTp06qVWrVipQoECm8u/fv19t2rRRWFiYRo0aJW9vbx06dEg//fTTfz5u1apVatmypcqVK6cRI0boypUrmjhxoho2bKhdu3al+UH9xBNPqGzZshozZox27dqlTz75REFBQRo7duwtM164cEFt2rRRx44d1aFDB02ePFkdO3bUl19+qYiICD3//PN66qmn9M477+jxxx/XiRMnVLBgQUnS9u3btWnTJnXs2FElSpTQsWPHNHnyZN1///369ddflS9fPt17773q16+fJkyYoP/973+qXLmyJDn+V0oZEtWpUyc999xz6tWrlypWrJgmp81m02effaawsDA9//zz+vbbbyVJw4cP1/79+/Xjjz9muvejcePGev/997V//35Vq1ZNUsoPdQ8PD23YsEH9+vVzrJOke++9N93nee6553Tq1CmtXLlSn3/+ebrbzJ49W5cuXdJzzz0nm82mcePG6dFHH9WRI0ccwwgz+77fyueff65nnnlGdevW1bPPPitJKl++/C0f9/bbb8vDw0OvvPKKYmJiNG7cOHXu3Flbt251bDNjxgz16NFDderU0ZgxYxQVFaUPPvhAP/30k3bv3u3Uu3v9+nW1aNFCjRo10rvvvqt8+fI57tuwYYMWL16s3r17S5LGjBmjNm3aaNCgQfroo4/04osv6sKFCxo3bpx69uypNWvWOB47f/58Xb58WS+88IIKFy6sbdu2aeLEifrrr780f/78TLfVv7322muKjo52fO5XrlypI0eOqEePHgoJCXEMPd6/f7+2bNkim82mRx99VAcPHtRXX32l999/X0WKFJEkFS1aNN3XjYqKUoMGDXT58mX169dPhQsX1syZM/Xwww/r66+/1iOPPJLp9wbA7XKDKdVzel+OHUCWi4mJsUuyt2vXLkPb79mzxy7J/swzzzitf+WVV+yS7GvWrHGsK126tF2Sf
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Wn7FDt7DqSKB",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 906
|
|||
|
|
},
|
|||
|
|
"outputId": "8598c5ba-17e3-4da4-a20e-c2c9cfc4b060"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples(correct_indices, predict, X_test, y_test, 5)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 37,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 25 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAN5CAYAAAA/32uUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADoeUlEQVR4nOzdd3hU1fbw8TWkF4GQhBIIIXSkV6UZmqg0C0FRQFSUK2DDBij+UIpX8AqWSxFEUAQvCHLBAlJEBVGQLkKkBUIgCKGEEiCQ7PcPb+ZlnzOZM5PMJBn4fp4nj8/as8+eFVk5OTtn9tk2pZQSAAAAAECeShR1AgAAAABQ3DFxAgAAAAALTJwAAAAAwAITJwAAAACwwMQJAAAAACwwcQIAAAAAC0ycAAAAAMACEycAAAAAsMDECQAAAAAsMHH6n8GDB8vtt99e1Gk4tXz5cgkPD5cTJ04UdSoohiZMmCC1a9eWnJycok4lT7t27RJ/f3/ZuXNnUaeCYogahq/jWgK+jhq2oArRrFmzlIjYv4KCglSNGjXUkCFD1LFjx/I15qhRo7QxjV/r1q2zHOPAgQMqICBAff/991r7lClTVGJiooqNjVUiovr37+9WbtnZ2Wr8+PGqSpUqKigoSNWvX1/NmzfPYd9du3apO+64Q4WFhamIiAjVt29fdfz4cVO/hg0bqqFDh7qVBzzHGzW8e/du9dJLL6mGDRuq8PBwVb58edWlSxf122+/uTxGRkaGKlOmjPr444+19v/85z+qT58+qnr16kpEVEJCgtv5ffTRR6p27doqKChIVa9eXb3//vsO+6WmpqpevXqpUqVKqZtuukn16NFD7d+/39SvR48e6t5773U7D3iGN2pYKaXGjh2runfvrsqWLatERI0aNcqt46lhuMpbNezO72xHuJaAq7xVw9f67LPPlIiosLAwl4+hhq0VycRp9OjRas6cOWrGjBmqf//+qkSJEio+Pl5duHDB7TG3b9+u5syZY/qKjY1VERER6vLly5ZjPPvss6pmzZqm9ri4OFWmTBl15513Kn9/f7cLZfjw4UpE1BNPPKGmT5+uunbtqkREff7551q/w4cPq6ioKFWtWjX13nvvqXHjxqmIiAjVsGFDU/5TpkxRoaGh6uzZs27lAs/wRg2/8MILqnTp0mrAgAHqww8/VBMmTFDVqlVTfn5+auXKlS6NMWnSJFWyZEl18eJFrT0hIUGFh4er9u3bq4iICLcvOqdNm6ZERPXs2VNNnz5d9evXT4mIeuutt7R+586dUzVq1FBly5ZV48ePVxMnTlSxsbGqUqVKKj09Xev77bffKhFR+/btcysXeIY3algppURElS9fXt1xxx35mjhRw3CVt2rY1d/ZeeFaAq7yVg3nOnfunIqJiVFhYWFuTZyoYWtFMnEy/iX9+eefVyLi1l92nElJSVE2m0098cQTln2zsrJUVFSUGjlypOm1gwcPqpycHKWUUmFhYW4VSmpqqgoICFBDhgyxt+Xk5Ki2bduqSpUqqatXr9rbBw0apEJCQtShQ4fsbStXrlQioj788ENt3L/++kv5+fmpmTNnupwLPMcbNbxp0yZ17tw5rS09PV1FR0er1q1buzRGgwYNVN++fU3tKSkpKjs7WymlVN26dd266MzMzFSRkZGqa9euWnufPn1UWFiYOnXqlL1t/PjxSkTUxo0b7W27d+9Wfn5+asSIEdrxWVlZKiIiQr322msu5wLP8dZ5ODk5WSml1IkTJ/I1caKG4Spv1LA7v7Md4VoC7vD29fCwYcNUrVq17Oc6V1DDrikWa5w6dOggIiLJycn2tv3798v+/fvzNd7nn38uSinp06ePZd9169ZJenq6dOrUyfRaXFyc2Gy2fOWwZMkSuXLligwePNjeZrPZZNCgQZKamiq//PKLvX3RokXSrVs3qVy5sr2tU6dOUrNmTVmwYIE2btmyZaVBgwayZMmSfOUF7yhIDTdt2lTCw8O1tsjISGnbtq3s3r3b8vjk5GTZsWOHwxqOjY2VEiXy92O+Zs0aOXnypFbDIiJDhgyRCxcuyDfffGNvW7hwoTRv3lyaN29ub6tdu7Z07NjRVMMBAQHSrl07ariYKeh5uEqVKvl+b2oYnlCQGnbnd7YjXEvAEzxxPbx3716ZNGmSTJw4Ufz9/V0+jhp2TbGYOOUWRGRkpL2tY8eO0rFjx3yNN3fuXImNjZXbbrvNsu/69evFZrNJ48aN8/Veedm6dauEhYVJnTp1tPYWLVrYXxcROXLkiBw/flyaNWtmGqNFixb2ftdq2rSprF+/3qP5omA8XcMiIseOHZOoqCjLfrm10KRJk3y/lyO5tWeszaZNm0qJEiXsr+fk5MiOHTvyrOH9+/fLuXPnTGPs3LlTzp4969GckX/eqGFXUcPwhILUsKu/s/PCtQQ8wRPn4eeee07at28vXbp0ceu9qWHXFMnEKSMjQ9LT0yU1NVXmz58vo0ePlpCQEOnWrVuBx/7jjz9kx44d8uCDD7o0O05KSpIyZcpIyZIlC/ze10pLS5Ny5cqZcqhQoYKIiBw9etTe79p2Y99Tp07J5cuXtfaqVatKenq6HD9+3KM5w3XerGERkbVr18ovv/wiDzzwgGXfpKQkERGJj4/3yHvnSktLEz8/PylbtqzWHhgYKJGRkfYazq3RvGpY5P/Xe66qVatKTk6OPXcUPm/XsDuoYeSHJ2vY1d/ZeeFaAvnh6fPwN998IytWrJCJEye6fSw17BrX7+F5kPE2YFxcnMydO1cqVqxobzt48GC+xp47d66IiEsf0xMROXnypEREROTrvZy5ePGiBAUFmdqDg4Ptr1/7X6u+176em296errpggCFw5s1fPz4cXnooYckPj5eXn75Zcv+J0+eFH9/f9PH/Qrq4sWLEhgY6PC14OBgt2v4WtfWMIqGN2vYXdQw8sOTNezq7+y8cC2B/PBkDWdlZcnQoUPlySeflJtvvtntXKhh1xTJxGny5MlSs2ZN8ff3l3LlykmtWrXy/Rn2aymlZN68eVKvXj1p0KCBW8d5WkhIiGlmLCJy6dIl++vX/teVvrly883v501RcN6q4QsXLki3bt3k3Llzsm7dOo9fSLojJCREsrKyHL526dIlatjHeauGixNq+PrmyRp29Xe2M1xLwF2erOFJkyZJenq6vPHGG/nOhxq2ViQTpxYtWjj8DGNB/fzzz3Lo0CH55z//6fIxkZGRcvr0aY/nUqFCBVmzZo0opbR/0NxbkTExMfZ+17ZfKy0tTcqUKWOafefm68r6F3iHN2o4KytL7rvvPtmxY4d89913Uq9ePZeOi4yMlKtXr8q5c+fkpptu8lg+FSpUkOzsbDl+/Lj2l5ysrCw5efKkvYZzazSvGhb5//Weixouet46D+cHNYz88GQNu/o7Oy9cSyA/PFXDGRkZMnbsWBk8eLCcPXvWvvby/PnzopSSgwcPSmhoqNO7MtSwa66rPy/OnTtXbDabPPTQQy4fU7t2bTl9+rRkZGR4NJdGjRpJZmam6aloGzZssL8uIlKxYkWJjo6WTZs2mcbYuHGjvd+1kpOTJSoqSqKjoz2aM4pOTk6OPPzww7J69WqZN2+eJCQkuHxs7dq1RUR/Co8n5NaesTY3bdokOTk59tdLlCgh9evXd1jDGzZskKpVq5ouhpOTk6VEiRJSs2ZNj+YM30QNo6i5+js7L1xLoCidPn1azp8/LxMmTJD4+Hj716JFiyQzM1Pi4+Nl4MCBTseghl1TbCdO7j5+8cqVK/LFF19ImzZttMcYWmnZsqUopWTz5s35SVNE/p7pJyUlacV29913S0BAgEyZMsXeppSSadOmScWKFaVVq1b29p49e8rXX38thw8ftretXr1a9uzZI7169TK93+bNm6Vly5b5zheFw50afvrpp2X+/PkyZcoUue+++9x6n9xacHSycVVmZqYkJSVp6zU6dOggZcqUkalTp2p9p06dKqGhodK1a1d7W2Jiovz2229aDn/++ad8//33edZw3bp1pVSpUvnOGd5XkG0h3EENw1tcrWF3fmc7wrUEvMWVGi5btqwsX
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "r0m_gom9qL3o",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 906
|
|||
|
|
},
|
|||
|
|
"outputId": "c03db89b-056c-4412-e1f0-687c03007376"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples(incorrect_indices, predict, X_test, y_test, 5)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 38,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 25 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAN5CAYAAAA/32uUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXxM1/8/8Ndk32xZEFvEltiXoPad2qr2pbaqpdYobamlVUtQlKLE1tKqtLQoqijqU7HUEktpBEUQaxKESCSRnN8fvpmfc+9k7iRmMom8no9HHo++z5x77hl95+ScmXvu1QkhBIiIiIiIiChDNtbuABERERERUU7HhRMREREREZEGLpyIiIiIiIg0cOFERERERESkgQsnIiIiIiIiDVw4ERERERERaeDCiYiIiIiISAMXTkRERERERBq4cCIiIiIiItLAhdP/GTlyJFq3bm3tbhi1YsUKlCpVCklJSdbuCuVAuSGHd+/eDTc3N0RHR1u7K5QDMYcpt2MOU26XG3LYqvNhkY3Wrl0rAOh/HB0dRfny5cWoUaPE3bt3s9zu5cuXRbdu3UTBggWFs7OzaNiwofjzzz9NPv7q1avC3t7e4DFr1qwR/v7+wtHRUZQrV04sWbLEpDYPHDggvdeXf44ePaqqf/jwYdGwYUPh7OwsihQpIsaMGSOePHki1UlMTBRFihQRixcvNvm9kXlZIoenTZuWYa4AEIcOHdJswxI5LIQQJ0+eFG+++abIly+fcHNzE61btxanT59W1duzZ4947733ROXKlYWNjY3w8fHJsM3q1auLcePGmdwHMi9LjcMv++GHHwQA4erqavIxlsjhgQMHGv3dioqK0tcNCgoSb7zxhvD09NSfZ+zYseL+/fuqdpnD1pWXxuHMzCU4DuceeWk+LITpc4nk5GTx+eefC19fX+Hg4CB8fX3FzJkzRUpKilTPmvNhqyycZsyYIdavXy9Wr14tBg4cKGxsbISvr694+vRpptu8ceOG8PT0FEWKFBFBQUHiq6++EtWrVxd2dnbir7/+MqmNsWPHigoVKqjKV6xYIQCIbt26iVWrVon+/fsLAGLu3LmabaYPdoGBgWL9+vXST3R0tFT39OnTwsnJSdSsWVMEBweLKVOmCEdHR9G2bVtVuxMmTBA+Pj4iLS3NpPdG5mWJHD579qwqR9avXy9KliwpChUqJJKSkjTbsEQOh4WFCScnJ1G+fHmxYMECMW/ePFG6dGmRP39+ERERIdUdOHCgcHJyEg0aNBAlSpQw+gd7+fLlwsXFRTx+/FizD2R+lsjhlz158kQUK1ZMuLq6ZmrhZIkcPnLkiOr36vvvvxcuLi6iUqVKUt2uXbuK999/XyxatEisWbNGfPjhhyJ//vyiXLlyIj4+XqrLHLauvDQOZ2YuwXE498hL8+HMzCV69uwpdDqdGDx4sAgODtZ/+DV06FBVu9aaD1tl4XTixAmpfPz48QKACAkJyXSbI0eOFHZ2dtI//tOnT0XJkiVFrVq1NI9PTk4Wnp6eYurUqVJ5QkKC8PDwEB06dJDK+/btK1xdXcWDBw+Mtps+2P3888+afWjXrp3w9vYWcXFx+rLVq1cLAGLPnj1S3ZMnTwoAYv/+/ZrtkvlZIocNuXHjhtDpdAYHCyVL5XD79u1FoUKFRExMjL7s9u3bws3NTXTt2lWqe+vWLZGcnCyEEKJDhw5G/2Dfu3dP2Nraim+++UbzvZH5WTqHJ06cKPz8/PR5ZgpL5bAhoaGhAoAICgrSrPvLL78IAOLHH3+UypnD1pWXxuHMzCU4DuceeWk+bOpc4vjx4wKA+PTTT6XjP/zwQ6HT6cTZs2elcmvNh3PEHqcWLVoAAK5du6Yvu3LlCq5cuaJ5bGhoKGrWrAk/Pz99mYuLCzp16oRTp07h8uXLRo8/dOgQYmJi0KpVK6n8wIEDiI2NxciRI6XyUaNG4enTp9i5c6dm39I9efIEz58/N/ja48ePsXfvXvTr1w/58+fXlw8YMABubm7YtGmTVD8gIADu7u7Ytm2byecny3uVHDbkxx9/hBACffv21axrqRwODQ1Fq1at4OHhoS/z9vZG06ZN8dtvvyE+Pl5fXqxYMdjb25vy1lC4cGFUq1aNOZzDmCOHL1++jEWLFmHhwoWws7Mz+bjsGIfThYSEQKfT4Z133tGsW7p0aQDAo0ePpHLmcM70Oo7DLzM2lwA4Dr8OXsf5sKlzidDQUABA7969peN79+4NIQQ2btwolVtrPpwjFk7pCfHyP2rLli3RsmVLzWOTkpLg7OysKndxcQEAhIWFGT3+yJEj0Ol0qFmzplR++vRpAEDt2rWl8oCAANjY2Ohf1zJo0CDkz58fTk5OaN68OU6ePCm9fu7cOTx//lx1HgcHB9SoUcPgeWrVqoXDhw+bdH7KHq+Sw4Zs2LABJUuWRJMmTTTrWiqHjf1uJScn4/z585p9y0hAQACOHDmS5ePJ/MyRwx988AGaN2+O9u3bZ+rclh6H06WkpGDTpk1o0KCBflH0MiEEYmJicPfuXYSGhiIwMBC2trZo1qyZqi5zOOd5HcfhdFpziaxgDuc8r+N82NS5RPqNHpR1jfXfGvNh0z8SNKO4uDjExMTg2bNnOHz4MGbMmAFnZ2d07Ngx0235+fkhNDQUT548Qb58+fTlhw4dAgDcunXL6PERERFwd3eXvu0BgDt37sDW1haFCxeWyh0cHODh4YHbt28bbdfBwQHdunVD+/bt4enpifDwcCxYsACNGzfGkSNH9Il5584dAC9W30re3t76FfjLypQpg/Xr1xs9P1mWOXNY6d9//8U///yDCRMmQKfTada3VA77+fnh77//RmpqKmxtbQEAycnJOHbsGADt3y1jypQpg5iYGNy/f1/VP8oe5s7hnTt34o8//sDZs2czfaylclhpz549iI2NzfAbhHv37kljcYkSJRASEgJ/f39VXeaw9eWFcdjUuURWMIetLy/Mh02dS6R/U3b48GH4+vrqj0+fBxvqvzXmw1ZZOCm/BvTx8cGGDRtQvHhxfVlkZKRJbY0YMQI7duxAr169EBQUBFdXVyxfvlz/aUxiYqLR42NjY1GoUCFVeWJiIhwcHAwe4+TkpNlugwYN0KBBA33cqVMndO/eHdWqVcOkSZOwe/duqX+Ojo4mn6dQoUJITExEQkKCfiVO2cucOay0YcMGADDp8hDAcjk8cuRIjBgxAoMHD8aECROQlpaGWbNm6Rf7Wscbk97fmJgY/sG2EnPmcHJyMsaNG4fhw4ejUqVKme6LpXJYKSQkBPb29ujZs6fB193d3bF37148e/YMp0+fxpYtW6RLUl/GHLa+vDAOmzqXyArmsPXlhfmwqXOJ9u3bw8fHBx999BFcXFwQEBCAY8eOYcqUKbCzs8sx82GrLJyWLVuGChUqwM7ODkWKFIGfnx9sbLJ21WC7du2wdOlSfPLJJ6hVqxYAoFy5cggKCsKECRPg5uam2YYQQlXm7OyM5ORkg/WfPXtm8GtHLeXKlcPbb7+NLVu26Ffe6e0Yuhd9RudJ768pn4KRZZgzh18mhEBISAiqVKmCatWqZeo4pVfN4eHDh+PmzZuYP38+vvvuOwAvvqqfMGECgoKCTPrd0uovc9h6zJnDixYtQkxMDKZPn57l/lh6HI6Pj8e2bdvw5ptvSpfBvMzBwUE/kenYsSNatmyJhg0bonDhwqpPgJnD1pcXxmFDDM0lsoI5bH15YT5s6lzCyckJO3fuRM+ePdGtWzcAL75UmDdvXoZzDmvksFX2ONWtWxetWrVCs2bNULFixVce6EaPHo179+7hyJEjOHnyJCIiIlCgQAEAQIUKFYwe6+HhgYcPH6rKvb29kZqaivv370vlycnJiI2NRbFixbLU15IlSyI5ORlPnz7Vnwf4/5fsvezOnTsGz/Pw4UO4uLhkacAl8zB3Dqc7fPgwrl+/bvKnnIBlczgoKAj37t1DaGgo/vnnH5w4cQJpaWkAtH+3jEnvr6enZ5bboFdjrhyOi4vDrFmzMHToUDx+/BiRkZGIjIxEfHw8hBCIjIxU5aBSdozDv/76KxISE
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "RsY5uPgU8Yak"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Great results!\n",
|
|||
|
|
"\n",
|
|||
|
|
"But wouldn't it be nice if we could visualize those convolutions so that we can see what the model is seeing?"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "BgbunToX8Yal"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"def visualize2(model, layer, img, r, c):\n",
|
|||
|
|
" from keras import Model\n",
|
|||
|
|
" # expand dimensions so that it represents a single 'sample'\n",
|
|||
|
|
" img = np.expand_dims(img, axis=0)\n",
|
|||
|
|
" outputs = [layer.output]\n",
|
|||
|
|
" model = Model(inputs=model.inputs, outputs=outputs)\n",
|
|||
|
|
" # get feature map for first hidden layer\n",
|
|||
|
|
" feature_maps = model.predict(img)\n",
|
|||
|
|
" for fmap in feature_maps:\n",
|
|||
|
|
" # plot all 64 maps in an 8x8 squares\n",
|
|||
|
|
" ix = 1\n",
|
|||
|
|
" for _ in range(r):\n",
|
|||
|
|
" for _ in range(c):\n",
|
|||
|
|
" # specify subplot and turn of axis\n",
|
|||
|
|
" ax = plt.subplot(r, c, ix)\n",
|
|||
|
|
" ax.set_xticks([])\n",
|
|||
|
|
" ax.set_yticks([])\n",
|
|||
|
|
" # plot filter channel in grayscale\n",
|
|||
|
|
" plt.imshow(fmap[:, :, ix-1], cmap='gray')\n",
|
|||
|
|
" ix += 1"
|
|||
|
|
],
|
|||
|
|
"execution_count": 39,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "ZiDemlgp8Yan",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 770
|
|||
|
|
},
|
|||
|
|
"outputId": "1fd81fd0-2e08-40de-fcc2-2f8d1cccb533"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"plt.figure()\n",
|
|||
|
|
"img = X_test[11]\n",
|
|||
|
|
"plt.imshow(img[:,:,0], cmap='gray', interpolation='none')"
|
|||
|
|
],
|
|||
|
|
"execution_count": 40,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<matplotlib.image.AxesImage at 0x7f0d78ec5a20>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 40
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 1 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAuQAAALgCAYAAADcNWJzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAl3UlEQVR4nO3df6yW9X34/9cB5IjtOTc9IpxzxoGiVF2KUGvhSLTMViJg64ryB7ouwYZp6w6sQjodG0qtJmR+0mm6UTs3I+tSbGsiWl3GoiAYK9gU44xLy4CwAuGHLQnn8GMCgev7x74926mAHLgvXufH45HcCee+r/N+v8zV++TZK9e5T01RFEUAAAApBmQPAAAA/ZkgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASDcoe4HedOHEidu3aFXV1dVFTU5M9DgAAdFtRFHHgwIFobm6OAQNOfw28xwX5rl27oqWlJXsMAAA4Zzt27IiRI0ee9pged8tKXV1d9ggAAFAVZ9K2PS7I3aYCAEBfcSZt2+OCHAAA+hNBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiUoL8mXLlsXHP/7xuPDCC6O1tTV+9rOflbUVAAD0WqUE+Y9+9KNYuHBhLFmyJN56662YMGFCTJs2Ld57770ytgMAgF6rpiiKotqLtra2xsSJE+Pv/u7vIiLixIkT0dLSEvPnz4+/+Iu/OO33dnR0RKVSqfZIAABw3rW3t0d9ff1pj6n6FfKjR4/Gxo0bY+rUqf+7yYABMXXq1Fi/fv0Hjj9y5Eh0dHR0eQAAQH9R9SD/zW9+E8ePH48RI0Z0eX7EiBGxZ8+eDxy/dOnSqFQqnY+WlpZqjwQAAD1W+qesLFq0KNrb2zsfO3bsyB4JAADOm0HVXnDYsGExcODA2Lt3b5fn9+7dG42NjR84vra2Nmpra6s9BgAA9ApVv0I+ePDguOaaa2L16tWdz504cSJWr14dkydPrvZ2AADQq1X9CnlExMKFC2POnDnxmc98JiZNmhSPP/54HDp0KL7yla+UsR0AAPRapQT57Nmz49e//nU8+OCDsWfPnvjUpz4Vq1at+sAvegIAQH9XyueQnwufQw4AQF+R8jnkAADAmRPkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQaFD2AED/UVtbW/oeP/3pT0vf4+qrry59jxdffLHU9WfOnFnq+gCcOVfIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASDQoewCg56itrS11/ccee6zU9SMiPvWpT5W+R1EUpe+xcePG0vcAoGdwhRwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEg7IHAHqOP/uzPyt1/bvvvrvU9SMi1qxZU/oeDz74YOl7bNiwofQ9AOgZXCEHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgUdWD/Jvf/GbU1NR0eVx55ZXV3gYAAPqEUv5S5yc/+cl45ZVX/neTQf4gKAAAnEwppTxo0KBobGw8o2OPHDkSR44c6fy6o6OjjJEAAKBHKuUe8s2bN0dzc3Nceuml8eUvfzm2b99+ymOXLl0alUql89HS0lLGSAAA0CNVPchbW1tj+fLlsWrVqnjiiSdi27Zt8dnPfjYOHDhw0uMXLVoU7e3tnY8dO3ZUeyQAAOixqn7LyowZMzr/PX78+GhtbY3Ro0fHj3/845g7d+4Hjq+trY3a2tpqjwEAAL1C6R97OHTo0Lj88stjy5YtZW8FAAC9TulBfvDgwdi6dWs0NTWVvRUAAPQ6VQ/yb3zjG7Fu3br4r//6r3jjjTfi1ltvjYEDB8Ydd9xR7a0AAKDXq/o95Dt37ow77rgj9u3bF5dccklcf/31sWHDhrjkkkuqvRUAAPR6VQ/yH/7wh9VeEgAA+qzS7yEHAABOTZADAEAiQQ4AAIkEOQAAJKr6L3UCvVdjY2P2COfslVdeKX2PDRs2lL4HAP2HK+QAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkGpQ9ANBz1NXVlbr+sWPHSl0/IuKVV14pfQ8AqCZXyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEg0KHsA4Mw0NzeXvsfcuXNLXf+NN94odf2IiLfeeqv0PQCgmlwhBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAINGg7AGAM7N48eLsEYASXHvttaXv0dLSUvoe58O///u/l7r+f/7nf5a6PpyKK+QAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQZlDwCcmS984QvZI5yzp556KnsE+pgnnnii9D3Kfu997GMfK3X9iIghQ4aUvsf50NHRUer6jz32WKnrR0Q8/PDDpe9B7+MKOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIm6HeSvvfZa3HLLLdHc3Bw1NTXx/PPPd3m9KIp48MEHo6mpKYYMGRJTp06NzZs3V2teAADoU7od5IcOHYoJEybEsmXLTvr6o48+Gt/5znfie9/7Xrz55pvxkY98JKZNmxbvv//+OQ8LAAB9zaDufsOMGTNixowZJ32tKIp4/PHHY/HixfGlL30pIiK+//3vx4gRI+L555+P22+//dymBQCAPqaq95Bv27Yt9uzZE1OnTu18rlKpRGtra6xfv/6k33PkyJHo6Ojo8gAAgP6iqkG+Z8+eiIgYMWJEl+dHjBjR+drvWrp0aVQqlc5HS0tLNUcCAIAeLf1TVhYtWhTt7e2djx07dmSPBAAA501Vg7yxsTEiIvbu3dvl+b1793a+9rtqa2ujvr6+ywMAAPqLqgb5mDFjorGxMVavXt35XEdHR7z55psxefLkam4FAAB9Qrc/ZeXgwYOxZcuWzq+3bdsWb7/9djQ0NMSoUaPi3nvvjUceeSQ+8YlPxJgxY+KBBx6I5
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "UklZv03C8Yao",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 747
|
|||
|
|
},
|
|||
|
|
"outputId": "a31a7e23-e735-46c1-a323-3c49d430527a"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"visualize2(model, model.layers[1], img, 4, 4)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 41,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"1/1 [==============================] - 0s 82ms/step\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 16 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAswAAALJCAYAAACgHHWpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABfcUlEQVR4nO3deXBd933e/y+IjdgXAiAIENwXcackLqIoWaJl0VIcSd4Ub0oqr43HaZy0TeW47bhuJ7aTSV1Pm8Zubbe2lVi1LUeyI5mkRO20xMWkuO8rQAIgCBDLxUKs9/fHb5Sxos/zIQ54L5aL9+sfzTyH53sOLs49+PCK50FaPB6PBwAAAACmKWN9AgAAAMB4xsAMAAAAOBiYAQAAAAcDMwAAAOBgYAYAAAAcDMwAAACAg4EZAAAAcDAwAwAAAI6M4fyhoaGhUF9fHwoKCkJaWlqyzwmTSDweD7FYLFRVVYUpU0b3729c10gWrmukIq5rpKLhXtfDGpjr6+tDTU1Nwk4O+Ofq6urCzJkzR/WYXNdINq5rpCKua6Si613XwxqYCwoKEnZCiVZZWWnmH/zgB818w4YNci31W8J/9KMfmfn27duvc3YYrrG4xt46ZmZm5rj7xGJgYMDMh4aGRvlMbkxWVtZYn8KYiMfjob+/f0yv6/FInduKFSsi5SHo+/Uzzzxj5vX19dc5u3fKzMw08/7+/shrpRKu67dTc0hZWVmkPx9CCOvXrzfzhQsXmnlnZ6eZ//KXv5TH2LNnj5m3trbKfSaD611jwxqYx9sw8dvUx+fZ2dlmnpubK9dSN2B100TijMU19tYx09LSxt01Pt7OZ6RS5esYqbG8rscjdW4ZGfaPInUfD0H/5TGR/1RgPL+WY4nr+u3UNZeenm7m6noPIYSpU6eauZpdBgcHzdybW0b7n9NMFNe7xnjVAAAAAAcDMwAAAOBgYAYAAAAcw/o3zGNt2rRpctt73vMeM587d66Zf+9735NrbdmyJdqJjUB+fr6Zq3+4j9TlPcCn/l0aMN55z4ksWrTIzDdu3Gjmc+bMkWs1Nzeb+d13323mx44dM/O6ujp5jPb2drkNk0tpaWnkbeqaO3DggFxLlQncddddZq7mo8uXL8tjYGT4hBkAAABwMDADAAAADgZmAAAAwMHADAAAADgYmAEAAAAHAzMAAADgmBC1ciUlJXKb+nXWf/EXf2Hmo1G1oqqTQgghKyvLzA8fPpys08E41d/fP9anACScd79ev369mav74o9//GO51vnz5818xYoVZq7uy11dXfIYTU1Nchsml97eXrktkT+/BwYGzLyhocHMVVWt9+u3vV+bDY1PmAEAAAAHAzMAAADgYGAGAAAAHAzMAAAAgIOBGQAAAHBMiJYMz8svv2zmiWzDqK6uNvMFCxaY+ZUrV+Razc3NCTkn4Ealp6eb+Qc/+EEznzp1qlxr7969Zn7mzJnoJ4YJzWvJmDLF/ozmySefNPMjR45EPv6FCxfMvLS01MyvXr0a+RiYfLw2lURS7RYzZ840846ODjP35pBYLBb9xMAnzAAAAICHgRkAAABwMDADAAAADgZmAAAAwMHADAAAADjGVUuGemo/Oztb7pPINgx1/I0bN5q5etK0sbFRHkM90YrJRzUGeNtyc3PN3HuPeE9LW/Lz88382rVrcp/u7u5Ix8DEp67RgYEBuc+LL75o5iNpw4iKNgwMR1pampnH4/GEHUPdY0MIYdOmTWau3leHDh0y84aGhugnBhefMAMAAAAOBmYAAADAwcAMAAAAOBiYAQAAAAcDMwAAAOAYVy0ZlZWVZj5//ny5T29vr5mr3/s+NDQk1xocHDTzCxcumLlq6GhtbZXHUE/aTp061czLy8vlWlVVVWZeUFBg5l5DR319vZk3NTWZeV9fn1wLb6eebvauRbVNPcHtXXNKRob99n/zzTfNvLa2Vq6l3m/qeh+t60c1OaivXVGv+2RWWFho5l5LxvHjxxN2/MzMTDPv7+9P2DEUdf2o+7inp6fHzNXPIyRXRUWFmSeykeuBBx6Q29QcpN47qq1rJFTbkpopQtD3AfUeUTNbCPprUflovNd/G58wAwAAAA4GZgAAAMDBwAwAAAA4GJgBAAAABwMzAAAA4GBgBgAAABzjqlZu5syZZq5qd0II4fTp08k6nese/8qVK2auamlCCOH2228384ceesjMN27cKNdSVS9Xr1418zNnzsi19u7da+avvfaamavqMXXsySyR9VDt7e0JW0tV16mKOq+W0KsSG0vqa4xaa5eeni63Ra2oSxWq/mru3Llyn7Nnz5q5+j55dVbqfaWuRVVxmJ+fL49RVlZm5rm5uWauagy946vKLK/GzKvmwo1J5P1a1VEuWLBA7tPd3W3mag5Rf17VzoYQwrx58yKdl/rzIYQwe/ZsM58+fbqZe/feAwcOmPmePXvMfN++fWZ+6dIleQz1PhwOPmEGAAAAHAzMAAAAgIOBGQAAAHAwMAMAAAAOBmYAAADAMa4e71ZPm2/fvn2Uz+TtVOuFelL64x//uFxr06ZNZl5UVGTm27Ztk2v94he/MPPz58+beWFhoVwrKyvLzJubm83ca0zAxKDeb42NjWY+XpswRoP35LzVoHEjT2JPFOopeI9qw1BmzZolt6kWiWvXrpm5uo97jQU1NTVmPmPGDDPv7++Xa02bNs3MVROHuveGoH8u7Ny5U+6D4fFe96hUY4y6RkPQ7VudnZ1mvnr1ajP32rqqq6vNfPHixWael5cn18rMzDRz1RCydOlSudaSJUvMfMWKFZHO68UXX5TH8NpnrodPmAEAAAAHAzMAAADgYGAGAAAAHAzMAAAAgIOBGQAAAHCMSUtGaWmpmaumhtF64vy2224z8wcffNDM1ROl6gnqEEJ48sknzfyrX/3qdc5ubKgmkIKCAjNvbW1N5umMa16TgkW1VISgW1NaWloiHcPT09OTsLUms76+vrE+haRS9wDVmrJr166EHVs1A4QQQnFxcaRcPYG/fPlyeYzKykoz7+7uNnP1WoWgn+hft26dmatmAG+fr3/962a+Y8cOudZkVVJSYuaJ/Bk2e/ZsM1ezTgj63q+uX/Ue8X4eXbhwwcx/+ctfmrlq7hiJm266SW77xCc+Yebr1683c/WanDx5Uh6jqanpHdlwZ0w+YQYAAAAcDMwAAACAg4EZAAAAcDAwAwAAAA4GZgAAAMDBwAwAAAA4xqRWrry83MyvXr2a9GOrOp4QQvj0pz9t5kuXLjXz+vp6M3//+98vjzHRqryuXbtm5qo+SVX1hJD6lXOqZmsk0tPTE7ZWqlB1jdnZ2XKf8+fPJ+lsJg9VA6oqDhN5H/fqntR5rV692sxVTZu6j4cQwssvv2zmqmaroaFBrqXuD6rW7jOf+Yxcq7q62sxvvvlmMz969KiZj8bP3PFqaGgo6ceYP3++mU+fPl3uo+rgamtrzVxdi6+//vp1zm5sHD9+XG77yU9+YuZqBlPvg6qqKnmMQ4cOvSOLx+Ohv79f7vMWPmEGAAAAHAzMAAAAgIOBGQAAAHAwMAMAAAAOBmYAAADAMSYtGZmZmWZ+8uTJyGsVFBSYuWp3UE9QhxBCZWWlmf/oRz8y8//1v/6Xf3JJtmbNGjNXT4N3dHTItX7+85+buXqSuLm52cxVk0EI9vc9Ho8ntF1iIvG+7qamplE8k9G3fv16ue1zn/ucmauGm+LiYrnWiy++aOb/5//8HzN/5ZVX5FqTVVlZmZmre+xIFBUVmblqwgghhNtuu83M1T3+zTffNPPHH39cHqOtrU1uS5QLFy6Y+bPPPiv3+dCHPmTmGRn2j/S8vDwz976Hvb2978ji8fiotEuMhvb29oStNWfOHDNfvHixmXvtL3v27DHzF154IfJ5TTSHDx82c/XeLSwsNHOvOclqoKIlAwAAAEgABmYAAADAwcAMAAAAOBiYAQAAAAcDMwAAAOAYk5YM1dbQ09MTea0pU+yZX/2OcZWHEMIvf/lLM//ud78b+byiuueee8z8rrvukvtcunTJzPv6+sy8q6tLrlVSUmLmra2tch9LeXm53GY9k
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "gt8S9bzR8Yar",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 670
|
|||
|
|
},
|
|||
|
|
"outputId": "97bb3961-0315-449e-8200-3637b42ce121"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"visualize2(model, model.layers[4], img, 4, 8)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 42,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"1/1 [==============================] - 0s 169ms/step\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 32 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAs0AAAJ8CAYAAAAF2ZxRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAp/ElEQVR4nO3de5Dd9Vk/8O9uLpvbJiElpElsCiRgREqgk7FMQSAMtsgIdErtRCs4Umi0aJlWvNDilNoZrXS0rdVmMLVcOlq5dLiEFnAgFRAmasBQSpNCbuSUza2ksJtkN3s55/eHRnF+nPM8m3PO5uzu6zVz/vo8+3w/++zZs+89O/v9tFUqlUoBAABU1X6sNwAAAK1OaAYAgIDQDAAAAaEZAAACQjMAAASEZgAACAjNAAAQmJgpKpfLRVdXV9HZ2Vm0tbU1e0+jVqVSKXp6eooFCxYU7e3D+33EjHPMuPnMuPnMuPnMuPnMuPnMeGSk51xJKJVKlaIoPJKPUqmUGasZm3FLP8zYjMfCw4zNeCw8zLg15px6p7mzszNTxn87mnllP2bixNSXrBgcHBz2HkaTZs64kZYvXx7WPP7446lev/7rvx7WfOc730n1WrBgQdW1crlc7N69u64Zz5s3r+Zv61deeWWq37Jly8KaX/mVXwlrnn/++dT11q9fH9Y88cQTqV6bNm2qulYul4sdO3Y09Xm8atWqVN0tt9wS1syaNSvVqxWNlteKjIsuuihV99RTT4U1hw8frnc7/2MszbhVmfHIiGaWSmDe1h+eo5lX9mN8Lf5LM2fcSJlfcmbOnJnqNWnSpHq38z8yf+arZ8bt7e01r9HR0ZHqN23atLAmM78ZM2akrjd16tSwJvt1aPaMI9kZZ59/o9Voea3IyD73Rnr/Y2nGrcqMR0Y0M/8ICAAAAaEZAAACQjMAAASEZgAACAjNAAAQEJoBACCQu+kvLeOUU05J1f3whz9s8k7I+OIXvxjWbN26NdXr/vvvr3M3/2vv3r1V1yqVSt39V6xYUUyePLnq+nnnnZfq89Of/jSsWbNmTVjzl3/5l6nrZe5du2PHjlSvY23JkiWpul/+5V9u8k7IyNwerLe3N9XrtNNOC2uee+65VC/gf3mnGQAAAkIzAAAEhGYAAAgIzQAAEBCaAQAgIDQDAEBAaAYAgIDQDAAAAaEZAAACTgQcZcrl8rHeAsPQ2dkZ1qxevXoEdvJ/9ff3N7X/8uXLi6lTp1ZdnzJlSqrPK6+8EtZkTlLLXi9zIuAZZ5yR6tXX11d1bWhoKH0S5NH68z//81Tdrl27mroPcs4///yw5qyzzkr1mjNnTljz/ve/P9Vr8+bNVdcGBgaKhx56KNVnrFm1alVYc9ttt6V6Nfv1mMbxTjMAAASEZgAACAjNAAAQEJoBACAgNAMAQEBoBgCAgNAMAAABoRkAAALj4nCTD33oQ2HNvffeOwI7iZ1//vnFxInVvyw/+MEPRnA31Ov2228Pa7785S83fR8jrVKpFJVKper6hg0bUn0yh5L8x3/8R1gzffr01PVWrFgR1vT09KR6LV++vOpab29v8alPfSrV52iN9KEll156aapu7dq1Td7J6HTccceFNbt37071ynztL7/88lSvCy64oOrawYMHx+3hJrfeeuux3gLDcPHFF9dcHxwcLB577LGwj3eaAQAgIDQDAEBAaAYAgIDQDAAAAaEZAAACQjMAAASEZgAACAjNAAAQGBeHmwwMDBzrLaQ9+OCDxcyZM6uut7W1jeBuqNf27dvDmssuuyzV6+Mf/3hYs23btob1qseWLVuKjo6Oquuvvvpqqs9ZZ50V1pRKpbBmcHAwdb3M99fOnTtTvYaGhqqu9ff3p3qMJsfi0JJly5ZVXRsaGhpVh0Hdd999Det1xRVXhDUvvfRSqte3vvWtqmu1nuPQSh555JGG9PFOMwAABIRmAAAICM0AABAQmgEAICA0AwBAQGgGAICA0AwAAAGhGQAAAkIzAAAEGnoi4OTJk1N1I30a1gMPPDCi16vHgQMHivZ2v8uMBhdffHFYc8MNN4Q173nPe1LXmzRpUlhz5513pno126OPPlrzebxly5ZUnx/+8IdhTeZUss2bN6eu9+///u+pOlrDJZdcUnXt8OHDLXMi4IwZM8Kak08+OazJnJBZFEWxdevWsOazn/1sqhfwv6QzAAAICM0AABAQmgEAICA0AwBAQGgGAICA0AwAAAGhGQAAAkIzAAAEGnq4Sa0bzb/ZCy+8ENZkbs4+bdq01PUOHToU1sydOzfVa9++fam6o/W1r32tmDJlStX1VtknRfELv/ALYc2Pf/zjsOapp55KXa9UKoU1q1evTvVqtm3btjWkz4svvtiQPsfC9OnTq65VKpXU69J4tWjRolTdhg0bqq4NDg7WvY/58+fXPKTn1VdfTfU5cOBAWJP5/v7MZz6Tul5fX19Ys2nTplSv1157LVU3GkydOjWs6e3tHYGdMFp5pxkAAAJCMwAABIRmAAAICM0AABAQmgEAICA0AwBAQGgGAICA0AwAAAGhGQAAAsM6EfBTn/pU0dHRUXU9e4rTfffdF9Y899xzDelTFEXxjW98I6zp6upK9Wq29vb2midQfeITn0j1+ZM/+ZNGbYkqMs/RP/3TPx2BnQzf3/7t31Zd6+3tLW644YYR3A38Xzt37mxo3dE655xzikmTJlVd//jHP57qc+6554Y1zzzzTFjT39+fut6ePXvCmvPOOy/Vq9aMh4aGio0bN6b6VPPRj360mDx5ctX1VatWpfosW7YsrLn99tvDmnXr1qWu981vfjNVR3PdeOONqbo/+7M/q7ne3d1dzJo1K+zjnWYAAAgIzQAAEBCaAQAgIDQDAEBAaAYAgIDQDAAAAaEZAAACqfs0VyqVoiiK4vDhwzXrent7Uxft7u4Oaw4cOBDWRPs5olwup+oa5ci8juZjos9ppD+XVlXPjBtlYGCgof1GUq3v1b6+vqIoWmPGo1mtWRxZM+Pmq2fG0ff4wYMHU/0yP/MyvbL3aT506FBYk339GhoaCtfqmXH0OWWyQFHkZpzJKNkZjzSvFW8tmwOj58eR9XBmlYRSqVQpisIj+SiVSpmxmrEZt/TDjM14LDzM2IzHwsOMW2PObZVK/KtIuVwuurq6is7OzqKtrS0qH7cqlUrR09NTLFiwoOapfm/FjHPMuPnMuPnMuPnMuPnMuPnMeGRk55wKzQAAMJ75R0AAAAgIzQAAEBCaAQAgIDQDAEBAaAYAgIDQDAAAAaEZAAACQjMAAASEZgAACAjNAAAQEJoBACAgNAMAQEBoBgCAgNAMAAABoRkAAAJCMwAABIRmAAAICM0AABAQmgEAICA0AwBAQGgGAICA0AwAAAGhGQAAAkIzAAAEhGYAAAgIzQAAEBCaAQAgIDQDAEBAaAYAgIDQDAAAAaEZAAACQjMAAASEZgAACAjNAAAQEJoBACAgNAMAQEBoBgCAgNAMAAABoRkAAAJCMwAABIRmAAAICM0AABAQmgEAICA0AwBAQGgGAICA0AwAAAGhGQAAAkIzAAAEhGYAAAgIzQAAEBCaAQAgIDQDAEBAaAYAgIDQDAAAAaEZAAACQjMAAASEZgAACAjNAAAQEJoBACAgNAMAQEBoBgCAgNAMAAABoRkAAAJCMwAABIRmAAAICM0AABAQmgEAICA0AwBAQGgGAICA0AwAAAGhGQAAAkIzAAAEhGYAAAgIzQAAEBCaAQAgIDQDAEBAaAYAgIDQDAAAAaEZAAACQjMAAASEZgAACAjNAAAQEJoBACAgNAMAQEBoBgCAgNAMAAABoRkAAAJCMwAABIRmAAAICM0AABAQmgEAICA0AwBAQGgGAICA0AwAAAGhGQAAAkIzAAAEhGYAAAgIzQAAEBCaAQAgIDQDAEBAaAYAgIDQDAAAgYmZonK5XHR1dRWdnZ1FW1tbs/c0alUqlaKnp6dYsGBB0d4+vN9HzDjHjJvPjJvPjJvPjJvPjJvPjEdGes6VhFKpVCmKwiP5KJVKmbGasRm39MOMzXgsPMzYjMfCw4xbY86pd5o7OzuLoiiKc
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "okUztZOd79by"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Transfer learning"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "8MRwLZgGhK-K"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Funkcje pomocnicze"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "1AIiBfX78luv"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Loading the specified subset of digits from the MNIST dataset\n",
|
|||
|
|
"def get_mnist(digits):\n",
|
|||
|
|
" # Reload the MNIST data\n",
|
|||
|
|
" (X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
|
|||
|
|
"\n",
|
|||
|
|
" train_ids = np.in1d(y_train, digits)\n",
|
|||
|
|
" test_ids = np.in1d(y_test, digits)\n",
|
|||
|
|
" X_train = X_train[train_ids]\n",
|
|||
|
|
" X_test = X_test[test_ids]\n",
|
|||
|
|
" y_train = y_train[train_ids]\n",
|
|||
|
|
" y_test = y_test[test_ids]\n",
|
|||
|
|
"\n",
|
|||
|
|
" X_train = np.expand_dims(X_train, axis=-1)\n",
|
|||
|
|
" X_test = np.expand_dims(X_test, axis=-1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" X_train = X_train.astype('float32') # change integers to 32-bit floating point numbers\n",
|
|||
|
|
" X_test = X_test.astype('float32')\n",
|
|||
|
|
"\n",
|
|||
|
|
" X_train /= 255 # normalize each value for each pixel for the entire vector for each input\n",
|
|||
|
|
" X_test /= 255\n",
|
|||
|
|
"\n",
|
|||
|
|
" return X_train, y_train, X_test, y_test"
|
|||
|
|
],
|
|||
|
|
"execution_count": 43,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Zf0T6tdPBC70"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Preparing the base model\n",
|
|||
|
|
"def prepare_model(input_shape, class_count):\n",
|
|||
|
|
" model = Sequential() # Linear stacking of layers\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Convolution Layer 1\n",
|
|||
|
|
" model.add(Conv2D(16, (3, 3), padding=\"same\", input_shape=input_shape))\n",
|
|||
|
|
" model.add(Activation('relu') )\n",
|
|||
|
|
" model.add(MaxPooling2D(pool_size=(2,2)))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Convolution Layer 2\n",
|
|||
|
|
" model.add(Conv2D(32, (3, 3), padding=\"same\"))\n",
|
|||
|
|
" model.add(Activation('relu'))\n",
|
|||
|
|
" model.add(MaxPooling2D(pool_size=(2,2)))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Convolution Layer 3\n",
|
|||
|
|
" model.add(Conv2D(64, (3, 3), padding=\"same\"))\n",
|
|||
|
|
" model.add(Activation('relu'))\n",
|
|||
|
|
" model.add(MaxPooling2D(pool_size=(2,2)))\n",
|
|||
|
|
"\n",
|
|||
|
|
" model.add(Flatten())\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Fully Connected Layer\n",
|
|||
|
|
" model.add(Dense(64))\n",
|
|||
|
|
" model.add(Activation('relu'))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Fully Connected Layer\n",
|
|||
|
|
" model.add(Dense(class_count))\n",
|
|||
|
|
" model.add(Activation('softmax'))\n",
|
|||
|
|
"\n",
|
|||
|
|
" return model"
|
|||
|
|
],
|
|||
|
|
"execution_count": 44,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "EX6bXUPBGDyx"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# One-hot encoding of non-consecutive labels\n",
|
|||
|
|
"def one_hot(labels):\n",
|
|||
|
|
" from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
|
" encoder = OneHotEncoder(categories='auto')\n",
|
|||
|
|
" l = labels.reshape(-1, 1)\n",
|
|||
|
|
" output = encoder.fit_transform(l)\n",
|
|||
|
|
" return output.toarray()"
|
|||
|
|
],
|
|||
|
|
"execution_count": 45,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "-VP_9oFFiE6a"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"def test_model(model, X_test, Y_test, y_test, digits):\n",
|
|||
|
|
" score = model.evaluate(X_test, Y_test)\n",
|
|||
|
|
" print('Test score:', score[0])\n",
|
|||
|
|
" print('Test accuracy:', score[1])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # The predict_classes function outputs the highest probability class\n",
|
|||
|
|
" # according to the trained classifier for each input example.\n",
|
|||
|
|
" predicted = model.predict(X_test)\n",
|
|||
|
|
" predicted_classes = np.argmax(predicted, axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" actual_classes = [digits[x] for x in predicted_classes]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Check which items we got right / wrong\n",
|
|||
|
|
" correct_indices = np.nonzero(actual_classes == y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
" incorrect_indices = np.nonzero(actual_classes != y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" cnf_matrix = confusion_matrix(y_test, actual_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
" class_names = [str(i) for i in digits]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Plot non-normalized confusion matrix\n",
|
|||
|
|
" plt.figure()\n",
|
|||
|
|
" plot_confusion_matrix(cnf_matrix, classes=class_names,\n",
|
|||
|
|
" title='Confusion matrix, without normalization')\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.show()"
|
|||
|
|
],
|
|||
|
|
"execution_count": 46,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Fi7BwsX8h5Dm"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Trening modelu bazowego: cyfry 0..4"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "ZyAj3Kk4Bw4d",
|
|||
|
|
"outputId": "079eade4-dd17-4e2f-e7df-e68c2fc499a7"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"digits = [0, 1, 2, 3, 4]\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train, y_train, X_test, y_test = get_mnist(digits)\n",
|
|||
|
|
"\n",
|
|||
|
|
"Y_train = one_hot(y_train)\n",
|
|||
|
|
"Y_test = one_hot(y_test)\n",
|
|||
|
|
"\n",
|
|||
|
|
"model = prepare_model( (28, 28, 1), 5)\n",
|
|||
|
|
"\n",
|
|||
|
|
"adam = tf.optimizers.Adam(learning_rate=0.001)\n",
|
|||
|
|
"model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.summary()\n",
|
|||
|
|
"\n"
|
|||
|
|
],
|
|||
|
|
"execution_count": 47,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Model: \"sequential_2\"\n",
|
|||
|
|
"_________________________________________________________________\n",
|
|||
|
|
" Layer (type) Output Shape Param # \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
" conv2d_2 (Conv2D) (None, 28, 28, 16) 160 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_7 (Activation) (None, 28, 28, 16) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_2 (MaxPoolin (None, 14, 14, 16) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_3 (Conv2D) (None, 14, 14, 32) 4640 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_8 (Activation) (None, 14, 14, 32) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_3 (MaxPoolin (None, 7, 7, 32) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_4 (Conv2D) (None, 7, 7, 64) 18496 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_9 (Activation) (None, 7, 7, 64) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_4 (MaxPoolin (None, 3, 3, 64) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" flatten_1 (Flatten) (None, 576) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_5 (Dense) (None, 64) 36928 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_10 (Activation) (None, 64) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_6 (Dense) (None, 5) 325 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_11 (Activation) (None, 5) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
"Total params: 60549 (236.52 KB)\n",
|
|||
|
|
"Trainable params: 60549 (236.52 KB)\n",
|
|||
|
|
"Non-trainable params: 0 (0.00 Byte)\n",
|
|||
|
|
"_________________________________________________________________\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "BjEZRMnvChne"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,\n",
|
|||
|
|
" height_shift_range=0.08, zoom_range=0.08, validation_split=0.2)\n",
|
|||
|
|
"\n",
|
|||
|
|
"test_gen = ImageDataGenerator()\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_generator = gen.flow(X_train, Y_train, batch_size=128, subset='training')\n",
|
|||
|
|
"valid_generator = gen.flow(X_train, Y_train, batch_size=128, subset='validation')\n",
|
|||
|
|
"test_generator = test_gen.flow(X_test, Y_test, batch_size=128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 48,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "byGKttIXDfE3",
|
|||
|
|
"outputId": "0f8b64ce-1306-4187-b53a-5255a36b2711"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model.fit_generator(train_generator, steps_per_epoch=25000//128, epochs=5, verbose=1, validation_data=valid_generator, validation_steps = 5000 // 128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 49,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/5\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stderr",
|
|||
|
|
"text": [
|
|||
|
|
"<ipython-input-49-1bedfa515cd7>:1: UserWarning: `Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`, which supports generators.\n",
|
|||
|
|
" model.fit_generator(train_generator, steps_per_epoch=25000//128, epochs=5, verbose=1, validation_data=valid_generator, validation_steps = 5000 // 128)\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"191/195 [============================>.] - ETA: 0s - loss: 0.2532 - accuracy: 0.9235"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stderr",
|
|||
|
|
"text": [
|
|||
|
|
"WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 975 batches). You may need to use the repeat() function when building your dataset.\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r195/195 [==============================] - 11s 43ms/step - loss: 0.2527 - accuracy: 0.9238 - val_loss: 0.0744 - val_accuracy: 0.9784\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<keras.src.callbacks.History at 0x7f0d801654e0>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 49
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"id": "3sUGcowEMLHk",
|
|||
|
|
"outputId": "c94d3f21-d4d1-46b8-841d-e918e87c8217"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"test_model(model, X_test, Y_test, y_test, digits)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 50,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"161/161 [==============================] - 1s 5ms/step - loss: 0.0275 - accuracy: 0.9930\n",
|
|||
|
|
"Test score: 0.02751018851995468\n",
|
|||
|
|
"Test accuracy: 0.9929947257041931\n",
|
|||
|
|
"161/161 [==============================] - 1s 3ms/step\n",
|
|||
|
|
"Confusion matrix, without normalization\n",
|
|||
|
|
"[[ 977 0 2 0 1]\n",
|
|||
|
|
" [ 0 1125 4 4 2]\n",
|
|||
|
|
" [ 5 0 1015 9 3]\n",
|
|||
|
|
" [ 0 0 3 1007 0]\n",
|
|||
|
|
" [ 1 0 1 1 979]]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 2 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1UAAAN6CAYAAABmBWMlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB5rUlEQVR4nOzdd3gUZdfH8d9sIIWQQktCJxQpSpEiRukgSFEQRXlECYioSBFQQB4FARUUG4I0G0VFiooKIkov0kF8EBClKQpJwEACoSQk+/4Rsq9LQBImYWbC95NrLt2Z2dkz2SHZk3Puewy32+0WAAAAAOCquKwOAAAAAACcjKQKAAAAAEwgqQIAAAAAE0iqAAAAAMAEkioAAAAAMIGkCgAAAABMIKkCAAAAABNIqgAAAADAhHxWBwAAAADg6pw9e1bJyclWh5Elvr6+8vf3tzqMXEFSBQAAADjQ2bNnFRBURDp/2upQsiQiIkIHDhzIk4kVSRUAAADgQMnJydL50/KrFi35+Fodzr9LTVbMrhlKTk4mqQIAAABgMz6+MmyeVLmtDiCXkVQBAAAATma40hc7s3t8JuXtswMAAACAXEZSBQAAAAAmkFQBAAAAgAmMqQIAAACczJBkGFZH8e9sHp5ZVKoAAAAAwASSKgAAAAAwgfY/AAAAwMmYUt1yefvsAAAAACCXkVQBAAAAgAm0/wEAAABOZhgOmP3P5vGZRKUKAAAAAEwgqQIAAAAAE0iqAAAAAMAExlQBAAAATsaU6pbL22cHAAAAALmMpAoAAAAATKD9DwAAAHAyplS3HJUqAAAAADCBpAoAAAAATCCpAgAAAAATGFMFAAAAOJoDplTP47WcvH12AAAAAJDLSKoAAAAAwATa/wAAAAAnY0p1y1GpAgAAAAATSKoAAAAAwATa/wAAAAAnMxww+5/d4zMpb58dAAAAAOQykioAAAAAMIGkCgAAAABMYEwVAAAA4GRMqW45KlUAAAAAYAJJFQAAAACYQPsfAAAA4GRMqW65vH12AAAAAJDLSKoAAAAAwATa/wAAAAAnY/Y/y1GpAgAAAAATSKoAAAAAwASSKgAAAAAwgTFVAAAAgJMxpbrl8vbZAQAAAEAuI6kCAAAAABNo/wMAAACczDDs317HlOoAAAAAgMshqQIAAAAAE2j/AwAAAJzMZaQvdmb3+EyiUgUAAAAAJpBUAQAAAIAJJFUAAAAAYAJjqgAAAAAnM1wOmFLd5vGZlLfPDgAAAAByGUkVAAAAAJhA+x8AAADgZIaRvtiZ3eMziUoVAAAAAJhAUgUAAAAAJpBUAQAAAIAJjKkCAAAAnIwp1S2Xt88OAAAAAHIZSRUAAAAAmED7HwAAAOBkTKluOSpVAAAAAGACSRUAAAAAmED7HwAAAOBkzP5nubx9dgAAAACQy0iqAAAAAMAEkioAAAAAMIExVQAAAICTMaW65ahUAQAAAIAJJFUAAAAAYALtfwAAAICTMaW65fL22QEAAABALiOpAvKw3377TS1btlRISIgMw9CXX36Zo8c/ePCgDMPQ9OnTc/S4eUG5cuXUrVs3q8PIJDvvWca+r7/+eu4HhksaMWKEjIsGd1t1bdn1mgYAOyCpAnLZvn379Pjjj6t8+fLy9/dXcHCwbr/9dr399ts6c+ZMrr52dHS0duzYoZdfflkfffSR6tatm6uvlxft2rVLI0aM0MGDB60OJdcsWrRII0aMsDqMTEaPHp3jfwjAv1u3bp1GjBihEydOWB0KgOzImP3P7ksexpgqIBd988036tSpk/z8/NS1a1fddNNNSk5O1tq1azVo0CDt3LlT7777bq689pkzZ7R+/Xo999xz6tOnT668RtmyZXXmzBnlz58/V45vB7t27dLIkSPVpEkTlStXLsvP27Nnj1wu+/3d6lLv2aJFizRx4kTbJVajR4/Wfffdpw4dOlgdiq3k5rW1bt06jRw5Ut26dVNoaOg1e10AcDqSKiCXHDhwQJ07d1bZsmW1fPlyFS9e3LOtd+/e2rt3r7755ptce/2jR49KUqYPRjnJMAz5+/vn2vGdxu126+zZswoICJCfn5/V4VwS75k5SUlJCgwMtDQGq64tu17TAGAH/MkJyCVjx47VqVOn9MEHH3glVBkqVqyop556yvP4/PnzevHFF1WhQgX5+fmpXLly+u9//6tz5855Pa9cuXJq166d1q5dq1tuuUX+/v4qX768Zs6c6dlnxIgRKlu2rCRp0KBBMgzDU2Xp1q3bJSsulxq7sWTJEjVo0EChoaEqWLCgKleurP/+97+e7Zcbn7N8+XI1bNhQgYGBCg0NVfv27bV79+5Lvt7evXs9fxUPCQlR9+7ddfr06ct/Yy9o0qSJbrrpJv3vf/9T48aNVaBAAVWsWFGfffaZJGnVqlWqX7++AgICVLlyZS1dutTr+b///ruefPJJVa5cWQEBASpSpIg6derk1eY3ffp0derUSZLUtGlTGYYhwzC0cuVKSf//Xnz33XeqW7euAgICNHXqVM+2jPEnbrdbTZs2VbFixRQXF+c5fnJysqpXr64KFSooKSnpiuf8TwMHDlSRIkXkdrs96/r27SvDMDR+/HjPutjYWBmGocmTJ0vK/J5169ZNEydOlCTP+V18HUjSu+++67k269Wrp82bN2faJyvve1avP8MwlJSUpBkzZnhi+rfxPCtXrpRhGJo7d65efvlllSpVSv7+/mrevLn27t2baf958+apTp06CggIUNGiRfXQQw/pr7/+yhRrwYIFtW/fPrVp00ZBQUHq0qWLJ74+ffpo3rx5qlatmgICAhQVFaUdO3ZIkqZOnaqKFSvK399fTZo0ydQ+umbNGnXq1EllypSRn5+fSpcurQEDBmSpJfjisU3/fN8uXjJe93//+5+6devmaUOOiIjQI488or///tvrPRg0aJAkKTIyMtMxLjWmav/+/erUqZMKFy6sAgUK6NZbb830x6LsvjcA4ERUqoBcsmDBApUvX1633XZblvZ/9NFHNWPGDN133316+umntXHjRo0ZM0a7d+/W/Pnzvfbdu3ev7rvvPvXo0UPR0dH68MMP1a1bN9WpU0c33nijOnbsqNDQUA0YMED/+c9/1KZNGxUsWDBb8e/cuVPt2rVTjRo1NGrUKPn5+Wnv3r364Ycf/vV5S5cuVevWrVW+fHmNGDFCZ86c0YQJE3T77bdr27ZtmT5Q33///YqMjNSYMWO0bds2vf/++woLC9Orr756xRiPHz+udu3aqXPnzurUqZMmT56szp0765NPPlH//v31xBNP6MEHH9Rrr72m++67T4cOHVJQUJAkafPmzVq3bp06d+6sUqVK6eDBg5o8ebKaNGmiXbt2qUCBAmrUqJH69eun8ePH67///a+qVq0qSZ7/SuktUf/5z3/0+OOPq2fPnqpcuXKmOA3D0IcffqgaNWroiSee0BdffCFJeuGFF7Rz506tXLky29WPhg0b6q233tLOnTt10003SUr/oO5yubRmzRr169fPs06SGjVqdMnjPP744zp8+LCWLFmijz766JL7zJo1SydPntTjjz8uwzA0duxYdezYUfv37/e0EWb3fb+Sjz76SI8++qhuueUWPfbYY5KkChUqXPF5r7zyilwul5555hklJCRo7Nix6tKlizZu3OjZZ/r06erevbvq1aunMWPGKDY2Vm+//bZ++OEH/fjjj17V3fPnz6tVq1Zq0KCBXn/9dRUoUMCzbc2aNfr666/Vu3dvSdKYMWPUrl07DR48WJMmTdKTTz6p48ePa+zYsXrkkUe0fPlyz3PnzZun06dPq1evXipSpIg2bdqkCRMm6M8//9S8efOy/b262PPPP6+4uDjPv/slS5Zo//796t69uyIiIjytxzt37tSGDRtkGIY6duyoX3/9VZ9++qneeustFS1aVJJUrFixS75ubGysbrvtNp0+fVr9+vVTkSJFNGPGDN1999367LPPdM8992T7vQFwtRwwpXper+W4AeS4hIQEtyR3+/bts7T/9u3b3ZLcjz76qNf6Z555xi3JvXz5cs+6smXLuiW5V69e7VkXFxfn9vPzc
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "0fkbvsjbiIi5"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Transfer wiedzy do nowego modelu: cyfry 5..9"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "MfJ3UmNrMwfC"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Extract specified range of layers to a new sequential model\n",
|
|||
|
|
"def extract_layers(main_model, starting_layer_ix, ending_layer_ix):\n",
|
|||
|
|
" # create an empty model\n",
|
|||
|
|
" new_model = Sequential()\n",
|
|||
|
|
" for ix in range(starting_layer_ix, ending_layer_ix + 1):\n",
|
|||
|
|
" curr_layer = main_model.get_layer(index=ix)\n",
|
|||
|
|
" # copy this layer over to the new model\n",
|
|||
|
|
" new_model.add(curr_layer)\n",
|
|||
|
|
" return new_model"
|
|||
|
|
],
|
|||
|
|
"execution_count": 51,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "aAY_aLnOiTct"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Nowy model stworzony jest na bazie wytrenowanego modelu bazowego. Część odpowiedzialna za detekcję cech w obrazie (warstwy splotowe) są wykorzystywane jako podstawa dla nowej sieci. Część klasyfikacyjna jest dodawana na nowo, z losowo zainicjowanymi wagami.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Warstwy splotowe były wstępnie wytrenowane w modelu bazowym, więc są zamrażane. Treningowi podlega tylko część klasyfikacyjna."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "qzDfXKIMNXMK",
|
|||
|
|
"outputId": "21ec04f6-f91c-47c2-f616-989d5da80821"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"new_model = extract_layers(model, 0, 9)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Fully Connected Layer\n",
|
|||
|
|
"new_model.add(Dense(64))\n",
|
|||
|
|
"new_model.add(Activation('relu'))\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Fully Connected Layer\n",
|
|||
|
|
"new_model.add(Dense(5))\n",
|
|||
|
|
"new_model.add(Activation('softmax'))\n",
|
|||
|
|
"\n",
|
|||
|
|
"for ix in range(0, 9+1):\n",
|
|||
|
|
" new_model.get_layer(index=ix).trainable=False\n",
|
|||
|
|
"\n",
|
|||
|
|
"new_digits = [5, 6, 7, 8, 9]\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train, y_train, X_test, y_test = get_mnist(new_digits)\n",
|
|||
|
|
"\n",
|
|||
|
|
"Y_train = one_hot(y_train)\n",
|
|||
|
|
"Y_test = one_hot(y_test)\n",
|
|||
|
|
"\n",
|
|||
|
|
"new_model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.legacy.Adam(), metrics=['accuracy'])\n",
|
|||
|
|
"\n",
|
|||
|
|
"new_model.summary()"
|
|||
|
|
],
|
|||
|
|
"execution_count": 52,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Model: \"sequential_3\"\n",
|
|||
|
|
"_________________________________________________________________\n",
|
|||
|
|
" Layer (type) Output Shape Param # \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
" conv2d_2 (Conv2D) (None, 28, 28, 16) 160 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_7 (Activation) (None, 28, 28, 16) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_2 (MaxPoolin (None, 14, 14, 16) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_3 (Conv2D) (None, 14, 14, 32) 4640 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_8 (Activation) (None, 14, 14, 32) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_3 (MaxPoolin (None, 7, 7, 32) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_4 (Conv2D) (None, 7, 7, 64) 18496 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_9 (Activation) (None, 7, 7, 64) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_4 (MaxPoolin (None, 3, 3, 64) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" flatten_1 (Flatten) (None, 576) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_7 (Dense) (None, 64) 36928 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_12 (Activation) (None, 64) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_8 (Dense) (None, 5) 325 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_13 (Activation) (None, 5) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
"Total params: 60549 (236.52 KB)\n",
|
|||
|
|
"Trainable params: 37253 (145.52 KB)\n",
|
|||
|
|
"Non-trainable params: 23296 (91.00 KB)\n",
|
|||
|
|
"_________________________________________________________________\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "quqhC141PBTO"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,\n",
|
|||
|
|
" height_shift_range=0.08, zoom_range=0.08, validation_split=0.2)\n",
|
|||
|
|
"\n",
|
|||
|
|
"test_gen = ImageDataGenerator()\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_generator = gen.flow(X_train, Y_train, batch_size=128, subset='training')\n",
|
|||
|
|
"valid_generator = gen.flow(X_train, Y_train, batch_size=128, subset='validation')\n",
|
|||
|
|
"test_generator = test_gen.flow(X_test, Y_test, batch_size=128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 53,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "oqy-_NxblqDy"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Trening nowego modelu"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "l0H5XRxWPKvC",
|
|||
|
|
"outputId": "4ea974a3-fee9-4bfa-a584-7f40b0cd5d2f"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"new_model.fit_generator(train_generator, steps_per_epoch=24000//128, epochs=5, verbose=1, validation_data=valid_generator, validation_steps = 6000 // 128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 54,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/5\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stderr",
|
|||
|
|
"text": [
|
|||
|
|
"<ipython-input-54-05ca407e5c2c>:1: UserWarning: `Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`, which supports generators.\n",
|
|||
|
|
" new_model.fit_generator(train_generator, steps_per_epoch=24000//128, epochs=5, verbose=1, validation_data=valid_generator, validation_steps = 6000 // 128)\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"183/187 [============================>.] - ETA: 0s - loss: 0.4288 - accuracy: 0.8646"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stderr",
|
|||
|
|
"text": [
|
|||
|
|
"WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 935 batches). You may need to use the repeat() function when building your dataset.\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r187/187 [==============================] - 11s 51ms/step - loss: 0.4275 - accuracy: 0.8651 - val_loss: 0.2110 - val_accuracy: 0.9355\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<keras.src.callbacks.History at 0x7f0cfa1a8b80>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 54
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"id": "FRuh3MN-PeDr",
|
|||
|
|
"outputId": "c8a9d191-d9d8-4d96-8202-f20f6d69580c"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"test_model(new_model, X_test, Y_test, y_test, new_digits)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 55,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"152/152 [==============================] - 1s 4ms/step - loss: 0.1271 - accuracy: 0.9636\n",
|
|||
|
|
"Test score: 0.1270837038755417\n",
|
|||
|
|
"Test accuracy: 0.963587760925293\n",
|
|||
|
|
"152/152 [==============================] - 0s 3ms/step\n",
|
|||
|
|
"Confusion matrix, without normalization\n",
|
|||
|
|
"[[865 3 3 14 7]\n",
|
|||
|
|
" [ 3 946 0 8 1]\n",
|
|||
|
|
" [ 0 1 968 8 51]\n",
|
|||
|
|
" [ 8 2 9 930 25]\n",
|
|||
|
|
" [ 10 5 11 8 975]]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 2 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0wAAAN6CAYAAAC9vskHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB/jklEQVR4nOzdd3gU1dvG8Xs2gSRACi0JkRZAOlhAMVQVpKOIoiBIE7BQBASEn9IFpCOIIKIUARGwggjSi3SQjvQmLUAgoSaQ7PtHZF9jmJgom5mF7yfXXLqzs7v3ZodknzznnDWcTqdTAAAAAIBkHFYHAAAAAAC7omACAAAAABMUTAAAAABggoIJAAAAAExQMAEAAACACQomAAAAADBBwQQAAAAAJiiYAAAAAMCEt9UBAAAAAPw7N27cUFxcnNUxUiVjxozy9fW1OkaaUTABAAAAHujGjRvy888u3bpmdZRUCQ0N1ZEjRzyuaKJgAgAAADxQXFycdOuafIo3l7wyWh0nZfFxOrNnquLi4iiYAAAAAKQjr4wybF4wOa0O8B9QMAEAAACezHAkbnZm93wp8NzkAAAAAOBmFEwAAAAAYIKCCQAAAABMMIcJAAAA8GSGJMOwOkXKbB4vJXSYAAAAAMAEBRMAAAAAmGBIHgAAAODJWFbcrTw3OQAAAAC4GQUTAAAAAJhgSB4AAADgyQzDA1bJs3m+FNBhAgAAAAATFEwAAAAAYIKCCQAAAABMMIcJAAAA8GQsK+5WnpscAAAAANyMggkAAAAATDAkDwAAAPBkLCvuVnSYAAAAAMAEBRMAAAAAmKBgAgAAAAATzGECAAAAPJoHLCvuwX0az00OAAAAAG5GwQQAAAAAJhiSBwAAAHgylhV3KzpMAAAAAGCCggkAAAAATDAkDwAAAPBkhgeskmf3fCnw3OQAAAAA4GYUTAAAAABggoIJAAAAAEwwhwkAAADwZCwr7lZ0mAAAAADABAUTAAAAAJhgSB4AAADgyVhW3K08NzkAAAAAuBkFEwAAAACYYEgeAAAA4MlYJc+t6DABAAAAgAkKJgAAAAAwQcEEAAAAACaYwwQAAAB4MpYVdyvPTQ4AAAAAbkbBBAAAAAAmGJIHAAAAeDLDsP+QN5YVBwAAAIB7DwUTAAAAAJhgSB4AAADgyRxG4mZnds+XAjpMAAAAAGCCggkAAAAATFAwAQAAAIAJ5jABAAAAnsxweMCy4jbPlwLPTQ4AAAAAbkbBBAAAAAAmGJIHAAAAeDLDSNzszO75UkCHCQAAAABMUDABAAAAgAkKJgAAAAAwwRwmAAAAwJOxrLhbeW5yAAAAAHAzCiYAAAAAMMGQPAAAAMCTsay4W9FhAgAAAAATFEwAAAAAYIIheQAAAIAnY5U8t/Lc5AAAAADgZhRMAAAAAGCCggkAAAAATDCHCQAAAPBkLCvuVnSYAAAAAMAEBRMAAAAAmGBIHgAAAODJWFbcrTw3OQAAAAC4GQUTcA87cOCAqlevrsDAQBmGoe+///6u3v/Ro0dlGIamTJlyV+/3XpA/f361aNHC6hjJpOU1u33s8OHD3R8Md9S3b18Zf5sobdW5ZddzGgDcjYIJcLNDhw7p9ddfV4ECBeTr66uAgABVqFBBH330ka5fv+7Wx27evLl27typgQMH6ssvv1TZsmXd+nj3oj179qhv3746evSo1VHcZsGCBerbt6/VMZIZNGjQXS/ykbK1a9eqb9++unTpktVRAKTF7VXy7L55KOYwAW70008/qWHDhvLx8VGzZs1UsmRJxcXFac2aNerWrZt2796tiRMnuuWxr1+/rnXr1um9995T+/bt3fIY+fLl0/Xr15UhQwa33L8d7NmzR/369dOTTz6p/Pnzp/p2+/btk8Nhv79J3ek1W7BggcaNG2e7omnQoEF68cUXVb9+fauj2Io7z621a9eqX79+atGihYKCgtLtcQHAziiYADc5cuSIGjVqpHz58mnZsmXKlSuX67p27drp4MGD+umnn9z2+OfOnZOkZG967ibDMOTr6+u2+/c0TqdTN27ckJ+fn3x8fKyOc0e8Zv/N1atXlTlzZkszWHVu2fWcBgB3409FgJsMHTpUV65c0eeff56kWLqtUKFCevvtt12Xb926pQEDBqhgwYLy8fFR/vz59b///U+xsbFJbpc/f37VrVtXa9as0eOPPy5fX18VKFBA06ZNcx3Tt29f5cuXT5LUrVs3GYbh6o60aNHijp2SO82VWLx4sSpWrKigoCBlyZJFRYoU0f/+9z/X9WbzYZYtW6ZKlSopc+bMCgoK0nPPPae9e/fe8fEOHjzo+mt2YGCgWrZsqWvXrpl/Y//05JNPqmTJktqxY4eqVKmiTJkyqVChQpo7d64kaeXKlSpXrpz8/PxUpEgRLVmyJMntjx07prfeektFihSRn5+fsmfProYNGyYZejdlyhQ1bNhQkvTUU0/JMAwZhqEVK1ZI+v/XYtGiRSpbtqz8/Pz06aefuq67Pd/D6XTqqaeeUs6cORUZGem6/7i4OJUqVUoFCxbU1atX//E5/1WXLl2UPXt2OZ1O174OHTrIMAyNGTPGte/s2bMyDEPjx4+XlPw1a9GihcaNGydJruf39/NAkiZOnOg6Nx977DFt2rQp2TGped1Te/4ZhqGrV69q6tSprkwpzZ9ZsWKFDMPQ7NmzNXDgQOXOnVu+vr6qWrWqDh48mOz4OXPmqEyZMvLz81OOHDnUtGlTnTx5MlnWLFmy6NChQ6pdu7b8/f3VpEkTV7727dtrzpw5Kl68uPz8/BQREaGdO3dKkj799FMVKlRIvr6+evLJJ5MN6Vy9erUaNmyovHnzysfHR3ny5FHnzp1TNUz373OJ/vq6/X27/bg7duxQixYtXEODQ0ND1apVK124cCHJa9CtWzdJUnh4eLL7uNMcpsOHD6thw4bKli2bMmXKpCeeeCLZH4LS+toAgN3QYQLcZN68eSpQoIDKly+fquNbt26tqVOn6sUXX9Q777yjDRs2aPDgwdq7d6++++67JMcePHhQL774ol577TU1b95cX3zxhVq0aKEyZcqoRIkSatCggYKCgtS5c2c1btxYtWvXVpYsWdKUf/fu3apbt65Kly6t/v37y8fHRwcPHtSvv/6a4u2WLFmiWrVqqUCBAurbt6+uX7+usWPHqkKFCtq6dWuyN8svvfSSwsPDNXjwYG3dulWTJk1ScHCwhgwZ8o8ZL168qLp166pRo0Zq2LChxo8fr0aNGmnGjBnq1KmT3njjDb3yyisaNmyYXnzxRZ04cUL+/v6SpE2bNmnt2rVq1KiRcufOraNHj2r8+PF68skntWfPHmXKlEmVK1dWx44dNWbMGP3vf/9TsWLFJMn1XylxmFLjxo31+uuvq02bNipSpEiynIZh6IsvvlDp0qX1xhtv6Ntvv5Uk9enTR7t379aKFSvS3LWoVKmSRo0apd27d6tkyZKSEt+EOxwOrV69Wh07dnTtk6TKlSvf8X5ef/11nTp1SosXL9aXX355x2Nmzpypy5cv6/XXX5dhGBo6dKgaNGigw4cPu4b2pfV1/ydffvmlWrdurccff1xt27aVJBUsWPAfb/fhhx/K4XCoa9euio6O1tChQ9WkSRNt2LDBdcyUKVPUsmVLPfbYYxo8eLDOnj2rjz76SL/++qt+++23JF3ZW7duqUaNGqpYsaKGDx+uTJkyua5bvXq1fvzxR7Vr106SNHjwYNWtW1fdu3fXJ598orfeeksXL17U0KFD1apVKy1btsx12zlz5ujatWt68803lT17dm3cuFFjx47VH3/8oTlz5qT5e/V377//viIjI13/7hcvXqzDhw+rZcuWCg0NdQ0H3r17t9avXy/DMNSgQQPt379fX331lUaNGqUcOXJIknLmzHnHxz179qzKly+va9euqWPHjsqePbumTp2qZ599VnPnztXzzz+f5tcGwL/lAcuKe3KfxgngrouOjnZKcj733HOpOn7btm1OSc7WrVsn2d+1a1enJOeyZctc+/Lly+eU5Fy1apVrX2RkpNPHx8f5zjvvuPYdOXLEKck5bNiwJPfZv
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "2zgQd-nJjU2F"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Fine-tuning"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "7kctU-jel8k-"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Po wytrenowaniu części klasyfikacyjnej można ponownie odmrozić wszystkie warstwy i przeprowadzić dotrenowanie pełnej sieci na nowych danych, bez ryzyka związanego z dużymi gradientami w pierwszych fazach uczenia."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "yu4ks06pR1aO",
|
|||
|
|
"outputId": "36451d32-572f-4ca4-f1ee-8478e600ab68"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"for ix in range(0, 9+1):\n",
|
|||
|
|
" new_model.get_layer(index=ix).trainable=True\n",
|
|||
|
|
"\n",
|
|||
|
|
"new_model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.legacy.Adam(), metrics=['accuracy'])\n",
|
|||
|
|
"new_model.summary()"
|
|||
|
|
],
|
|||
|
|
"execution_count": 56,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Model: \"sequential_3\"\n",
|
|||
|
|
"_________________________________________________________________\n",
|
|||
|
|
" Layer (type) Output Shape Param # \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
" conv2d_2 (Conv2D) (None, 28, 28, 16) 160 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_7 (Activation) (None, 28, 28, 16) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_2 (MaxPoolin (None, 14, 14, 16) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_3 (Conv2D) (None, 14, 14, 32) 4640 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_8 (Activation) (None, 14, 14, 32) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_3 (MaxPoolin (None, 7, 7, 32) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_4 (Conv2D) (None, 7, 7, 64) 18496 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_9 (Activation) (None, 7, 7, 64) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_4 (MaxPoolin (None, 3, 3, 64) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" flatten_1 (Flatten) (None, 576) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_7 (Dense) (None, 64) 36928 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_12 (Activation) (None, 64) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_8 (Dense) (None, 5) 325 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_13 (Activation) (None, 5) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
"Total params: 60549 (236.52 KB)\n",
|
|||
|
|
"Trainable params: 60549 (236.52 KB)\n",
|
|||
|
|
"Non-trainable params: 0 (0.00 Byte)\n",
|
|||
|
|
"_________________________________________________________________\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "g9x5tXRzS9yE",
|
|||
|
|
"outputId": "4b9b67da-1fad-4b25-e63d-f489278bb706"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"new_model.fit_generator(train_generator, steps_per_epoch=24000//128, epochs=2, verbose=1, validation_data=valid_generator, validation_steps = 6000 // 128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 57,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/2\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stderr",
|
|||
|
|
"text": [
|
|||
|
|
"<ipython-input-57-5ed88bbb5f18>:1: UserWarning: `Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`, which supports generators.\n",
|
|||
|
|
" new_model.fit_generator(train_generator, steps_per_epoch=24000//128, epochs=2, verbose=1, validation_data=valid_generator, validation_steps = 6000 // 128)\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"184/187 [============================>.] - ETA: 0s - loss: 0.1569 - accuracy: 0.9506"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stderr",
|
|||
|
|
"text": [
|
|||
|
|
"WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 374 batches). You may need to use the repeat() function when building your dataset.\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r187/187 [==============================] - 9s 44ms/step - loss: 0.1569 - accuracy: 0.9506 - val_loss: 0.0890 - val_accuracy: 0.9716\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<keras.src.callbacks.History at 0x7f0d7834a3e0>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 57
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"id": "pbfKwpJ9TI-e",
|
|||
|
|
"outputId": "cd1cef33-22c7-42f1-e65e-56abdf5c6610"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"test_model(new_model, X_test, Y_test, y_test, new_digits)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 58,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"152/152 [==============================] - 0s 3ms/step - loss: 0.0513 - accuracy: 0.9842\n",
|
|||
|
|
"Test score: 0.051340196281671524\n",
|
|||
|
|
"Test accuracy: 0.9841596484184265\n",
|
|||
|
|
"152/152 [==============================] - 0s 2ms/step\n",
|
|||
|
|
"Confusion matrix, without normalization\n",
|
|||
|
|
"[[ 880 1 2 6 3]\n",
|
|||
|
|
" [ 5 946 0 6 1]\n",
|
|||
|
|
" [ 1 0 1012 5 10]\n",
|
|||
|
|
" [ 9 0 4 957 4]\n",
|
|||
|
|
" [ 7 2 6 5 989]]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 2 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1UAAAN6CAYAAABmBWMlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACAIklEQVR4nOzde3zO9f/H8ednYwdjm9M2cxxyJjmkOSaLRCVSROZcOYUQ33IshA4OyTGnpKJCJHIMOStyyvmUjBibYRvb9ftDrp8rqs1n8/lc87jv9rnV9bk+13U9P7su2/Xa6/1+X4bD4XAIAAAAAHBXPKwOAAAAAADujKIKAAAAAEygqAIAAAAAEyiqAAAAAMAEiioAAAAAMIGiCgAAAABMoKgCAAAAABMoqgAAAADAhExWBwAAAABwd+Lj45WYmGh1jBTx8vKSj4+P1THSBUUVAAAA4Ibi4+Plmy2ndP2K1VFSJCQkREePHs2QhRVFFQAAAOCGEhMTpetX5F0qUvL0sjrOv0tKVNTemUpMTKSoAgAAAGAznl4ybF5UOawOkM4oqgAAAAB3Znjc2OzM7vlMythnBwAAAADpjKIKAAAAAEygqAIAAABgO2vXrtVTTz2l0NBQGYahBQsWuFzvcDg0YMAA5cmTR76+voqIiNDBgwddjomOjlaLFi3k7++vwMBAtWvXTnFxcS7H/Prrr6pRo4Z8fHyUP39+jRw5MtVZKaoAAAAAd2ZIMgybb6k/rcuXL+vBBx/U+PHj73j9yJEjNXbsWE2cOFGbN2+Wn5+f6tWrp/j4eOcxLVq00J49e7R8+XItXrxYa9euVceOHZ3Xx8bGqm7duipYsKC2b9+uUaNGadCgQZo8eXLqngKHw5HRF+MAAAAAMpzY2FgFBATI+8GXZXh6Wx3nXzmSEpSwc5JiYmLk7++f6tsbhqH58+erUaNGN+7P4VBoaKhef/119erVS5IUExOj4OBgzZgxQ82aNdO+fftUqlQpbd26VZUqVZIkLV26VE8++aR+//13hYaGasKECXrzzTcVFRUlL68bKyj27dtXCxYs0G+//ZbifHSqAAAAANwTsbGxLltCQsJd3c/Ro0cVFRWliIgI576AgABVqVJFGzdulCRt3LhRgYGBzoJKkiIiIuTh4aHNmzc7j6lZs6azoJKkevXqaf/+/bpw4UKK81BUAQAAAO7s5pLqdt8k5c+fXwEBAc5t+PDhd3XKUVFRkqTg4GCX/cHBwc7roqKiFBQU5HJ9pkyZlCNHDpdj7nQftz5GSvA5VQAAAADuiZMnT7oM//P2tvewxZSiUwUAAADgnvD393fZ7raoCgkJkSSdOXPGZf+ZM2ec14WEhOjs2bMu11+/fl3R0dEux9zpPm59jJSgqAIAAADcmeUr+6VwS0NhYWEKCQnRypUrnftiY2O1efNmhYeHS5LCw8N18eJFbd++3XnMqlWrlJycrCpVqjiPWbt2ra5du+Y8Zvny5SpevLiyZ8+e4jwUVQAAAABsJy4uTjt27NCOHTsk3VicYseOHTpx4oQMw1D37t31zjvv6Ntvv9WuXbvUqlUrhYaGOlcILFmypJ544gl16NBBW7Zs0U8//aQuXbqoWbNmCg0NlSS9+OKL8vLyUrt27bRnzx59+eWXGjNmjHr27JmqrMypAgAAAGA727ZtU+3atZ2XbxY6kZGRmjFjhvr06aPLly+rY8eOunjxoqpXr66lS5fKx8fHeZvPPvtMXbp0UZ06deTh4aEmTZpo7NixzusDAgL0ww8/qHPnzqpYsaJy5cqlAQMGuHyWVUrwOVUAAACAG3J+TtVDndzjc6p++fiuP6fK7uhUAQAAAO7sliXLbcvu+UzK2GcHAAAAAOmMogoAAAAATGD4HwAAAODO0mHJ8jRn93wm0akCAAAAABMoqgAAAADABIoqAAAAADCBOVUAAACAW3ODJdUzeC8nY58dAAAAAKQziioAAAAAMIHhfwAAAIA7Y0l1y9GpAgAAAAATKKoAAAAAwASG/wEAAADuzHCD1f/sns+kjH12AAAAAJDOKKoAAAAAwASKKgAAAAAwgTlVAAAAgDtjSXXL0akCAAAAABMoqgAAAADABIb/AQAAAO6MJdUtl7HPDgAAAADSGUUVAAAAAJjA8D8AAADAnbH6n+XoVAEAAACACRRVAAAAAGACRRUAAAAAmMCcKgAAAMCdsaS65TL22QEAAABAOqOoAgAAAAATGP4HAAAAuDPDsP/wOpZUBwAAAAD8E4oqAAAAADCB4X8AAACAO/Mwbmx2Zvd8JtGpAgAAAAATKKoAAAAAwASKKgAAAAAwgTlVAAAAgDszPNxgSXWb5zMpY58dAAAAAKQziioAAAAAMIHhfwAAAIA7M4wbm53ZPZ9JdKoAAAAAwASKKgAAAAAwgaIKAAAAAExgThUAAADgzlhS3XIZ++wAAAAAIJ1RVAEAAACACQz/AwAAANwZS6pbjk4VAAAAAJhAUQUAAAAAJjD8DwAAAHBnrP5nuYx9dgAAAACQziiqAAAAAMAEiioAAAAAMIE5VQAAAIA7Y0l1y9GpAgAAAAATKKoAAAAAwASG/wEAAADujCXVLZexzw4AAAAA0hlFFZCBHTx4UHXr1lVAQIAMw9CCBQvS9P6PHTsmwzA0Y8aMNL3fjKBQoUJq3bq11TFuk5rn7Oax7733XvoHwx0NGjRIxt8md1v12rLraxoA7ICiCkhnhw8f1ssvv6zChQvLx8dH/v7+qlatmsaMGaOrV6+m62NHRkZq165dGjp0qD799FNVqlQpXR8vI9q7d68GDRqkY8eOWR0l3SxZskSDBg2yOsZthg0bluZ/CMC/27BhgwYNGqSLFy9aHQVAatxc/c/uWwbGnCogHX333Xdq2rSpvL291apVK5UpU0aJiYlav369evfurT179mjy5Mnp8thXr17Vxo0b9eabb6pLly7p8hgFCxbU1atXlTlz5nS5fzvYu3evBg8erEcffVSFChVK8e32798vDw/7/d3qTs/ZkiVLNH78eNsVVsOGDdNzzz2nRo0aWR3FVtLztbVhwwYNHjxYrVu3VmBg4D17XABwdxRVQDo5evSomjVrpoIFC2rVqlXKkyeP87rOnTvr0KFD+u6779Lt8f/8809Juu2NUVoyDEM+Pj7pdv/uxuFwKD4+Xr6+vvL29rY6zh3xnJlz+fJl+fn5WZrBqteWXV/TAGAH/MkJSCcjR45UXFycPvnkE5eC6qaiRYvqtddec16+fv263n77bRUpUkTe3t4qVKiQ/ve//ykhIcHldoUKFVLDhg21fv16Pfzww/Lx8VHhwoU1a9Ys5zGDBg1SwYIFJUm9e/eWYRjOLkvr1q3v2HG509yN5cuXq3r16goMDFTWrFlVvHhx/e9//3Ne/0/zc1atWqUaNWrIz89PgYGBeuaZZ7Rv3747Pt6hQ4ecfxUPCAhQmzZtdOXKlX/+xv7l0UcfVZkyZfTrr7+qVq1aypIli4oWLaqvvvpKkvTjjz+qSpUq8vX1VfHixbVixQqX2x8/flydOnVS8eLF5evrq5w5c6pp06Yuw/xmzJihpk2bSpJq164twzBkGIbWrFkj6f+fi2XLlqlSpUry9fXVpEmTnNfdnH/icDhUu3Zt5c6dW2fPnnXef2JiosqWLasiRYro8uXL/3nOt+rZs6dy5swph8Ph3Ne1a1cZhqGxY8c69505c0aGYWjChAmSbn/OWrdurfHjx0uS8/z+/jqQpMmTJztfm5UrV9bWrVtvOyYlz3tKX3+GYejy5cuaOXOmM9O/zedZs2aNDMPQ3LlzNXToUOXLl08+Pj6qU6eODh06dNvx8+bNU8WKFeXr66tcuXKpZcuWOnXq1G1Zs2bNqsOHD+vJJ59UtmzZ1KJFC2e+Ll26aN68eSpVqpR8fX0VHh6uXbt2SZImTZqkokWLysfHR48++uhtw0fXrVunpk2bqkCBAvL29lb+/PnVo0ePFA0J/vvcpluft79vNx/3119/VevWrZ3DkENCQtS2bVudP3/e5Tno3bu3JCksLOy2+7jTnKojR46oadOmypEjh7JkyaJHHnnkt
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "e6ssF6eHeHro"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Test on different dataset"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "s00LerSteLv5",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "9906bb92-ad9d-4521-9433-6a5a46257042"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"!wget https://raw.githubusercontent.com/googlecreativelab/quickdraw-dataset/master/categories.txt"
|
|||
|
|
],
|
|||
|
|
"execution_count": 59,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"--2024-04-19 14:32:03-- https://raw.githubusercontent.com/googlecreativelab/quickdraw-dataset/master/categories.txt\n",
|
|||
|
|
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
|
|||
|
|
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
|
|||
|
|
"HTTP request sent, awaiting response... 200 OK\n",
|
|||
|
|
"Length: 2791 (2.7K) [text/plain]\n",
|
|||
|
|
"Saving to: ‘categories.txt’\n",
|
|||
|
|
"\n",
|
|||
|
|
"categories.txt 100%[===================>] 2.73K --.-KB/s in 0s \n",
|
|||
|
|
"\n",
|
|||
|
|
"2024-04-19 14:32:04 (41.9 MB/s) - ‘categories.txt’ saved [2791/2791]\n",
|
|||
|
|
"\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "H4d-lVtqgzgC"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"import os\n",
|
|||
|
|
"import urllib\n",
|
|||
|
|
"\n",
|
|||
|
|
"def download_and_load(class_names, test_split = 0.2, max_items_per_class = 10000):\n",
|
|||
|
|
" root = 'data'\n",
|
|||
|
|
" if not os.path.exists(root):\n",
|
|||
|
|
" os.makedirs(root)\n",
|
|||
|
|
"\n",
|
|||
|
|
" print('downloading ...')\n",
|
|||
|
|
" base = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'\n",
|
|||
|
|
" for c in class_names:\n",
|
|||
|
|
" cc = c.replace('_', '%20')\n",
|
|||
|
|
" path = base+cc+'.npy'\n",
|
|||
|
|
" print(path)\n",
|
|||
|
|
" if not os.path.exists(f'{root}/{cc}.npy'):\n",
|
|||
|
|
" urllib.request.urlretrieve(path, f'{root}/{cc}.npy')\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" print(\"Already downloaded\")\n",
|
|||
|
|
" print('loading ...')\n",
|
|||
|
|
"\n",
|
|||
|
|
" #initialize variables\n",
|
|||
|
|
" x = np.empty([0, 784])\n",
|
|||
|
|
" y = np.empty([0])\n",
|
|||
|
|
"\n",
|
|||
|
|
" #load each data file\n",
|
|||
|
|
" for idx, file in enumerate(class_names):\n",
|
|||
|
|
" file = file.replace('_', '%20')\n",
|
|||
|
|
" data = np.load(f'{root}/{file}.npy')\n",
|
|||
|
|
" data = data[0: max_items_per_class, :]\n",
|
|||
|
|
" labels = np.full(data.shape[0], idx)\n",
|
|||
|
|
"\n",
|
|||
|
|
" x = np.concatenate((x, data), axis=0)\n",
|
|||
|
|
" y = np.append(y, labels)\n",
|
|||
|
|
"\n",
|
|||
|
|
" data = None\n",
|
|||
|
|
" labels = None\n",
|
|||
|
|
"\n",
|
|||
|
|
" #randomize the dataset\n",
|
|||
|
|
" permutation = np.random.permutation(y.shape[0])\n",
|
|||
|
|
" x = x[permutation, :]\n",
|
|||
|
|
" y = y[permutation]\n",
|
|||
|
|
"\n",
|
|||
|
|
" #reshape and inverse the colors\n",
|
|||
|
|
" x = 255 - np.reshape(x, (x.shape[0], 28, 28))\n",
|
|||
|
|
"\n",
|
|||
|
|
" #separate into training and testing\n",
|
|||
|
|
" test_size = int(x.shape[0]/100*(test_split*100))\n",
|
|||
|
|
"\n",
|
|||
|
|
" x_test = x[0:test_size, :]\n",
|
|||
|
|
" y_test = y[0:test_size]\n",
|
|||
|
|
"\n",
|
|||
|
|
" x_train = x[test_size:x.shape[0], :]\n",
|
|||
|
|
" y_train = y[test_size:y.shape[0]]\n",
|
|||
|
|
"\n",
|
|||
|
|
" print('Training Data : ', x_train.shape[0])\n",
|
|||
|
|
" print('Testing Data : ', x_test.shape[0])\n",
|
|||
|
|
" return x_train, y_train, x_test, y_test"
|
|||
|
|
],
|
|||
|
|
"execution_count": 60,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "kAiJYOJBgXup"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# set your random seed value, put any number here\n",
|
|||
|
|
"RANDOM_SEED = 1234"
|
|||
|
|
],
|
|||
|
|
"execution_count": 61,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "CsfWTt6NeUEJ",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"outputId": "8f5a3e43-397a-4069-8076-cb4ac25a702f"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"import random\n",
|
|||
|
|
"\n",
|
|||
|
|
"with open('categories.txt') as f:\n",
|
|||
|
|
" all_class_names = f.readlines()\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"random.seed(RANDOM_SEED)\n",
|
|||
|
|
"\n",
|
|||
|
|
"all_class_names = [x.strip().replace(' ', '_') for x in all_class_names]\n",
|
|||
|
|
"\n",
|
|||
|
|
"nb_classes = 10\n",
|
|||
|
|
"\n",
|
|||
|
|
"# select random 10 classes\n",
|
|||
|
|
"class_names = random.sample(all_class_names, nb_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# you can also try those \"hard\" classes instead\n",
|
|||
|
|
"#class_names = ['wheel', 'pizza', 'smiley_face', 'apple', 'potato', 'basketball', 'soccer_ball', 'brain', 'clock', 'circle']\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(class_names)\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train, y_train, X_test, y_test = download_and_load(class_names)\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train = X_train.reshape(X_train.shape[0], 28, 28, 1) #add an additional dimension to represent the single-channel\n",
|
|||
|
|
"X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train = X_train.astype('float32') # change integers to 32-bit floating point numbers\n",
|
|||
|
|
"X_test = X_test.astype('float32')\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train /= 255 # normalize each value for each pixel for the entire vector for each input\n",
|
|||
|
|
"X_test /= 255\n",
|
|||
|
|
"\n",
|
|||
|
|
"Y_train = to_categorical(y_train, nb_classes)\n",
|
|||
|
|
"Y_test = to_categorical(y_test, nb_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"visualize_classes(X_train, y_train)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 62,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"['pizza', 'cannon', 'ambulance', 'bulldozer', 'swing_set', 'baseball_bat', 'zebra', 'bridge', 'cactus', 'matches']\n",
|
|||
|
|
"downloading ...\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/pizza.npy\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/cannon.npy\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/ambulance.npy\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/bulldozer.npy\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/swing%20set.npy\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/baseball%20bat.npy\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/zebra.npy\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/bridge.npy\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/cactus.npy\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/matches.npy\n",
|
|||
|
|
"loading ...\n",
|
|||
|
|
"Training Data : 80000\n",
|
|||
|
|
"Testing Data : 20000\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 1000x2000 with 1 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAMaCAYAAAABQDBSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9Z3Nd2XklvG7OOeAiZzCAzahudlBbarVkWRp7pmpmyjV/YX7UfPCHcU1yTb225JFkWWpb3VIHRoAJOePmnPP7gV5P73sbIEESJAHyrCpUs0ng4txz99n7CetZS9ftdrvQoEGDBg0aNGjQoEGDhmOE/nVfgAYNGjRo0KBBgwYNGt48aImGBg0aNGjQoEGDBg0ajh1aoqFBgwYNGjRo0KBBg4Zjh5ZoaNCgQYMGDRo0aNCg4dihJRoaNGjQoEGDBg0aNGg4dmiJhgYNGjRo0KBBgwYNGo4dWqKhQYMGDRo0aNCgQYOGY4eWaGjQoEGDBg0aNGjQoOHYoSUaGjRo0KBBgwYNGjRoOHZoiYYGDRo0aNCgQYMGDRqOHcbXfQEaNGjQoOHVodvtAgB0Ot1rvpLXh06ng3K5jFqthlarhUqlgmazCaPRCJPJBIPBAJPJBJPJBL1eD7PZDJPJBJ1OB4PB8FbfOw0angXNZhOFQgHNZhN6vV6eH5vNBpvN9rov78Sj3W6jWCyiWq3KXqTX62EymWC1WqHXn/x+gZZoaNCgQcNbgm63K186ne5UHFIvA61WC6lUCqlUCpVKBbFYDKVSCTabDW63GyaTCQ6HA16vFyaTCW63G06nEzqdDlarFUajdnRq0HAUVKtV7O7uolgswmQywWKxwGg0IhQKwWKxvLV70FHRbDYRjUaRSCRkL7JYLHC5XFIIOenQdksNGjRoUKAG4+12G+12W4JybupqRftlVLfZdTgKdDqdXMNh19Jut+X9NJtNAIDJZILZbH6rqvPtdhutVguNRgOlUgm5XA6VSgWZTAaFQgF2ux2tVgtms1nuk3qPDAYDut2uHPCszqqfgQYNGr5Fq9VCsVhELpeD2WyWRN3hcKDZbMJgMGhdwgPAM6DVaqFcLiOXy8FiscBgMKDdbsNsNj/TOfE6oSUaGjRo0PBv6Ha7KBaLKBQKqFQqePjwIba2tmC32zE8PAyXywWz2Qy73S4HpNFoPPZDkgHxUQ4Si8UCh8MBo9Eof1avp1KpYGdnB9lsFvl8Htvb26jX67hw4QLee+892O32Y732k4pOp4P19XUsLS2hXC5jaWkJOzs7aDQaKBQKqNfrMBqNsNlsMBgMsFgssNls0Ov1sNlssFgsMJlM8Pl8sNvtCAQCmJ2dhcfjgcPhgMfjgcFgeN1vU4OGE4VcLocbN25ga2sLnU4H7XYbBoMB8/PzOH/+PJxOJ0ZGRhAMBl/3pZ4YdLtdlMtllMtlZLNZ/OlPf8Ldu3fhcrkwMTEBt9uNyclJ+P1+mM3m1325T4WWaGjQoEHDv6Hb7aJQKGB/fx+pVAq/+MUv8PnnnyMYDOJ73/sehoeH4XA4ZIO3WCywWq1P7Sg8KxqNBur1OjqdzhOTDZ1OB5fLhXA4DIvFAq/XC7vd3nMd1WoVy8vL2NzcxM7ODr766isUCgX8p//0n/DOO++8NYlGu93G6uoqfvvb3yKdTmNhYQFra2tyj0kn6/8s1b+zWq2IRCLwer2YnZ3Fz372MwwPD2NgYABOp1NLNDRo6EM2m8VXX32FhYUFVCoVFAoFGAwGfP/730elUoHf74fNZtMSDQVMNJLJJPb39/HFF1/gd7/7HQKBAC5evIhQKAQAeOedd17zlR4NWqKhQYMGDQo6nQ4ajQaazab8t1KpIJvNwmw2o1wuo9lswmw2CxXgZSYahJpwqL+nXC6j1WrBYrGgWq1KZZ4dl3Q6LfMIpAiVSiWUSiUUCgWhAen1ehl25p/539MMVlHr9TpKpRLS6TSy2SwajYZ0pMgb73a7cs87nc6Bf65WqwAeV2pLpRLK5TLq9fqpoTG8KnQ6HenKqQmzyWSSGZfTvrY0PB2tVgvVahWlUgm1Wg21Wg16vR65XA6JRALtdhvJZBLBYBAmkwl2u/2tn4FiZz0ejyORSCCfz6NcLsNms6Fer6PZbB65430S8HZ/mho0aNDwb2BVu9Vqodlsot1uw+VyIRKJoN1u4/79+1hcXOxR+2AwDxxv0ETq1FFgNpuF7uNyueDxeGAymeD3++HxeJDP57GwsID9/X3UajUJlPf29vCHP/xBqD8OhwMmk0m6IhaLRQajTzPK5TJSqZTQpRYWFlCv1xEMBjE5OQmPx4OzZ88iEAhIUNRut1GpVFCtVoUjXalUUC6Xsbq6ilgsBrvdjpWVFZTLZRiNRoyNjcFisbzut/vawaSiXC4jHo+jXq/LutPpdBgZGcHAwIAkeadhmFXDs4P7abPZRKlUQj6fh06nkwHwnZ0dFItFuFwuxGIx3Lt3DwMDA3j33XcxNDT0Vs89NZtN3L59G7/4xS+QzWaxvLyMRqMhZ8JpuzdaoqFBgwYN/wZWX5lo2Gw2+Hw+5HI5rK6uIp1Ov+5LPBQ6nQ52ux1OpxMWiwWjo6OIRCIol8tYXl5GKpUSNSWTyYRUKoXFxUU4HA4EAgH4fD5YrVYMDQ3B6/XC4XDAbref+kSjVqshlUohn89jZ2cH6+vr0Ol0mJycxDvvvIPBwUH84Ac/wMTEBGq1GgqFAhqNBorFIjKZDFqtFtLpNPL5PGKxGDY3N5HJZOB0OhGNRtHtdjE0NHRqqosvGwwwK5UK4vG4dH3y+TwMBgOsViv8fj+63a5GNXuDoRZuqtUqyuVyzwxZPB7H+vo6bDYbqtUq9vb2MD09jZmZGQwODgJ4ezterVYLKysr+Kd/+ifpBLVaLRH1OG049YlGu91Go9Hoac9qEoQaNGh4FjSbTdRqNTQaDUSjUaytraFSqcBoNGJ0dBQ+nw8GgwH5fB61Wg3FYhHNZhPVahWVSgV6vV46CE+rNnW7XeRyOeTzeQCQqu5hHQRV7YqHTK1WQzKZlCoxv4eJBhVJSqUSKpWKVMN0Op1Uxdiat1qtMnRosVhQKBTgdDrh9/sBAB6PR7omp/Hgr1QqiEajyGQyKBaL4pURCAQwPDyMUCgkyVe324XNZuv5DJh0ApDuBYNkzum8zWdNt9tFo9EQul8ymRSVnK2tLVQqFVQqFZRKJekA1mo1OBwOSWr1ev1LEVXQ8HrABEPtDne7XdjtdgwODsJkMkkCajAY0Ol0kE6n4fF4hOpptVrfyrknqgPy/qkUKZ4TNpvtVCkGnvrdkW3xZrMpvDWj0YihoSE5KDVo0KDhSSgWi9jb20OxWMTvfvc7/L//9/9gMpnwk5/8BB999FHP925vb+PWrVvIZDLY2trCysoKzGYzfvCDH+D999+X+YjDKCH1eh1ffPEFvvnmG3S7XbhcLlitVgwODuLdd9+VQT/gcUWPgTHwbZVwe3sbv/zlL7G2tgaj0Qir1QqDwYBgMIhwOAzg8RDm7u6uVOlZFWOQvL6+jmg02kMBU83ppqam8Jd/+ZcYGxtDOBzGxMTEqVA4IXiv9vb28M///M+IRqPY29uD2+2G1+vF1atX8ZOf/AQ2mw2BQEAOb5vNhna7LZ2tVquFZDKJTCYjspzdbhdGoxFerxfBYFA8Nt5GtNttJBIJpNNpxGIx/Pa3v8Xy8jJqtRqy2awEmq1WSxJyn8+HgYEB/If/8B9w9erVN4amp+ExaIjJ2Yxms4lOp4Ph4WF88skn8Hg8KJfLUgi5f/8+7t69i0KhgJmZGZRKJQwNDWFubg5Op/N1v51XBlV+m4l7o9GQ+TBSW0OhENxu96lJwk59okH+X71elyEZi8WiKRho0KDhyGg0Gsjlcsjlctjc3MTCwgLcbjd+/vOfY2ZmRgySzGYzHj16hHw+D6fT2WNCNTY2hqtXr/Y4Sh8Eys0+ePAAAOD1euF0OjE8PIyLFy9iZGREvlen00nVHPh2KNnj8eDLL7/E7u6uyO0ajUZ4PB4EAgF0Oh1ks1nZGxuNhlTlGfhVq1Wk02mpoKmD58DjIs6lS
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "W5Jm2OPyk4FF"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Create the model"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "1k0Qnajnk3ey"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model = Sequential() # Linear stacking of layers\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Convolution Layer 1\n",
|
|||
|
|
"model.add(Conv2D(64, (5, 5), input_shape=(28,28,1))) # 16 different 5x5 kernels -- so 16 feature maps\n",
|
|||
|
|
"model.add(BatchNormalization()) # weights normalization\n",
|
|||
|
|
"model.add(Activation('relu') ) # activation\n",
|
|||
|
|
"model.add(MaxPooling2D(pool_size=(2,2))) # Pool the max values over a 2x2 kernel\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Convolution Layer 2\n",
|
|||
|
|
"model.add(Conv2D(32, (5, 5))) # 32 different 5x5 kernels -- so 32 feature maps\n",
|
|||
|
|
"model.add(BatchNormalization()) # weights normalization\n",
|
|||
|
|
"model.add(Activation('relu')) # activation\n",
|
|||
|
|
"model.add(MaxPooling2D(pool_size=(2,2))) # Pool the max values over a 2x2 kernel\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.add(Flatten()) # Flatten final output matrix into a vector\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Fully Connected Layer\n",
|
|||
|
|
"model.add(Dense(128)) # 128 FC nodes\n",
|
|||
|
|
"model.add(Activation('relu')) # activation\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Fully Connected Layer\n",
|
|||
|
|
"model.add(Dense(10)) # final 10 FC nodes\n",
|
|||
|
|
"model.add(Activation('softmax')) # softmax activation\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# we'll use the same optimizer\n",
|
|||
|
|
"adam = tf.optimizers.Adam(learning_rate=0.001)\n",
|
|||
|
|
"model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])"
|
|||
|
|
],
|
|||
|
|
"execution_count": 63,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "HmR_tm0gmR0u"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# data augmentation prevents overfitting by slightly changing the data randomly\n",
|
|||
|
|
"# Keras has a great built-in feature to do automatic augmentation\n",
|
|||
|
|
"\n",
|
|||
|
|
"gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,\n",
|
|||
|
|
" height_shift_range=0.08, zoom_range=0.08, validation_split=0.2)\n",
|
|||
|
|
"\n",
|
|||
|
|
"test_gen = ImageDataGenerator()\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_generator = gen.flow(X_train, Y_train, batch_size=128, subset='training')\n",
|
|||
|
|
"valid_generator = gen.flow(X_train, Y_train, batch_size=128, subset='validation')\n",
|
|||
|
|
"test_generator = test_gen.flow(X_test, Y_test, batch_size=128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": 64,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "p-jeOuI8nfl9",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "e9544c03-ba7e-4de9-d41b-6b63d26f9700"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model.fit(train_generator, steps_per_epoch=64000//128, epochs=5, verbose=1, validation_data=valid_generator, validation_steps = 16000 // 128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/5\n",
|
|||
|
|
"500/500 [==============================] - 27s 49ms/step - loss: 0.7584 - accuracy: 0.7635 - val_loss: 1.1462 - val_accuracy: 0.6129\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "OlzMJklYooX8"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"score = model.evaluate(X_test, Y_test)\n",
|
|||
|
|
"print('Test score:', score[0])\n",
|
|||
|
|
"print('Test accuracy:', score[1])"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "sAhuIjARoqhp"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The predict_classes function outputs the highest probability class\n",
|
|||
|
|
"# according to the trained classifier for each input example.\n",
|
|||
|
|
"predicted = model.predict(X_test)\n",
|
|||
|
|
"predicted_classes = np.argmax(predicted, axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Check which items we got right / wrong\n",
|
|||
|
|
"correct_indices = np.nonzero(predicted_classes == y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"incorrect_indices = np.nonzero(predicted_classes != y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"cnf_matrix = confusion_matrix(y_test, predicted_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Plot non-normalized confusion matrix\n",
|
|||
|
|
"plt.figure()\n",
|
|||
|
|
"plot_confusion_matrix(cnf_matrix, classes=class_names,\n",
|
|||
|
|
" title='Confusion matrix, without normalization')\n",
|
|||
|
|
"\n",
|
|||
|
|
"plt.show()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "BKDRjvPFow6L"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples(correct_indices, predicted, X_test, y_test, 5, class_names)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "lgtxWrarps6b"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples(incorrect_indices, predicted, X_test, y_test, 5, class_names)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "r29_OvwJtyM7"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# CIFAR-10"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "nuo0JHXDzxT7"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Ładowanie zbioru danych"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "pPrx-UWet0Ng"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"from keras.datasets import cifar10\n",
|
|||
|
|
"\n",
|
|||
|
|
"(X_train, y_train), (X_test, y_test) = cifar10.load_data()\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train = X_train.astype('float32') # change integers to 32-bit floating point numbers\n",
|
|||
|
|
"X_test = X_test.astype('float32')\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train /= 255 # normalize each value for each pixel for the entire vector for each input\n",
|
|||
|
|
"X_test /= 255\n",
|
|||
|
|
"\n",
|
|||
|
|
"y_train = y_train.reshape((1,-1))[0]\n",
|
|||
|
|
"y_test = y_test.reshape((1,-1))[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(\"Training matrix shape\", X_train.shape, y_train.shape)\n",
|
|||
|
|
"print(\"Testing matrix shape\", X_test.shape, y_test.shape)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# one-hot format classes\n",
|
|||
|
|
"\n",
|
|||
|
|
"nb_classes = 10\n",
|
|||
|
|
"\n",
|
|||
|
|
"Y_train = to_categorical(y_train, nb_classes)\n",
|
|||
|
|
"Y_test = to_categorical(y_test, nb_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"cifar_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Q2LGp6AVzzqC"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Podgląd zbioru treningowego"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "o6OJ7XPdxe1i"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"for i in range(0, 10):\n",
|
|||
|
|
" img_batch = X_train[y_train == i][0:10]\n",
|
|||
|
|
" img_batch = np.reshape(img_batch, (img_batch.shape[0]*img_batch.shape[1], img_batch.shape[2], img_batch.shape[3]))\n",
|
|||
|
|
" if i > 0:\n",
|
|||
|
|
" img = np.concatenate([img, img_batch], axis = 1)\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" img = img_batch\n",
|
|||
|
|
"plt.figure(figsize=(10,20))\n",
|
|||
|
|
"plt.axis('off')\n",
|
|||
|
|
"plt.imshow(img, cmap='gray')\n"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "6EF56o2gz3Mz"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Przygotowanie modelu"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "6UvabhsguqWc"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"def generate_model():\n",
|
|||
|
|
" model = Sequential() # Linear stacking of layers\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Convolution Layer 1\n",
|
|||
|
|
" model.add(Conv2D(16, (3, 3), input_shape=(32,32,3)))\n",
|
|||
|
|
" model.add(Activation('relu') )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # ...\n",
|
|||
|
|
"\n",
|
|||
|
|
" model.add(Flatten()) # Flatten final output matrix into a vector\n",
|
|||
|
|
"\n",
|
|||
|
|
" # ...\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Fully Connected Layer\n",
|
|||
|
|
" model.add(Dense(10)) # final 10 FC nodes\n",
|
|||
|
|
" model.add(Activation('softmax')) # softmax activation\n",
|
|||
|
|
"\n",
|
|||
|
|
" model.summary()\n",
|
|||
|
|
"\n",
|
|||
|
|
" adam = tf.optimizers.Adam(learning_rate=0.001)\n",
|
|||
|
|
" model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" return model"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "PSQkAq9IvJsS"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model = generate_model()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "YMtJVU_oz55g"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Trening"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "sjZaiBLJvQkP"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,\n",
|
|||
|
|
" height_shift_range=0.08, zoom_range=0.08, validation_split=0.2)\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_generator = gen.flow(X_train, Y_train, batch_size=128, subset='training')\n",
|
|||
|
|
"valid_generator = gen.flow(X_train, Y_train, batch_size=128, subset='validation')"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "MlMHmbOsvbWP"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model.fit(train_generator, steps_per_epoch=40000//128, epochs=2, verbose=1, validation_data=valid_generator, validation_steps = 10000 // 128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "mN5zKMDNz8Jp"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Test"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "lLZORONWvqex"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"score = model.evaluate(X_test, Y_test)\n",
|
|||
|
|
"print('Test score:', score[0])\n",
|
|||
|
|
"print('Test accuracy:', score[1])\n",
|
|||
|
|
"\n",
|
|||
|
|
"# The predict_classes function outputs the highest probability class\n",
|
|||
|
|
"# according to the trained classifier for each input example.\n",
|
|||
|
|
"predicted = model.predict(X_test)\n",
|
|||
|
|
"predicted_classes = np.argmax(predicted, axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Check which items we got right / wrong\n",
|
|||
|
|
"correct_indices = np.nonzero(predicted_classes == y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"incorrect_indices = np.nonzero(predicted_classes != y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"cnf_matrix = confusion_matrix(y_test, predicted_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"class_names = [str(i) for i in range(10)]\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Plot non-normalized confusion matrix\n",
|
|||
|
|
"plt.figure()\n",
|
|||
|
|
"plot_confusion_matrix(cnf_matrix, classes=class_names,\n",
|
|||
|
|
" title='Confusion matrix, without normalization')\n",
|
|||
|
|
"\n",
|
|||
|
|
"plt.show()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "ktfHkYc1zHsY"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"def show_samples_rgb(indices, preds, images, labels, count=3, names = []):\n",
|
|||
|
|
" plt.figure()\n",
|
|||
|
|
" for i, sample in enumerate(indices[:count**2]):\n",
|
|||
|
|
" pred_id = int(np.argmax(preds[sample]))\n",
|
|||
|
|
" real_id = int(labels[sample])\n",
|
|||
|
|
" pred_score = preds[sample][pred_id]\n",
|
|||
|
|
" real_score = preds[sample][real_id]\n",
|
|||
|
|
" plt.subplot(count,count,i+1)\n",
|
|||
|
|
" plt.imshow(images[sample], interpolation='none')\n",
|
|||
|
|
" plt.axis('off')\n",
|
|||
|
|
" if len(names) > 0:\n",
|
|||
|
|
" plt.title(\"P: {} ({:.2f})\\nE: {} ({:.2f})\".format(names[pred_id], pred_score, names[real_id], real_score))\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" plt.title(\"P: {} ({:.2f})\\nE: {} ({:.2f})\".format(pred_id, pred_score, real_id, real_score))\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.tight_layout()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "F8hS785GzoGJ"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Poprawne klasyfikacje"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "IUcJVbOlzm5z"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples_rgb(correct_indices, predicted, X_test, y_test, 5, cifar_names)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "PiyibL4yzpup"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Błędne klasyfikacje"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "ECh_2RW6zgKB"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples_rgb(incorrect_indices, predicted, X_test, y_test, 5, cifar_names)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
}
|