mirror of
https://github.com/kuhyx/WUT_Computer_Science.git
synced 2026-07-04 16:43:12 +02:00
3470 lines
3.9 MiB
Plaintext
3470 lines
3.9 MiB
Plaintext
|
|
{
|
|||
|
|
"nbformat": 4,
|
|||
|
|
"nbformat_minor": 0,
|
|||
|
|
"metadata": {
|
|||
|
|
"kernelspec": {
|
|||
|
|
"name": "python3",
|
|||
|
|
"display_name": "Python 3"
|
|||
|
|
},
|
|||
|
|
"colab": {
|
|||
|
|
"provenance": [],
|
|||
|
|
"gpuType": "T4",
|
|||
|
|
"include_colab_link": true
|
|||
|
|
},
|
|||
|
|
"accelerator": "GPU",
|
|||
|
|
"language_info": {
|
|||
|
|
"name": "python"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"cells": [
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "view-in-github",
|
|||
|
|
"colab_type": "text"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<a href=\"https://colab.research.google.com/github/kuhyx/twm_4/blob/main/TWM_KerasIntro.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "c9tFwxET8YY5"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"%matplotlib inline"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "M9PoJm628YY_"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Introduction to Deep Learning with Keras and TensorFlow\n",
|
|||
|
|
"\n",
|
|||
|
|
"Based on excelent work by **[Daniel Moser (UT Southwestern Medical Center)](https://github.com/AviatorMoser/keras-mnist-tutorial)**, Resources: **[Xavier Snelgrove](https://github.com/wxs/keras-mnist-tutorial), [Yash Katariya](https://github.com/yashk2810/MNIST-Keras)**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "lh0dsQg08YZA"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"To help you understand the fundamentals of deep learning, this demo will walk through the basic steps of building two toy models for classifying handwritten numbers with accuracies surpassing 95%. The first model will be a basic fully-connected neural network, and the second model will be a deeper network that introduces the concepts of convolution and pooling."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "LJVXVJ6_8YZB"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## The Task for the AI\n",
|
|||
|
|
"\n",
|
|||
|
|
"Our goal is to construct and train an artificial neural network on thousands of images of handwritten digits so that it may successfully identify others when presented. The data that will be incorporated is the MNIST database which contains 60,000 images for training and 10,000 test images. We will use the Keras Python API with TensorFlow as the backend."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "3eVu55U98YZC"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src=\"https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/mnist.png\" >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Ms_a4hzY8YZC"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Prerequisite Python Modules\n",
|
|||
|
|
"\n",
|
|||
|
|
"First, some software needs to be loaded into the Python environment."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "D7Bns2Lo8YZD",
|
|||
|
|
"outputId": "3808fa0b-ff68-4786-c7f7-86195e55179b",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"import numpy as np # advanced math library\n",
|
|||
|
|
"import matplotlib.pyplot as plt # MATLAB like plotting routines\n",
|
|||
|
|
"import random # for generating random numbers\n",
|
|||
|
|
"\n",
|
|||
|
|
"import tensorflow as tf\n",
|
|||
|
|
"\n",
|
|||
|
|
"from keras.datasets import mnist # MNIST dataset is included in Keras%\n",
|
|||
|
|
"from keras.models import Sequential # Model type to be used\n",
|
|||
|
|
"\n",
|
|||
|
|
"from keras.layers import Dense, Dropout, Activation # Types of layers to be used in our model\n",
|
|||
|
|
"from keras.utils import to_categorical # NumPy related tools\n",
|
|||
|
|
"\n",
|
|||
|
|
"from keras import optimizers\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"from sklearn.metrics import confusion_matrix\n",
|
|||
|
|
"import itertools\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(tf.__version__)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"2.15.0\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "QygDQ7Ch8YZG"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Loading Training Data\n",
|
|||
|
|
"\n",
|
|||
|
|
"The MNIST dataset is conveniently bundled within Keras, and we can easily analyze some of its features in Python."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "SdZZph6i8YZG",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "a956082b-3f72-41e0-c83f-39af48c6ed99"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The MNIST data is split between 60,000 28 x 28 pixel training images and 10,000 28 x 28 pixel images\n",
|
|||
|
|
"(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(\"X_train shape\", X_train.shape)\n",
|
|||
|
|
"print(\"y_train shape\", y_train.shape)\n",
|
|||
|
|
"print(\"X_test shape\", X_test.shape)\n",
|
|||
|
|
"print(\"y_test shape\", y_test.shape)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"X_train shape (60000, 28, 28)\n",
|
|||
|
|
"y_train shape (60000,)\n",
|
|||
|
|
"X_test shape (10000, 28, 28)\n",
|
|||
|
|
"y_test shape (10000,)\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "9v5sjP468YZJ"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Using matplotlib, we can plot some sample images from the training set directly into this Jupyter Notebook. We can egzamine the interclass variability - how many different ways of writing the same digit there are!"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Ov1T_A6X8YZJ",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 811
|
|||
|
|
},
|
|||
|
|
"outputId": "11cc6494-b79c-4225-a5c5-86b831ef354d"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"plt.rcParams['figure.figsize'] = (9,9) # Make the figures a bit bigger\n",
|
|||
|
|
"\n",
|
|||
|
|
"def visualize_classes(X, y):\n",
|
|||
|
|
" for i in range(0, 10):\n",
|
|||
|
|
" img_batch = X[y == i][0:10]\n",
|
|||
|
|
" img_batch = np.reshape(img_batch, (img_batch.shape[0]*img_batch.shape[1], img_batch.shape[2]))\n",
|
|||
|
|
" if i > 0:\n",
|
|||
|
|
" img = np.concatenate([img, img_batch], axis = 1)\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" img = img_batch\n",
|
|||
|
|
" plt.figure(figsize=(10,20))\n",
|
|||
|
|
" plt.axis('off')\n",
|
|||
|
|
" plt.imshow(img, cmap='gray')\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"visualize_classes(X_train, y_train)\n"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 1000x2000 with 1 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAMaCAYAAAABQDBSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz953Ncd3oljp/OOeeIHAgQzFSi5OFIMx55XF6vvVu1frP/nqtc663yfj0O47FHM8pDimIASRCxge5G5+7bOYffC/2eR7dBgCIlEOhu3lOFYgLA7otPeMJ5zpENBoMBJEiQIEGCBAkSJEiQIOEUIT/vFyBBggQJEiRIkCBBgoTJg5RoSJAgQYIECRIkSJAg4dQhJRoSJEiQIEGCBAkSJEg4dUiJhgQJEiRIkCBBggQJEk4dUqIhQYIECRIkSJAgQYKEU4eUaEiQIEGCBAkSJEiQIOHUISUaEiRIkCBBggQJEiRIOHVIiYYECRIkSJAgQYIECRJOHcqX/USZTPY6X4cECRIkSJAgQYIECRLGBC/j+S11NCRIkCBBggQJEiRIkHDqkBINCRIkSJAgQYIECRIknDqkREOCBAkSJEiQIEGCBAmnjpee0ZAgQYIECRIkvDmQyWSQy+VQKpWQy7+vSw4GA/R6PXQ6nXN8dRIkSBgHSImGBAkSJEiQIIEhl8uh0+mg0Whgt9tx9epVeDweFoXpdrtYX1/H3bt30Wq1zvnVSpAgYZQhJRoSJEiQIEGCBIZcLofBYIDJZMLCwgL+5m/+BhcvXuQOR7vdxt///d/j0aNHUqIhQYKEF0JKNCRIkCBBggQJ0Gg00Gq1UKvVcLlcsFqtcLvdsNlssFqt6Pf76PV6kMlkUCql8EGCBAk/DOmkkCBBggQJEt5wKJVKLCwsYGVlBQaDAfPz8/D5fLDb7VhYWIDFYkGpVEImk0GlUkG5XH4pDX0JEiS82ZASDQkSJEiQIOENh1wuRzAYxM2bN2Gz2XDx4kVMT09DpVJBr9dDpVKhXC6jXC5DEATU63Up0ZAgQcIPYqISDZVKBaVSCYVCAaPRCK1WC6VSyYdkq9VCo9FAt9tFo9FArVZDv99Hp9NBt9s975cvQYIECRIkvHbQrIVMJoPJZILdbodOp8PU1BQ8Hg/MZjMMBgNUKhXkcjnfkcViEbFYDPl8HoVCAb1e77zfigQJEkYcE5NoiA9Mg8GACxcuIBQKwWq1YmFhAXa7Hel0Gtvb26hUKtjb28Pm5ibq9ToEQUCpVDrvtyBBggQJEiS8digUCuh0OiiVSly+fBkffvghnE4n5ufnMT8/D5VKBZPJBK1Wi3a7jVKphGaziSdPnuA3v/kNkskkEomEJG8rQYKEH8REJRoajQZGoxEWiwVTU1NYWlqCy+XC9evX4fF4EI1GodVqIQgCOp0Okskk5HI5qtXqeb98CRIkSJAg4Uwgk8mgUqmgVqvh8Xhw6dIleL1e+Hw++P1+9swYDAbodDpoNpuo1WrIZDLY3t5GPB5Hs9mUOhoSJEj4QYx9oqFSqVglY2FhAaurq7BYLFheXkYwGITFYoFGowEA6HQ6+Hw+mM1m9Ho9qFQqlEolfP3118jn8xLf9BgolUo4nU5YLBZotVp4PB4YDAb+98FggGQyiWg0ilarhXq9jmazeY6v+GxAtAOiIMjlcq4CqtVqNrkCgGKxiEKhgH6/j36/f86v/HShUCigVquhUCgAgJ+JXq+HVquFRqOBz+eDzWb70f9Hs9lEpVJBu91GJpPhSqq0X0+GTCaDwWCATqeDTqdDKBSC2WyGIAiIRCKo1WrodDpot9tj/RzF++8kGI1G2O12pgHRGtVoNFCpVGg0GkgkEqhUKjCbzXC73dDr9TAYDDCbzZDL5RgMBhgMBmi328hms6hWqxx4j4u8K71npVIJs9mMmZkZmM1mLC4usrKUTqeDTCZDp9NBJpPhWYxkMolqtYpIJIJqtYp2u41erzfWa0fC+ID2uUqlwtTUFHw+H2QyGd+1hUIBe3t7qNVq6Ha7aLfb5/2SJYgw9omGTqeDy+WCyWTCn/3Zn+Gv/uqvYDQamWOqUCig1WoBAFarFSsrK+j1elhcXMS7776LXC6HcrmMjY0NqTpzDNRqNZaWlrC8vAyPx4Nbt24hFAoNGTf9/ve/xz/90z9BEAQkk0m0Wq2Jv4AUCgUHLjQbZDKZMDs7C6vVCpVKxQnZxsYGHj58yAHJJCUbKpUKVquVk3makfJ6vfB4PHA4HPjwww9x8eJFAPhR6yKbzWJ3dxflchlffPEFfve736FarU5k4nZaUCgUcDqd8Hg88Pv9+PWvf43l5WU8evQI//AP/4BoNIpKpYJisTi25x4lDWq1GiqVis+kowiFQrhy5QpMJhMUCgXvXZvNBqPRiHQ6jf/8z//E3t4e5ufn8bOf/QwulwtTU1NYWFiASqVCt9tFv9+HIAi4e/cu9vf3EYvF8OWXX45NoqFQKGAymWAwGBAMBvHRRx8hGAxienoac3Nz0Ov1UKvVkMlkaDabWF9fx9OnT1GtVnF4eIhKpYJ4PI58Pi8Ngks4UygUCiiVSlgsFvz85z/Hz3/+c2g0GhgMBmg0Gjx48AB///d/j2g0ysmGdDeMDsY20aBLRa1Ww2AwwGg0wuVyIRQKQa/Xc7WKMBgMoFQqYTQaAXyXoJjNZmg0GphMJsjlcvT7fenw/P+Dqn4qlQoWi4UDltnZWczOzvK/d7tdbG5uwmQyodVqTbS2urgaqlaruZpCa81oNMJms8HhcECj0UCv1wMADg8PoVaruQI4zgcg7Tvxc6DuhUwm4wvBarXC4XDA7XZjenoay8vLAMCV4VeB1WpFt9uFIAj8bFutFjqdzpk+S3EgS7//Me/nLCCXy6HVamE0GmG1WhEMBjE/P498Pg+DwTDUhRpl0Do7+mfxMLNGo4FGozkx0TCbzXC5XLBYLFAqlVwYcDgcMJvNAMDzCGazGT6fDz6fDzMzM1hcXIRarUa320W320Uul0MsFkO5XEaxWByr8468L9RqNUwmE9xuN/x+P5xOJ/R6PZ9XwHfFkHK5jEwmg2q1ilQqhUqlAkEQ0G63x/oMkzB6OGnvDgYD3utKpRIajQYulwvT09PQarUwmUzQaDTIZDLQ6XRQqVRjca69aRifU1IEMUVlcXERN2/ehN1ux4ULF6DT6YYu0cFgwAEe/Z6gVCqh1Wphs9ng9XrRbDZRrVa5Ij+KAcRZQCaTweFwwOl0wmq14tKlS7h27RqsVitMJhOA7yvT9JzGPYA+CdQRo8DE7/dDq9XCYrHAbDbzpa3X66HT6eB2u2EwGPhCHwwG0Ol0UCgUqFQq2N3dxeHh4ditLZlMBp1Ox3QoUqaxWCyYnp7mBJ4qxk6nE263G0ajEW63e2i9vCp0Oh0CgQBsNhsuXbqEXC6HYrGISCSCeDyOfr+Pbrf7Wp8pJdwajYaTKZlMhlKpxLS4UYJMJuNnb7FY0Gw2kclkUCwWR1ZhT6VSccIgDi5IQZD2HSUK9OF0OuFyuU6kTzmdTg5M5HI5FAoFF1Ho99evX4fb7cbc3ByWl5fhdDpht9vR7/fRbrdRqVRQr9eRzWZ5EDqfz4/VMDS9ZwrQ/H4/wuEwJ2BiDAYDtFotVKtVFItFpFIpFItFVKvVkV0/EsYLNCMk3oti0LkOABaLBTabDXa7HYFAAG63GwqFgvcnfTSbzbHak28KxjLRoKzWaDTiypUr+O///b/D5XJxK1ycHdMwGyUb9Cs5oOp0OjgcDgQCAaZjUNttXCkFPxVyuRxutxsrKytwOp1466238M477/AlJQ7ojgZ34xZA/xCIEqXRaLCwsIB33nkHVquVByc1Gg0Hc1TRp4BHJpOh1+vBbrfDZDIhn8+j1WohkUiM3XOiwNVms8FkMuHKlSsIhUJwuVxYW1uD1WoFAObLE32ROj7Aj18bRqMROp0OvV4P/X4fKpUKhUIBv//975HP53nO4HUGQGq1Gm63m7ugxGWPRCIolUojl2goFApepxaLBdVqFclkEoVCYWQDRUraKVmVyWTQarUIh8Ow2WywWCwIhUIwGo1Qq
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "uGjX5ZD68YZM"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Let's examine a single digit a little closer, and print out the array representing the last digit."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "wW_rrJfj8YZN",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "7a6fb008-9484-4e97-fcab-e53415548b8b"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# just a little function for pretty printing a matrix\n",
|
|||
|
|
"def matprint(mat, fmt=\"g\"):\n",
|
|||
|
|
" col_maxes = [max([len((\"{:\"+fmt+\"}\").format(x)) for x in col]) for col in mat.T]\n",
|
|||
|
|
" for x in mat:\n",
|
|||
|
|
" for i, y in enumerate(x):\n",
|
|||
|
|
" print((\"{:\"+str(col_maxes[i])+fmt+\"}\").format(y), end=\" \")\n",
|
|||
|
|
" print(\"\")\n",
|
|||
|
|
"\n",
|
|||
|
|
"# now print!\n",
|
|||
|
|
"matprint(X_train[-1])"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 38 48 48 22 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 62 97 198 243 254 254 212 27 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 67 172 254 254 225 218 218 237 248 40 0 21 164 187 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 89 219 254 97 67 14 0 0 92 231 122 23 203 236 59 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 25 217 242 92 4 0 0 0 0 4 147 253 240 232 92 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 101 255 92 0 0 0 0 0 0 105 254 254 177 11 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 167 244 41 0 0 0 7 76 199 238 239 94 10 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 192 121 0 0 2 63 180 254 233 126 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 190 196 14 2 97 254 252 146 52 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 130 225 71 180 232 181 60 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 130 254 254 230 46 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 6 77 244 254 162 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 110 254 218 254 116 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 131 254 154 28 213 86 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 66 209 153 19 19 233 60 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 142 254 165 0 14 216 167 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 90 254 175 0 18 229 92 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 26 229 249 176 222 244 44 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 73 193 197 134 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
|
|||
|
|
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "s3koPAgMdzoG"
|
|||
|
|
},
|
|||
|
|
"source": [],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "vQPPE2Bq8YZP"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Each pixel is an 8-bit integer from 0-255. 0 is full black, while 255 is full white. This what we call a single-channel pixel. It's called monochrome.\n",
|
|||
|
|
"\n",
|
|||
|
|
"*Fun-fact! Your computer screen has three channels for each pixel: red, green, blue. Each of these channels also likely takes an 8-bit integer. 3 channels -- 24 bits total -- 16,777,216 possible colors!*"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "z09IKy7PR9M9"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## More input data nalysis\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"Visualizing high-dimensional data by projecting it into a low-dimensional space is a classic operation that anyone working with data has probably done at least once in their life. There are a huge variety of methods for reducing dimensionality, but one very popular method is t-SNE, a method proposed by Geoffry Hinton’s group back in 2008. [ _more..._ ](https://mlexplained.com/2018/09/14/paper-dissected-visualizing-data-using-t-sne-explained/)\n",
|
|||
|
|
"\n",
|
|||
|
|
"Code by [Zaid Alyafeai](https://github.com/zaidalyafeai/Notebooks) [MIT license]"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "59m4-pjCSBgM"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"from sklearn.manifold import TSNE\n",
|
|||
|
|
"import matplotlib.patheffects as PathEffects\n",
|
|||
|
|
"import seaborn as sns\n",
|
|||
|
|
"\n",
|
|||
|
|
"RS=19238\n",
|
|||
|
|
"\n",
|
|||
|
|
"class_names = [ str(clid) for clid in range(10) ]\n",
|
|||
|
|
"\n",
|
|||
|
|
"def scatter(x, colors):\n",
|
|||
|
|
" # We choose a color palette with seaborn.\n",
|
|||
|
|
" palette = np.array(sns.color_palette(\"hls\", 10))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # We create a scatter plot.\n",
|
|||
|
|
" f = plt.figure(figsize=(8, 8))\n",
|
|||
|
|
" ax = plt.subplot(aspect='equal')\n",
|
|||
|
|
" sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40,\n",
|
|||
|
|
" c=palette[colors.astype(\"int\")])\n",
|
|||
|
|
" plt.xlim(-25, 25)\n",
|
|||
|
|
" plt.ylim(-25, 25)\n",
|
|||
|
|
" ax.axis('off')\n",
|
|||
|
|
" ax.axis('tight')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # We add the labels for each digit.\n",
|
|||
|
|
" txts = []\n",
|
|||
|
|
" for i in range(10):\n",
|
|||
|
|
" # Position of each label.\n",
|
|||
|
|
" xtext, ytext = np.median(x[colors == i, :], axis=0)\n",
|
|||
|
|
" txt = ax.text(xtext, ytext, class_names[i], fontsize=15)\n",
|
|||
|
|
" txt.set_path_effects([\n",
|
|||
|
|
" PathEffects.Stroke(linewidth=5, foreground=\"w\"),\n",
|
|||
|
|
" PathEffects.Normal()])\n",
|
|||
|
|
" txts.append(txt)\n",
|
|||
|
|
"\n",
|
|||
|
|
"def plot_tsne(X, y):\n",
|
|||
|
|
" print('calculating tsne ...')\n",
|
|||
|
|
" proj = TSNE(random_state=RS, learning_rate=200, init='random').fit_transform(X)\n",
|
|||
|
|
" scatter(proj, y)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "1_gbsFQ3Srw_",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 670
|
|||
|
|
},
|
|||
|
|
"outputId": "31984728-ddd1-4338-cedf-eaa46f90ea13"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"X = np.reshape(X_train, (X_train.shape[0], 28 * 28))[0:2000]\n",
|
|||
|
|
"y = y_train[0:2000]\n",
|
|||
|
|
"plot_tsne(X, y)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"calculating tsne ...\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 800x800 with 1 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAJ8CAYAAABunRBBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd5hcVfnA8e+d3me29+ym995IbxB67yCgCEhRFEH5ISqCCIoICgIqglSR3msCoSQhPaSH9GzvO73P3N8fm93sZKdtSyHn8zw8D3vvufee2ezuvHPK+0qyLMsIgiAIgiAIxw3Fke6AIAiCIAiCcHiJAFAQBEEQBOE4IwJAQRAEQRCE44wIAAVBEARBEI4zIgAUBEEQBEE4zogAUBAEQRAE4TgjAkBBEARBEITjjAgABUEQBEEQjjMiABQEQRAEQTjOiABQEARBEAThOCMCQEEQBEEQhOOMCAAFQRAEQRCOMyIAFARBEARBOM6IAFAQBEEQBOE4IwJAQRAEQRCE44wIAAVBEARBEI4zIgAUBEEQBEE4zogAUBAEQRAE4TgjAkBBEARBEITjjAgABUEQBEEQjjMiABQEQRAEQTjOiABQEARBEAThOCMCQEEQBEEQhOOMCAAFQRAEQRCOMyIAFARBEARBOM6IAFAQBEEQBOE4IwJAQRAEQRCE44wIAAVBEARBEI4zIgAUBEEQBEE4zogAUBAEQRAE4TgjAkBBEARBEITjjAgABUEQBEEQjjOqI90BQRCE44l/906CNdUoTWYMo8YgqcSfYUEQDj/xl0cQBOEwCNZWU/fEowT2720/prTayPneDzBNnnoEeyYIwvFIkmVZPtKdEARB+C6LeL2U33kbkZbmzicVCopu/w36ocMPf8cEQThuiTWAgiAIfazh+afjB38A0Sgt779zeDskCMJxTwSAgiAIfajhv8/i/npp0jbezRsQkzGCIBxOIgAUBEHoI/59e3F88uGR7oYgCEInIgAUBEHoA/49u6j+8x/SamsYNRZJktK+d9TnJdRQTzQQ6G73BEE4zoldwIIgCL0sWFdL9Z//QNTnS91Yksg4/ay07htqaqTp1f/iXrMKwmEknQ7ztJlkXXApSqOxh70WBOF4IgJAQRCEXub45IO0g7+8629Oawdw2GGn6r7fEW5qbD8m+/04lywmsGcXRb+6G4VW25NuC4JwHBFTwIIgCL0kGvDj+nopruVfpdXedvLpmKdOS6ut/eP3Y4K/jgL79+Fa9mXa/RQEQRAjgIIgHFfkcBjf9i1EvF60/UrR5Bf2yn2dy76k4bmnkQP+tNrrhg4n87yL0mobamrE8fmnSdu4Vi7HOv+ktO4nCIIgAkBBEL5zwk4HUY8HVVY2Co2m/bjr66U0vvwiEXtL+zHDqLHkXnsDKqut28/zbt1E/ZOPp93eMGosBT/7RacycKGmRrwb1iFHIuiHjUSVnU3VH+8huH9fynvK/jSmnAVBEA4QAaAgCN8ZgYr9NL38X7xbNoIso9DrMc+aS9b5l+DbtoW6fz0Gh+Tb827eQPWD91Hyu/uRlMpuPbcriZwlnY6cK38QE/zJkQgNzz+N88slEI0ebKtWI4dCad1XWzogrXahhnocSxbh37UDSa3BNGkq5umzxPpBQTjOiABQEITvhGB1FVX33U3U520/FvX5cHzyIcGKciJ+f6fgr/3ainI861ZjmnxCl58ryzK+bVvSaqvKzCLvRz9GnZsfc7zp1f/ijDPFm27wh0KBdcHClM28mzdQ88hDyMGD6WN8Wzbh+PRjin75G5QWS3rPEwThmCc2gQiC8J3Q/NZrMcFfR75tWwju3Z30es8367r13KjLlV5DtYbSBx/ttOM34vXi+Gxxt54NgFJJ7tU/QltalrRZNOCn9vFHYoK/NsHKChpefKb7fRAE4ZgjRgAFQTjmyeEw7jUre3iTrpdiCzU2UHX/3THTtokYR41GUnT+zB3YszNuUJYOw4RJ5F75Q1S2jJRt3Su/Jur1JD6/ZiURp1OMAgrCcUKMAAqCcMyTI+HUQZgq+eddw6gxXX5uwwvPJEzNEkOSsJ1yRvxzcYLCdOX94Lq0gj+AYG118gaRCKHG+m73RRCEY4sYARQE4Zin0OpQ6PREk+yElRQKEo3xqfPyk67/827ZhGPxxwQqy1EaTZinzcAwbiLeDamnjSWdjtwrf5gw2bNu0FAURhNRjzvlvWKvG4LSnN5oXcTpJFidIgAElBZrl/ogCMKxS5Llbsx7CIIgHGX23/4zQnW1SdsYxk/Cu3E9RCLtx7T9B2I75QyCVRUgSRjHTkA3YGD7+eZ336T59Zc73UtdUEioJnlQZRgzjvybfoZCq0varuXDd2l6+cWkbTqSDAbKHnocpS75feFAfsL/PIkcTr2hxDhxMjnf+wGqjMy0+yIIwrFJBICCIHwn1D31D1xffd6la0wnzCCwf2+nQE43dDiFt9xOqLGBil//ott9yv7e97GdeErC81G/H9fyr/Bt30KwtoZQXS1yoHU9oKRWo8rKJlRbE/fajDPOIeuCS5I+379vL5X33JnWGsU26tw8iu+6L2ltYTkcxrt5AxGXC01hEbqBg9O+vyAIRwcxBSwIwneCde6CLgeA7hXL4h73f7uN2n88gjo3L/kNlMqY0cSOJI0G8wkzE14arK6i6s9/INLSHHNclZNH1nkXoi4opPLuOxNeb1/8ERmnn4VCbwAg4nbhWLIYz7rVyKEw+qHDCLe0dCn4AwjV1+H88jMyTj0z7nn36hU0PP8fIk5H+zFtaX/ybvhJr1VVEQSh74kRQEEQjgg5Go27K7YnKn53B4F9e3vtfvpRY/Bt3pi0jdJmI2K3xx5UKMj94fVYZsyOe40sy5T/6taEU8jGCZMxjBpNw3NPJ312wc9+iXHcBEIN9VT98Z70NqSkQdJoUOfkoh86AutJp6ApaA3sfNu3UvXAvXGDSlVmJiX3PojSYOiVPgiC0LfECKAgCIdNoKKclnffxLNuDXIkjG7wEDJOPRPj+EndvqccjRLYtwc5HCbU0pL6gi4I1iZfU6gwGCn+3f04P1uEe81K5GAQ3aAh2Baeim7AoITX+bZsSrp+0LN+Ddr+qSt7RP2tdYcbnn2q14I/ADkYJFhVSbCqEudXS8i/+TaMo8fS8t7bCUcUw83N7L/9Z2hLy7DOno/CasW9/CsibheagkIscxagzsnttT4KgtAzYgRQEITDwr9rJ1V/vrd9jVtH2ZddiW3haV2+p/PLJTS/9Rrh5qbe6GKXWReeSs5lV3X5unQ2fZhnzcP11ZKkbdSFReRf/xMq7rqjW3kM06Uwmyl78O/suf773X/OgYTViUZFBUE4vEQeQEEQDouGF/8TN/gDaHrlJSIuZ5fu5/jiM+qf/ucRC/40xSVknn0B0Fplw7t1M95tW4gGgymvVRoSb7Bo41m7KmWbUHUVLR+936fBH7RWO3GvW9Ozm0Qi1D/9z9T5CAVBOCzEFLAgCH0uWF1FYO+ehOflcAjXyq+xnXhyWveTIxGa33y1t7rXLabps1Ho9TS9/jKOxR8R9bXmIFSYzGScdiYZp52V8FrjpClILz6DnCBYVFptRBz2tPrh27q5y33vjnBTA/rhI3v2vEgE55JPyb70it7rmCAI3SJGAAVB6HPpjO51ZQTQv3snEXsP1vtJCozjJ5F92ZUgSd26hezz0vjyC7S8+2Z78AcQdbtoeuW/NL/zZsJrlUYTmeddnOCkEtPU6Wn3I+p2oxsyLO328aiyslO2cX+9rFeCzUBleczXcjjc43sKgtB1IgAUBKHPqfPyW1OmJOHdvBH7x+8TSaMiRjpJjSWNJv5xixXrSScTdjlbg7S26VOVClVuHvo0S8IpLVYciz9OeL7lg3faN2nEk3HK6eTdcDPa0rL2Y/oRoyi87VdYZs5Jqw/QmpA658ofojCZ077mUJqSUlAk//cJVlV0+/4dKQ1G5HCY5vfeYt+tP2b3Nd9jz03X0PDSc0Tcrl55hiAIqYlNIIIgHBY1f38Yz5qVKdspDEYKbvkl+sFDE7aJeDzsu+VG5GD8NYU9YVmwEOenn
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "UFWr5nQo8YZQ"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Formatting the input data layer\n",
|
|||
|
|
"\n",
|
|||
|
|
"Instead of a 28 x 28 matrix, we build our network to accept a 784-length vector.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Each image needs to be then reshaped (or flattened) into a vector. We'll also normalize the inputs to be in the range [0-1] rather than [0-255]. Normalizing inputs is generally recommended, so that any additional dimensions (for other network architectures) are of the same scale."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "B-b0XGKd8YZQ"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src='https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/flatten.png' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "0ge5s-pT8YZR",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "7f0549bf-47ae-415e-daf3-28021468ccb1"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"X_train = X_train.reshape(60000, 784) # reshape 60,000 28 x 28 matrices into 60,000 784-length vectors.\n",
|
|||
|
|
"X_test = X_test.reshape(10000, 784) # reshape 10,000 28 x 28 matrices into 10,000 784-length vectors.\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train = X_train.astype('float32') # change integers to 32-bit floating point numbers\n",
|
|||
|
|
"X_test = X_test.astype('float32')\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train /= 255 # normalize each value for each pixel for the entire vector for each input\n",
|
|||
|
|
"X_test /= 255\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(\"Training matrix shape\", X_train.shape)\n",
|
|||
|
|
"print(\"Testing matrix shape\", X_test.shape)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Training matrix shape (60000, 784)\n",
|
|||
|
|
"Testing matrix shape (10000, 784)\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "sS8jHU6V8YZT"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"We then modify our classes (unique digits) to be in the one-hot (categorical) format, i.e.\n",
|
|||
|
|
"\n",
|
|||
|
|
"```\n",
|
|||
|
|
"0 -> [1, 0, 0, 0, 0, 0, 0, 0, 0]\n",
|
|||
|
|
"1 -> [0, 1, 0, 0, 0, 0, 0, 0, 0]\n",
|
|||
|
|
"2 -> [0, 0, 1, 0, 0, 0, 0, 0, 0]\n",
|
|||
|
|
"etc.\n",
|
|||
|
|
"```\n",
|
|||
|
|
"\n",
|
|||
|
|
"If the final output of our network is very close to one of these classes, then it is most likely that class. For example, if the final output is:\n",
|
|||
|
|
"\n",
|
|||
|
|
"```\n",
|
|||
|
|
"[0, 0.94, 0, 0, 0, 0, 0.06, 0, 0]\n",
|
|||
|
|
"```\n",
|
|||
|
|
"then it is most probable that the image is that of the digit `1`."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "AFvPkCFw8YZU"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"nb_classes = 10 # number of unique digits\n",
|
|||
|
|
"\n",
|
|||
|
|
"Y_train = to_categorical(y_train, nb_classes)\n",
|
|||
|
|
"Y_test = to_categorical(y_test, nb_classes)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "q_B8ZH5G8YZX"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Building a 3-layer fully connected network\n",
|
|||
|
|
"\n",
|
|||
|
|
"<img src=\"https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/figure.png\" />"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "M-uOWlRv8YZX"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The Sequential model is a linear stack of layers and is very common.\n",
|
|||
|
|
"\n",
|
|||
|
|
"model = Sequential()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "iyoUd9WE8YZa"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## The first hidden layer"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "IdgP4lai8YZb"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The first hidden layer is a set of 512 nodes (artificial neurons).\n",
|
|||
|
|
"# Each node will receive an element from each input vector and apply some weight and bias to it.\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.add(Dense(512, input_shape=(784,))) #(784,) is not a typo -- that represents a 784 length vector!"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "c1MiY3u68YZd"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# An \"activation\" is a non-linear function applied to the output of the layer above.\n",
|
|||
|
|
"# It checks the new value of the node, and decides whether that artifical neuron has fired.\n",
|
|||
|
|
"# The Rectified Linear Unit (ReLU) converts all negative inputs to nodes in the next layer to be zero.\n",
|
|||
|
|
"# Those inputs are then not considered to be fired.\n",
|
|||
|
|
"# Positive values of a node are unchanged.\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.add(Activation('relu'))"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "tQCRmiMI8YZg"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"$$f(x) = max (0,x)$$\n",
|
|||
|
|
"<img src = 'https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/relu.jpg' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "kudeX_kN8YZg"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Dropout zeroes a selection of random outputs (i.e., disables their activation)\n",
|
|||
|
|
"# Dropout helps protect the model from memorizing or \"overfitting\" the training data.\n",
|
|||
|
|
"model.add(Dropout(0.2))"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "PY5XiXYc8YZj"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Adding the second hidden layer"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Tnz-mpIS8YZj"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The second hidden layer appears identical to our first layer.\n",
|
|||
|
|
"# However, instead of each of the 512-node receiving 784-inputs from the input image data,\n",
|
|||
|
|
"# they receive 512 inputs from the output of the first 512-node layer.\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.add(Dense(512))\n",
|
|||
|
|
"model.add(Activation('relu'))\n",
|
|||
|
|
"model.add(Dropout(0.2))"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "vkeHSS0h8YZm"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## The Final Output Layer"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "cndhIgBK8YZm"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The final layer of 10 neurons in fully-connected to the previous 512-node layer.\n",
|
|||
|
|
"# The final layer of a FCN should be equal to the number of desired classes (10 in this case).\n",
|
|||
|
|
"model.add(Dense(10))"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "02Zk2VWL8YZp"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The \"softmax\" activation represents a probability distribution over K different possible outcomes.\n",
|
|||
|
|
"# Its values are all non-negative and sum to 1.\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.add(Activation('softmax'))"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "YReJLqWL8YZr",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "a16ece56-114d-41ff-9086-87549cd4a997"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Summarize the built model\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.summary()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Model: \"sequential\"\n",
|
|||
|
|
"_________________________________________________________________\n",
|
|||
|
|
" Layer (type) Output Shape Param # \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
" dense (Dense) (None, 512) 401920 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation (Activation) (None, 512) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dropout (Dropout) (None, 512) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_1 (Dense) (None, 512) 262656 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_1 (Activation) (None, 512) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dropout_1 (Dropout) (None, 512) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_2 (Dense) (None, 10) 5130 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_2 (Activation) (None, 10) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
"Total params: 669706 (2.55 MB)\n",
|
|||
|
|
"Trainable params: 669706 (2.55 MB)\n",
|
|||
|
|
"Non-trainable params: 0 (0.00 Byte)\n",
|
|||
|
|
"_________________________________________________________________\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "UcndsYyk8YZv"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Compiling the model\n",
|
|||
|
|
"\n",
|
|||
|
|
"Keras is built on top of Theano and TensorFlow. Both packages allow you to define a *computation graph* in Python, which then compiles and runs efficiently on the CPU or GPU without the overhead of the Python interpreter.\n",
|
|||
|
|
"\n",
|
|||
|
|
"When compiing a model, Keras asks you to specify your **loss function** and your **optimizer**. The loss function we'll use here is called *categorical cross-entropy*, and is a loss function well-suited to comparing two probability distributions.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Our predictions are probability distributions across the ten different digits (e.g. \"we're 80% confident this image is a 3, 10% sure it's an 8, 5% it's a 2, etc.\"), and the target is a probability distribution with 100% for the correct category, and 0 for everything else. The cross-entropy is a measure of how different your predicted distribution is from the target distribution. [More detail at Wikipedia](https://en.wikipedia.org/wiki/Cross_entropy)\n",
|
|||
|
|
"\n",
|
|||
|
|
"The optimizer helps determine how quickly the model learns through **gradient descent**. The rate at which descends a gradient is called the **learning rate**."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Wsv1Z9LC8YZw"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = \"https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/gradient_descent.png\" >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "vZWeBx-58YZx"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = \"https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/learning_rate.png\" >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "ptVA6aSg8YZy"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"So are smaller learning rates better? Not quite! It's important for an optimizer not to get stuck in local minima while neglecting the global minimum of the loss function. Sometimes that means trying a larger learning rate to jump out of a local minimum."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "zSGiajLT8YZ1"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = 'https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/complicated_loss_function.png' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "7RUmQxfp8YZ2"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Let's use the Adam optimizer for learning\n",
|
|||
|
|
"adam = tf.optimizers.Adam(learning_rate=0.001)\n",
|
|||
|
|
"model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "d-u8EqW48YZ3"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Train the model!\n",
|
|||
|
|
"This is the fun part!"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "rTynqX-H8YZ4"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"The batch size determines over how much data per step is used to compute the loss function, gradients, and back propagation. Large batch sizes allow the network to complete it's training faster; however, there are other factors beyond training speed to consider.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Too large of a batch size smoothes the local minima of the loss function, causing the optimizer to settle in one because it thinks it found the global minimum.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Too small of a batch size creates a very noisy loss function, and the optimizer may never find the global minimum.\n",
|
|||
|
|
"\n",
|
|||
|
|
"So a good batch size may take some trial and error to find!"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "MxhTSlBI8YZ5",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "7b40631d-07d2-4eea-e8b7-ded474fb8322"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model.fit(X_train, Y_train,\n",
|
|||
|
|
" batch_size=128, epochs=5,\n",
|
|||
|
|
" verbose=1)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/5\n",
|
|||
|
|
"469/469 [==============================] - 4s 3ms/step - loss: 0.2509 - accuracy: 0.9247\n",
|
|||
|
|
"Epoch 2/5\n",
|
|||
|
|
"469/469 [==============================] - 2s 3ms/step - loss: 0.1032 - accuracy: 0.9679\n",
|
|||
|
|
"Epoch 3/5\n",
|
|||
|
|
"469/469 [==============================] - 2s 3ms/step - loss: 0.0709 - accuracy: 0.9773\n",
|
|||
|
|
"Epoch 4/5\n",
|
|||
|
|
"469/469 [==============================] - 2s 4ms/step - loss: 0.0559 - accuracy: 0.9818\n",
|
|||
|
|
"Epoch 5/5\n",
|
|||
|
|
"469/469 [==============================] - 2s 4ms/step - loss: 0.0451 - accuracy: 0.9855\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<keras.src.callbacks.History at 0x7ac89b4ee020>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 19
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "YufwyCC38YZ6"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"The two numbers, in order, represent the value of the loss function of the network on the training set, and the overall accuracy of the network on the training data. But how does it do on data it did not train on?"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "wQORnbe08YZ7"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Evaluate Model's Accuracy on Test Data"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "SX3n6p108YZ7",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "0b1db3cb-ed8f-4422-fcc7-1b0394667dd1"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"score = model.evaluate(X_test, Y_test)\n",
|
|||
|
|
"print('Test score:', score[0])\n",
|
|||
|
|
"print('Test accuracy:', score[1])"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"313/313 [==============================] - 1s 2ms/step - loss: 0.0772 - accuracy: 0.9774\n",
|
|||
|
|
"Test score: 0.07723397016525269\n",
|
|||
|
|
"Test accuracy: 0.977400004863739\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "NYD0_NjusHRL"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"def plot_confusion_matrix(cm, classes,\n",
|
|||
|
|
" normalize=False,\n",
|
|||
|
|
" title='Confusion matrix',\n",
|
|||
|
|
" cmap=plt.cm.Blues):\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" This function prints and plots the confusion matrix.\n",
|
|||
|
|
" Normalization can be applied by setting `normalize=True`.\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" if normalize:\n",
|
|||
|
|
" cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
|
|||
|
|
" print(\"Normalized confusion matrix\")\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" print('Confusion matrix, without normalization')\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(cm)\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
|
|||
|
|
" plt.title(title)\n",
|
|||
|
|
" plt.colorbar()\n",
|
|||
|
|
" tick_marks = np.arange(len(classes))\n",
|
|||
|
|
" plt.xticks(tick_marks, classes, rotation=45)\n",
|
|||
|
|
" plt.yticks(tick_marks, classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
" fmt = '.2f' if normalize else 'd'\n",
|
|||
|
|
" thresh = cm.max() / 2.\n",
|
|||
|
|
" for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
|
|||
|
|
" plt.text(j, i, format(cm[i, j], fmt),\n",
|
|||
|
|
" horizontalalignment=\"center\",\n",
|
|||
|
|
" color=\"white\" if cm[i, j] > thresh else \"black\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.ylabel('True label')\n",
|
|||
|
|
" plt.xlabel('Predicted label')\n",
|
|||
|
|
" plt.tight_layout()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "yWaT13H78YZ9"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"### Inspecting the output\n",
|
|||
|
|
"\n",
|
|||
|
|
"It's always a good idea to inspect the output and make sure everything looks sane. Here we'll look at some examples it gets right, and some examples it gets wrong."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "LFECu0nE8YZ-",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"outputId": "869fc07a-0198-4d53-9257-88ac9b1ab938"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The predict_classes function outputs the highest probability class\n",
|
|||
|
|
"# according to the trained classifier for each input example.\n",
|
|||
|
|
"predict = model.predict(X_test)\n",
|
|||
|
|
"predicted_classes = np.argmax(predict,axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Check which items we got right / wrong\n",
|
|||
|
|
"correct_indices = np.nonzero(predicted_classes == y_test)[0]\n",
|
|||
|
|
"incorrect_indices = np.nonzero(predicted_classes != y_test)[0]\n",
|
|||
|
|
"cnf_matrix = confusion_matrix(y_test, predicted_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"class_names = [str(i) for i in range(10)]\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Plot non-normalized confusion matrix\n",
|
|||
|
|
"plt.figure()\n",
|
|||
|
|
"plot_confusion_matrix(cnf_matrix, classes=class_names,\n",
|
|||
|
|
" title='Confusion matrix, without normalization')\n",
|
|||
|
|
"\n",
|
|||
|
|
"plt.show()\n"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"313/313 [==============================] - 1s 2ms/step\n",
|
|||
|
|
"Confusion matrix, without normalization\n",
|
|||
|
|
"[[ 968 0 1 1 0 1 3 0 2 4]\n",
|
|||
|
|
" [ 0 1127 2 1 0 0 3 0 2 0]\n",
|
|||
|
|
" [ 1 4 1009 5 1 0 2 4 5 1]\n",
|
|||
|
|
" [ 0 0 5 996 0 3 0 1 1 4]\n",
|
|||
|
|
" [ 0 0 0 0 958 0 9 1 2 12]\n",
|
|||
|
|
" [ 2 0 0 15 1 863 1 1 4 5]\n",
|
|||
|
|
" [ 4 2 1 1 3 4 942 0 1 0]\n",
|
|||
|
|
" [ 0 13 11 2 1 0 0 982 4 15]\n",
|
|||
|
|
" [ 0 1 2 10 4 3 0 3 938 13]\n",
|
|||
|
|
" [ 0 4 0 3 5 2 1 1 2 991]]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 2 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1UAAAN6CAYAAABmBWMlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC8H0lEQVR4nOzdd1yV9f/G8euAMgRZKiBO1HKkaM5w5ExyW2ZZWmhmy5GZZjZMzZz9yjRnQ600V5kj90Rzz0zNbZoKOEFwoHB+f/D16ElN8Abuc/D17HE/8tznPofrvs99DufNZ9wWq9VqFQAAAADgvriYHQAAAAAAnBlFFQAAAAAYQFEFAAAAAAZQVAEAAACAARRVAAAAAGAARRUAAAAAGEBRBQAAAAAGUFQBAAAAgAE5zA4AAAAA4P5cuXJFSUlJZsdIEzc3N3l4eJgdI1NQVAEAAABO6MqVK/LMnUe6fsnsKGkSHBysI0eOZMvCiqIKAAAAcEJJSUnS9UtyLxMpubqZHee/JScpes9kJSUlUVQBAAAAcDCubrI4eFFlNTtAJqOoAgAAAJyZxSV1cWSOns+g7L13AAAAAJDJKKoAAAAAwACKKgAAAAAwgDFVAAAAgDOzSLJYzE7x3xw8nlG0VAEAAACAARRVAAAAAGAA3f8AAAAAZ8aU6qbL3nsHAAAAAJmMogoAAAAADKD7HwAAAODMLBYnmP3PwfMZREsVAAAAABhAUQUAAAAABlBUAQAAAIABjKkCAAAAnBlTqpsue+8dAAAAAGQyiioAAAAAMIDufwAAAIAzY0p109FSBQAAAAAGUFQBAAAAgAEUVQAAAABgAGOqAAAAAKfmBFOqZ/O2nOy9dwAAAACQySiqAAAAAMAAuv8BAAAAzowp1U1HSxUAAAAAGEBRBQAAAAAG0P0PAAAAcGYWJ5j9z9HzGZS99w4AAAAAMhlFFQAAAAAYQFEFAAAAAAYwpgoAAABwZkypbjpaqgAAAADAAIoqAAAAADCA7n8AAACAM2NKddNl770DAAAAgExGUQUAAAAABtD9DwAAAHBmzP5nOlqqAAAAAMAAiioAAAAAMICiCgAAAAAMYEwVAAAA4MyYUt102XvvAAAAACCTUVQBAAAAgAF0/wMAAACcmcXi+N3rmFIdAAAAAHA3FFUAAAAAYADd/wAAAABn5mJJXRyZo+cziJYqAAAAADCAogoAAAAADKCoAgAAAAADGFMFAAAAODOLixNMqe7g+QzK3nsHAAAAAJmMogoAAAAADKD7HwAAAODMLJbUxZE5ej6DaKkCAAAAAAMoqgAAAADAAIoqAAAAADCAMVUAAACAM2NKddNl770DAAAAgExGUQUAAAAABtD9DwAAAHBmTKluOlqqAAAAAMAAiioAAAAAMIDufwAAAIAzY/Y/02XvvQMAAACATEZRBQAAAAAGUFQBAAAAgAGMqQIAAACcGVOqm46WKgAAAAAwgKIKAAAAAAyg+x8AAADgzJhS3XTZe+8AAAAAIJNRVAHZ2IEDB9SwYUP5+vrKYrHo119/zdDnP3r0qCwWiyZNmpShz5sdFC1aVO3btzc7xm3S85rd2Pazzz7L/GC4o379+snyr8HdZp1bjnpOA4AjoKgCMtmhQ4f02muvqVixYvLw8JCPj49q1KihL7/8UpcvX87Unx0ZGaldu3bp008/1Q8//KDKlStn6s/Ljvbs2aN+/frp6NGjZkfJNAsWLFC/fv3MjnGbQYMGZfgfAvDf1q1bp379+unChQtmRwGQHjdm/3P0JRtjTBWQiX777Te1bt1a7u7ueumll1S2bFklJSVp7dq16tWrl3bv3q0JEyZkys++fPmy1q9frw8++EBdunTJlJ9RpEgRXb58WTlz5syU53cEe/bsUf/+/VWnTh0VLVo0zY/bt2+fXFwc7+9Wd3rNFixYoNGjRztcYTVo0CA988wzatmypdlRHEpmnlvr1q1T//791b59e/n5+WXZzwUAZ0dRBWSSI0eOqE2bNipSpIhWrFih/Pnz2+7r3LmzDh48qN9++y3Tfv7p06cl6bYvRhnJYrHIw8Mj057f2VitVl25ckWenp5yd3c3O84d8ZoZk5iYKC8vL1MzmHVuOeo5DQCOgD85AZlk2LBhSkhI0LfffmtXUN1QokQJvfXWW7bb169f1yeffKLixYvL3d1dRYsW1fvvv6+rV6/aPa5o0aJq2rSp1q5dq6pVq8rDw0PFihXT999/b9umX79+KlKkiCSpV69eslgstlaW9u3b37HF5U5jN5YuXaqaNWvKz89P3t7eKlmypN5//33b/Xcbn7NixQrVqlVLXl5e8vPzU4sWLbR37947/ryDBw/a/iru6+urDh066NKlS3c/sP9Tp04dlS1bVn/88Ydq166tXLlyqUSJEpo1a5YkafXq1apWrZo8PT1VsmRJLVu2zO7xf//9t958802VLFlSnp6eypMnj1q3bm3XzW/SpElq3bq1JKlu3bqyWCyyWCxatWqVpJuvxeLFi1W5cmV5enpq/PjxtvtujD+xWq2qW7eu8uXLp9jYWNvzJyUlqVy5cipevLgSExPvuc+36tGjh/LkySOr1Wpb17VrV1ksFo0cOdK2LiYmRhaLRWPHjpV0+2vWvn17jR49WpJs+/fv80CSJkyYYDs3q1Spos2bN9+2TVpe97SefxaLRYmJiZo8ebIt03+N51m1apUsFotmzJihTz/9VAULFpSHh4fq16+vgwcP3rb9zJkzValSJXl6eipv3rxq166dTpw4cVtWb29vHTp0SI0bN1bu3LnVtm1bW74uXbpo5syZKlOmjDw9PRUeHq5du3ZJksaPH68SJUrIw8NDderUua376Jo1a9S6dWsVLlxY7u7uKlSokN5+++00dQn+99imW1+3fy83fu4ff/yh9u3b27ohBwcH6+WXX9bZs2ftXoNevXpJkkJDQ297jjuNqTp8+LBat26tgIAA5cqVS4899thtfyxK72sDAM6Iliogk8ybN0/FihVT9erV07T9K6+8osmTJ+uZZ57RO++8o40bN2rw4MHau3evZs+ebbftwYMH9cwzz6hjx46KjIzUd999p/bt26tSpUp65JFH9PTTT8vPz09vv/22nn/+eTVu3Fje3t7pyr979241bdpUYWFhGjBggNzd3XXw4EH9/vvv//m4ZcuWqVGjRipWrJj69euny5cva9SoUapRo4a2bdt22xfqZ599VqGhoRo8eLC2bdumb775RoGBgRo6dOg9M54/f15NmzZVmzZt1Lp1a40dO1Zt2rTRlClT1L17d73++ut64YUXNHz4cD3zzDM6fvy4cufOLUnavHmz1q1bpzZt2qhgwYI6evSoxo4dqzp16mjPnj3KlSuXHn/8cXXr1k0jR47U+++/r9KlS0uS7f9Sapeo559/Xq+99po6deqkkiVL3pbTYrHou+++U1hYmF5//XX98ssvkqSPP/5Yu3fv1qpVq9Ld+lGrVi198cUX2r17t8qWLSsp9Yu6i4uL1qxZo27dutnWSdLjjz9+x+d57bXXdPLkSS1dulQ//PDDHbeZOnWqLl68qNdee00Wi0XDhg3T008/rcOHD9u6Eab3db+XH374Qa+88oqqVq2qV199VZJUvHjxez5uyJAhcnFxUc+ePRUXF6dhw4apbdu22rhxo22bSZMmqUOHDqpSpYoGDx6smJgYffnll/r999+1fft2u9bd69evKyIiQjVr1tRnn32mXLly2e5bs2aN5s6dq86dO0uSBg8erKZNm+rdd9/VmDFj9Oabb+r8+fMaNmyYXn75Za1YscL22JkzZ+rSpUt64403lCdPHm3atEmjRo3SP//8o5kzZ6b7WP3bhx9+qNjYWNv7funSpTp8+LA6dOig4OBgW9fj3bt3a8OGDbJYLHr66ae1f/9+/fTTT/riiy+UN29eSVK+fPnu+HNjYmJUvXp1Xbp0Sd26dVOePHk0efJkNW/eXLNmzdJTTz2V7tcGwP1yginVs3tbjhVAhouLi7NKsrZo0SJN2+/YscMqyfrKK6/Yre/Zs6dVknXFihW2dUWKFLFKskZFRdnWxcbGW
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"source": [],
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "d4tL55qnBfAA"
|
|||
|
|
},
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Ct14A0rsD-Ja"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"**Correct predictions**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "ew6oWSoY8YaA",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 906
|
|||
|
|
},
|
|||
|
|
"outputId": "5dc0a7b5-0c31-48c8-acc3-0f0824f899f9"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"def show_samples(indices, preds, images, labels, count=3, names = []):\n",
|
|||
|
|
" plt.figure()\n",
|
|||
|
|
" for i, sample in enumerate(indices[:count**2]):\n",
|
|||
|
|
" pred_id = int(np.argmax(preds[sample]))\n",
|
|||
|
|
" real_id = int(labels[sample])\n",
|
|||
|
|
" pred_score = preds[sample][pred_id]\n",
|
|||
|
|
" real_score = preds[sample][real_id]\n",
|
|||
|
|
" plt.subplot(count,count,i+1)\n",
|
|||
|
|
" plt.imshow(images[sample].reshape(28,28), cmap='gray', interpolation='none')\n",
|
|||
|
|
" plt.axis('off')\n",
|
|||
|
|
" if len(names) > 0:\n",
|
|||
|
|
" plt.title(\"P: {} ({:.2f})\\nE: {} ({:.2f})\".format(names[pred_id], pred_score, names[real_id], real_score))\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" plt.title(\"P: {} ({:.2f})\\nE: {} ({:.2f})\".format(pred_id, pred_score, real_id, real_score))\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.tight_layout()\n",
|
|||
|
|
"\n",
|
|||
|
|
"show_samples(correct_indices, predict, X_test, y_test, 5)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 25 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAN5CAYAAAA/32uUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADzNUlEQVR4nOzdd3hUVfrA8XdILwIhdAih944g1dBFmiIBUVBWYVkBGxaK4qIUV2AFy1KkCIrggiiCBRQQBURBmogQpQRCh1BCCZCQnN8fbubHuXcydyaZyWTg+3mePD7vmXPPvJE3N/dk7rnHppRSAgAAAADIVgFfJwAAAAAA+R0TJwAAAACwwMQJAAAAACwwcQIAAAAAC0ycAAAAAMACEycAAAAAsMDECQAAAAAsMHECAAAAAAtMnAAAAADAAhOn/xkyZIh06NDB12k4tWrVKomMjJQzZ874OhXkQ5MmTZLq1atLZmamr1PJ1p49eyQwMFB2797t61SQD1HD8HdcS8DfUcMWVB6aN2+eEhH7V0hIiKpSpYoaOnSoOnnyZI7GHDNmjDam8Wvjxo2WYxw8eFAFBQWp7777TmufPn26io+PVzExMUpEVP/+/d3KLSMjQ02cOFGVL19ehYSEqDp16qhFixY57Ltnzx51zz33qIiICBUVFaX69eunTp8+bepXr149NWzYMLfygOd4o4b37t2rXnzxRVWvXj0VGRmpSpYsqTp37qx++eUXl8dISUlRRYoUUe+//77W/t///lf17dtXVa5cWYmIiouLczu/OXPmqOrVq6uQkBBVuXJl9c477zjsd/ToUdWrVy9VqFAhdccdd6ju3burAwcOmPp1795d9ejRw+084BneqGGllBo/frzq1q2bKl68uBIRNWbMGLeOp4bhKm/VsDu/sx3hWgKu8lYN3+yjjz5SIqIiIiJcPoYatuaTidPYsWPVggUL1OzZs1X//v1VgQIFVIUKFdSVK1fcHvPXX39VCxYsMH3FxMSoqKgodf36dcsxnnnmGVW1alVTe2xsrCpSpIjq1KmTCgwMdLtQRo4cqURE/f3vf1ezZs1SXbp0USKiPv74Y63fkSNHVNGiRVWlSpXU22+/rSZMmKCioqJUvXr1TPlPnz5dhYeHq4sXL7qVCzzDGzX8/PPPq8KFC6sBAwao9957T02aNElVqlRJBQQEqNWrV7s0xtSpU1XBggXV1atXtfa4uDgVGRmp2rRpo6Kioty+6Jw5c6YSEdWzZ081a9Ys9cgjjygRUW+88YbW79KlS6pKlSqqePHiauLEiWrKlCkqJiZGlS1bViUnJ2t9v/76ayUiav/+/W7lAs/wRg0rpZSIqJIlS6p77rknRxMnahiu8lYNu/o7OztcS8BV3qrhLJcuXVKlS5dWERERbk2cqGFrPpk4Gf+S/txzzykRcesvO84kJSUpm82m/v73v1v2TUtLU0WLFlWjR482vXbo0CGVmZmplFIqIiLCrUI5evSoCgoKUkOHDrW3ZWZmqlatWqmyZcuqGzdu2NsHDx6swsLC1OHDh+1tq1evViKi3nvvPW3cU6dOqYCAADV37lyXc4HneKOGt27dqi5duqS1JScnq2LFiqkWLVq4NEbdunVVv379TO1JSUkqIyNDKaVUrVq13LroTE1NVdHR0apLly5ae9++fVVERIQ6d+6cvW3ixIlKRNSWLVvsbXv37lUBAQFq1KhR2vFpaWkqKipKvfLKKy7nAs/x1nk4MTFRKaXUmTNncjRxoobhKm/UsDu/sx3hWgLu8Pb18IgRI1S1atXs5zpXUMOuyRdrnNq2bSsiIomJifa2AwcOyIEDB3I03scffyxKKenbt69l340bN0pycrK0b9/e9FpsbKzYbLYc5bB8+XJJT0+XIUOG2NtsNpsMHjxYjh49Kj/99JO9/dNPP5WuXbtKuXLl7G3t27eXqlWrypIlS7RxixcvLnXr1pXly5fnKC94R25quFGjRhIZGam1RUdHS6tWrWTv3r2WxycmJsquXbsc1nBMTIwUKJCzH/N169bJ2bNntRoWERk6dKhcuXJFvvrqK3vb0qVLpXHjxtK4cWN7W/Xq1aVdu3amGg4KCpLWrVtTw/lMbs/D5cuXz/F7U8PwhNzUsDu/sx3hWgKe4Inr4X379snUqVNlypQpEhgY6PJx1LBr8sXEKasgoqOj7W3t2rWTdu3a5Wi8hQsXSkxMjNx9992WfTdt2iQ2m00aNGiQo/fKzo4dOyQiIkJq1KihtTdp0sT+uojIsWPH5PTp03LnnXeaxmjSpIm9380aNWokmzZt8mi+yB1P17CIyMmTJ6Vo0aKW/bJqoWHDhjl+L0eyas9Ym40aNZICBQrYX8/MzJRdu3ZlW8MHDhyQS5cumcbYvXu3XLx40aM5I+e8UcOuoobhCbmpYVd/Z2eHawl4gifOw88++6y0adNGOnfu7NZ7U8Ou8cnEKSUlRZKTk+Xo0aOyePFiGTt2rISFhUnXrl1zPfbvv/8uu3btkoceesil2XFCQoIUKVJEChYsmOv3vtmJEyekRIkSphxKlSolIiLHjx+397u53dj33Llzcv36da29YsWKkpycLKdPn/ZoznCdN2tYRGTDhg3y008/yYMPPmjZNyEhQUREKlSo4JH3znLixAkJCAiQ4sWLa+3BwcESHR1tr+GsGs2uhkX+v96zVKxYUTIzM+25I+95u4bdQQ0jJzxZw67+zs4O1xLICU+fh7/66iv59ttvZcqUKW4fSw27xvXP8DzI+DFgbGysLFy4UMqUKWNvO3ToUI7GXrhwoYiIS7fpiYicPXtWoqKicvRezly9elVCQkJM7aGhofbXb/6vVd+bX8/KNzk52XRBgLzhzRo+ffq0PPzww1KhQgUZPny4Zf+zZ89KYGCg6Xa/3Lp69aoEBwc7fC00NNTtGr7ZzTUM3/BmDbuLGkZOeLKGXf2dnR2uJZATnqzhtLQ0GTZsmDzxxBNSs2ZNt3Ohhl3jk4nTtGnTpGrVqhIYGCglSpSQatWq5fge9psppWTRokVSu3ZtqVu3rlvHeVpYWJhpZiwicu3aNfvrN//Xlb5ZsvLN6f2myD1v1fCVK1eka9eucunSJdm4caPHLyTdERYWJmlpaQ5fu3btGjXs57xVw/kJNXxr82QNu/o72xmuJeAuT9bw1KlTJTk5WV577bUc50MNW/PJxKlJkyYO72HMrR9//FEOHz4s//rXv1w+Jjo6Ws6fP+/xXEqVKiXr1q0TpZT2D5r1UWTp0qXt/W5uv9mJEyekSJEiptl3Vr6urH+Bd3ijhtPS0uSBBx6QXbt2yTfffCO1a9d26bjo6Gi5ceOGXLp0Se644w6P5VOqVCnJyMiQ06dPa3/JSUtLk7Nnz9prOKtGs6thkf+v9yzUsO956zycE9QwcsKTNezq7+zscC2BnPBUDaekpMj48eNlyJAhcvHiRfvay8uXL4tSSg4dOiTh4eFOP5Whhl1zS/15ceHChWKz2eThhx92+Zjq1avL+fPnJSUlxaO51K9fX1JTU01PRdu8ebP9dRGRMmXKSLFixWTr1q2mMbZs2WLvd7PExEQpWrSoFCtWzKM5w3cyMzPl0UcflbVr18qiRYskLi7O5WOrV68uIvpTeDwhq/aMtbl161bJzMy0v16gQAGpU6eOwxrevHmzVKxY0XQxnJiYKAUKFJCqVat6NGf4J2oYvubq7+zscC0BXzp//rxcvnxZJk2aJBUqVLB/ffrpp5KamioVKlSQQYMGOR2DGnZNvp04ufv4xfT0dPnkk0+kZcuW2mMMrTRr1kyUUrJt27acpCkif830ExIStGK77777JCgoSKZPn25vU0rJzJkzpUyZMtK8eXN7e8+ePeXLL7+UI0eO2NvWrl0rf/75p/Tq1cv0ftu2bZNmzZrlOF/kDXdq+KmnnpLFixfL9OnT5YEHHnDrfbJqwdHJxlWpqamSkJCgrddo27atFClSRGbMmKH1nTFjhoSHh0uXLl3sbfHx8fLLL79oOfzxxx/y3XffZVvDtWrVkkKFCuU4Z3hfbraFcAc1DG9xtYbd+Z3tCNcS8BZXarh48
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "fhZp92ZxEGn4"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"**Wrong predictions**"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "FrAIwG1CEFxe",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 906
|
|||
|
|
},
|
|||
|
|
"outputId": "a022759d-a12d-4fa7-d021-26127237c618"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples(incorrect_indices, predict, X_test, y_test, 5)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 25 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAN5CAYAAAA/32uUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXxMV/8H8E/2VchCxBaxJHYiqlTt+1JaQqu02lpqqa0L1VJPEUo9FLXro61SWsuDtrSotrFUCYqS2vc1QcgiIfn+/ugv8zj3TuZOmDGJfN6vV16vfs+cc+4Z/ebmnpl77nESEQERERERERHlyNnRAyAiIiIiIsrrOHEiIiIiIiIywIkTERERERGRAU6ciIiIiIiIDHDiREREREREZIATJyIiIiIiIgOcOBERERERERngxImIiIiIiMgAJ05EREREREQGOHH6fwMHDkTLli0dPQyLNm7cCF9fX1y7ds3RQ6E8KD/k8Lx581CmTBmkp6c7eiiUB+WHHOZ5mCxhDlN+lx9y2KHXEvIILV68WACYfjw8PKRixYoyaNAguXz58gP3e+zYMenSpYsUKVJEvLy8pEGDBvLzzz9b3f7kyZPi5uZmts2iRYukUqVK4uHhIRUqVJCZM2fmamxxcXHyzDPPiL+/v3h5eUnVqlVlxowZSp0ff/xRXnvtNalatao4OztLaGhojv3VrFlThg8fnqsxkO3YK4cvXrwoffv2lbJly4qnp6eUK1dOhg8fLgkJCVa1d3QOZ2Zmyty5c6VmzZri4+MjxYoVkzZt2sj27duVemlpaRIcHKxrT49OQTsPHz16VJ5//nkpWbKkeHl5SUREhHz44YeSkpKi1IuJiZEnn3xSgoKCTMcZOnSoXL16Vdcnz8OOZa8cFhE5fvy4dO/eXYoWLSqenp5SoUIFee+996xq6+jzMHM4/7BHDo8dO1bpU/uzbds2wz7skcNbt27NcUw7d+5U6jZu3NhsvdatWyv1HHkt4Wr3mZkZ48aNQ1hYGO7cuYNt27Zh7ty5+OGHH3Do0CF4e3vnqq9z586hfv36cHFxwTvvvAMfHx8sXrwYrVq1wpYtW9CoUSPDPmbMmIGwsDA0bdpUKZ8/fz769++PLl264M0330RsbCyGDBmC1NRUjBw50rDfn376Cc888wwiIyMxZswY+Pr64sSJEzh//rxSb9myZVixYgVq166NEiVKWOzz9ddfx9tvv40PP/wQhQoVMhwD2Yctczg5ORn169dHSkoKBg4ciNKlS+PPP//Ep59+iq1btyIuLg7Ozpa/HHZ0Dr/zzjuYNm0aevbsiYEDB+LmzZuYP38+GjdujO3bt6Nu3boAAE9PT/Tq1QvTpk3D4MGD4eTklKt/K7KdgnAePnfuHOrWrYvChQvjjTfeQEBAAHbu3ImxY8ciLi4Oa9euNdWNi4tDrVq18MILL6BQoUI4cuQIFi5ciO+//x779++Hj4+PqS7Pw3mDLXMYAPbv348mTZqgZMmSeOuttxAYGIizZ8/i3LlzVrV39HmYOZz/2DKHO3fujAoVKujK33vvPSQnJ+OJJ54w7MNeOQwAQ4YM0Y3B3HhLlSqFSZMmKWXaa2OHXks8ylla9gx79+7dSvmbb74pAGTZsmW57nPgwIHi6uoq8fHxprKUlBQpXbq01K5d27B9RkaGBAUFyejRo5Xy1NRUCQwMlPbt2yvlPXr0EB8fH7l+/brFfpOSkiQ4OFiee+45yczMtFj3woULkpGRISIi7du3t/iN05UrV8TFxUU+++wzi32Sfdgjh5cuXSoA5LvvvlPKP/jgAwEge/futdje0Tl89+5d8fLykujoaKX85MmTAkCGDBmilO/Zs0cAyJYtWywen+yjIJ2HY2JiBIAcOnRIKX/55ZcFgGH7lStXCgD5+uuvlXKehx3LHjmcmZkp1apVkyeffFJSU1Nz3d7R5+GcMIfzJnvksDlnz54VJycn6du3r2Fde+Vw9jdO3377reEYGjduLFWrVjWsJ+K4a4k8scapWbNmAIBTp06Zyk6cOIETJ04Yto2NjUVkZCQiIiJMZd7e3ujYsSP27t2LY8eOWWy/bds2JCQkoEWLFkr51q1bkZiYiIEDByrlgwYNQkpKCr7//nuL/S5btgxXrlxBTEwMnJ2dkZKSgqysLLN1S5QoATc3N4v9ZStWrBhq1KihfFJKjvcwOXzr1i0AQHBwsFIeEhICAPDy8rLY3tE5fPfuXaSlpenGX6xYMTg7O+vGHxUVhYCAAOZwHvM4noct/W45OzvD3d3dYvuyZcsCAG7evKmU8zycNz1MDv/00084dOgQxo4dCy8vL6SmpiIzM9PqYzv6PJwT5nD+8jA5bM7XX38NEUGPHj0M69orh+93+/Zt3Lt3z7DevXv3kJycbLGOo64l8sTEKTshAgMDTWXNmzdH8+bNDdump6ebvbDM/oozLi7OYvsdO3bAyckJkZGRSvm+ffsAAHXq1FHKo6Ki4OzsbHo9J5s3b4afnx8uXLiAiIgI+Pr6ws/PDwMGDMCdO3cM35clUVFR2LFjx0P1Qbb1MDncqFEjODs7Y+jQofj9999x/vx5/PDDD4iJicGzzz6LSpUqWWzv6Bz28vLCk08+ic8//xxLly7F2bNnceDAAbzyyivw9/dHv379dH3Xrl0b27dvt3h8erQex/NwkyZNAAC9e/fG/v37ce7cOaxYsQJz587FkCFDlFuXAEBEkJCQgMuXL5tuRXFxcTH1ox0Dz8N5y8Pk8ObNmwEAHh4eqFOnDnx8fODt7Y0XXngB169fN2zv6PNwNuZw/vYwOWzO0qVLUbp0aatul7ZXDmd79dVX4efnB09PTzRt2hR79uwxW+/o0aPw8fFBoUKFULx4cYwZMwZ37941W9cR1xIOWeOUlJSEhIQE3LlzB9u3b8e4cePg5eWFDh065LqviIgIxMbG4vbt28p9utu2bQMAXLhwwWL7+Ph4BAQEwM/PTym/dOkSXFxcUKxYMaXc3d0dgYGBuHjxosV+jx07hnv37qFTp07o3bs3Jk2ahF9++QWzZs3CzZs38fXXX+fmbSrKlSuHhIQEXL16VTc+ejRsmcNVqlTBggUL8Pbbb6N+/fqm8l69emHRokWG7fNCDn/11Vd4/vnn0bNnT1NZuXLlsH37dpQrV07Xd7ly5bBkyRLD90b2UxDOw23atMH48eMxceJErFu3zlT+/vvvY8KECbr6V65cMX3TC/xzr/2yZcvMfnjB87Dj2TKHs78V7datG9q0aYNRo0bhzz//xKRJk3Du3Dls27bN4jqKvHAeBpjD+Y0tc1jrr7/+woEDBzBixAir1gDZK4fd3d3RpUsXtGvXDkFBQTh8+DCmTp2Khg0bYseOHcpErXz58mjatCmqV6+OlJQUrFy5EhMmTMDRo0exYsUKXd+OuJZwyMRJ+zVgaGgoli5dipIlS5rKTp8+bVVfAwYMwPr16/H8888jJiYGPj4+mDNnjmkmm5aWZrF9YmIi/P39deVpaWk53sbh6elp2G9ycjJSU1PRv39/zJw5E8A/C/cyMjIwf/58jBs3DhUrVrTmLepkjzchIYEnOwexZQ4DQMmSJVG3bl20a9cOoaGhiI2NxcyZMxEUFISpU6dabJsXcrhQoUKoWrUq6tevj+bNm+Py5cv46KOP8OyzzyI2NhZBQUFK3/7+/khLS0NqauoDLeKmh1cQzsPAP7cqNWrUCF26dEFgYCC+//57TJw4EcWLF8cbb7yh1A0ICMCmTZtw584d7Nu3D6tXr87xdhGehx3Pljmc/f/5iSeewFdffQUA6NKlC7y9vTFq1Chs2bJFd7z75YXzMMAczm9sfS1xv6VLlwKAVbfpAfbL4aeeegpPPfWUKe7YsSOio6NRo0YNjBo1Chs3bjS99tlnnyltX3rpJfTr1w8LFy7E8OHDUa9ePeV1R1xLOGTiNHv2bISHh8PV1RXBwcGIiIgwfGpYTtq2bYtZs2bh3XffRe3atQH885SOmJgYjBgxAr6+voZ9iIiuzMvLCxkZGWbr37lzx3DdSfbr3bt3V8pffPFFzJ8/Hzt37nzgiVP2e
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "6ZzALozi8YaD"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Try experimenting with the batch size!\n",
|
|||
|
|
"\n",
|
|||
|
|
" * How does increasing the batch size to 10,000 affect the training time and test accuracy?\n",
|
|||
|
|
" * How about a batch size of 32?\n",
|
|||
|
|
" * Is there any difference in results between the students? If so - why?\n",
|
|||
|
|
" * Experiment with the learning rate in the optimizer"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "uPUgO1BY8YaD"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Introducing Convolution! What is it?"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "vs75ypBn8YaE"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Before, we built a network that accepts the normalized pixel values of each value and operates soley on those values. What if we could instead feed different features (e.g. **curvature, edges**) of each image into a network, and have the network learn which features are important for classifying an image?\n",
|
|||
|
|
"\n",
|
|||
|
|
"This possible through convolution! Convolution applies **kernels** (filters) that traverse through each image and generate **feature maps**."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "eVIKBobD8YaF"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = 'https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/convolution.gif' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "gYJDfK4p8YaG"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"In the above example, the image is a 5 x 5 matrix and the kernel going over it is a 3 x 3 matrix. A dot product operation takes place between the image and the kernel and the convolved feature is generated. Each kernel in a CNN learns a different characteristic of an image.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Kernels are often used in photoediting software to apply blurring, edge detection, sharpening, etc."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "qm9CKpnr8YaH"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = 'https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/kernels.png' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "QdHIoa-E8YaH"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Kernels in deep learning networks are used in similar ways, i.e. highlighting some feature. Combined with a system called **max pooling**, the non-highlighted elements are discarded from each feature map, leaving only the features of interest, reducing the number of learned parameters, and decreasing the computational cost (e.g. system memory)."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "ceCEMpHf8YaI"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = 'https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/max_pooling.png' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "poGu943u8YaJ"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"We can also take convolutions of convolutions -- we can stack as many convolutions as we want, as long as there are enough pixels to fit a kernel.\n",
|
|||
|
|
"\n",
|
|||
|
|
"*Warning: What you may find down there in those deep convolutions may not appear recognizable to you.*"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "u7KmEU3m8YaJ"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"<img src = 'https://github.com/wut-mpg/keras-mnist-tutorial/raw/master/go_deeper.jpg' >"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "h88R-Bsa8YaK"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Building a \"Deep\" Convolutional Neural Network"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "qoXIbzAG8YaL"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# import some additional tools\n",
|
|||
|
|
"\n",
|
|||
|
|
"from keras.preprocessing.image import ImageDataGenerator\n",
|
|||
|
|
"from keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D, GlobalAveragePooling2D, Flatten\n",
|
|||
|
|
"from keras.layers import BatchNormalization"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "w0MBTlCL8YaM"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Reload the MNIST data\n",
|
|||
|
|
"(X_train, y_train), (X_test, y_test) = mnist.load_data()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "TbBjjLil8YaO",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "3384e000-0c87-4357-a3fb-d6e613e5a883"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Again, do some formatting\n",
|
|||
|
|
"# Except we do not flatten each image into a 784-length vector because we want to perform convolutions first\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train = X_train.reshape(60000, 28, 28, 1) #add an additional dimension to represent the single-channel\n",
|
|||
|
|
"X_test = X_test.reshape(10000, 28, 28, 1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train = X_train.astype('float32') # change integers to 32-bit floating point numbers\n",
|
|||
|
|
"X_test = X_test.astype('float32')\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train /= 255 # normalize each value for each pixel for the entire vector for each input\n",
|
|||
|
|
"X_test /= 255\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(\"Training matrix shape\", X_train.shape)\n",
|
|||
|
|
"print(\"Testing matrix shape\", X_test.shape)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Training matrix shape (60000, 28, 28, 1)\n",
|
|||
|
|
"Testing matrix shape (10000, 28, 28, 1)\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "aNNglEKV8YaR"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# one-hot format classes\n",
|
|||
|
|
"\n",
|
|||
|
|
"nb_classes = 10 # number of unique digits\n",
|
|||
|
|
"\n",
|
|||
|
|
"Y_train = to_categorical(y_train, nb_classes)\n",
|
|||
|
|
"Y_test = to_categorical(y_test, nb_classes)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "buX5gLwP8YaT"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model = Sequential() # Linear stacking of layers\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Convolution Layer 1\n",
|
|||
|
|
"model.add(Conv2D(16, (5, 5), input_shape=(28,28,1))) # 16 different 5x5 kernels -- so 16 feature maps\n",
|
|||
|
|
"model.add(Activation('relu') ) # activation\n",
|
|||
|
|
"model.add(MaxPooling2D(pool_size=(2,2))) # Pool the max values over a 2x2 kernel\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Convolution Layer 2\n",
|
|||
|
|
"model.add(Conv2D(32, (5, 5))) # 32 different 5x5 kernels -- so 32 feature maps\n",
|
|||
|
|
"model.add(Activation('relu')) # activation\n",
|
|||
|
|
"model.add(MaxPooling2D(pool_size=(2,2))) # Pool the max values over a 2x2 kernel\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.add(Flatten()) # Flatten final output matrix into a vector\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Fully Connected Layer\n",
|
|||
|
|
"model.add(Dense(128)) # 128 FC nodes\n",
|
|||
|
|
"model.add(Activation('relu')) # activation\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Fully Connected Layer\n",
|
|||
|
|
"model.add(Dense(10)) # final 10 FC nodes\n",
|
|||
|
|
"model.add(Activation('softmax')) # softmax activation"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "srtd-OZV8YaV",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "f072bbeb-8979-4b0f-da8c-cb2fd0fffe36"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model.summary()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Model: \"sequential_1\"\n",
|
|||
|
|
"_________________________________________________________________\n",
|
|||
|
|
" Layer (type) Output Shape Param # \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
" conv2d (Conv2D) (None, 24, 24, 16) 416 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_3 (Activation) (None, 24, 24, 16) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d (MaxPooling2 (None, 12, 12, 16) 0 \n",
|
|||
|
|
" D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_1 (Conv2D) (None, 8, 8, 32) 12832 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_4 (Activation) (None, 8, 8, 32) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_1 (MaxPoolin (None, 4, 4, 32) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" flatten (Flatten) (None, 512) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_3 (Dense) (None, 128) 65664 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_5 (Activation) (None, 128) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_4 (Dense) (None, 10) 1290 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_6 (Activation) (None, 10) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
"Total params: 80202 (313.29 KB)\n",
|
|||
|
|
"Trainable params: 80202 (313.29 KB)\n",
|
|||
|
|
"Non-trainable params: 0 (0.00 Byte)\n",
|
|||
|
|
"_________________________________________________________________\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "TJ5vnJet8YaX"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# we'll use the same optimizer\n",
|
|||
|
|
"adam = tf.optimizers.Adam(learning_rate=0.001)\n",
|
|||
|
|
"model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "WShqV9h28Yaa"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# data augmentation prevents overfitting by slightly changing the data randomly\n",
|
|||
|
|
"# Keras has a great built-in feature to do automatic augmentation\n",
|
|||
|
|
"\n",
|
|||
|
|
"gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,\n",
|
|||
|
|
" height_shift_range=0.08, zoom_range=0.08, validation_split=0.2)\n",
|
|||
|
|
"\n",
|
|||
|
|
"test_gen = ImageDataGenerator()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "VsGRzqqd8Yab"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# We can then feed our augmented data in batches\n",
|
|||
|
|
"# Besides loss function considerations as before, this method actually results in significant memory savings\n",
|
|||
|
|
"# because we are actually LOADING the data into the network in batches before processing each batch\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Before the data was all loaded into memory, but then processed in batches.\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_generator = gen.flow(X_train, Y_train, batch_size=128, subset='training')\n",
|
|||
|
|
"valid_generator = gen.flow(X_train, Y_train, batch_size=128, subset='validation')\n",
|
|||
|
|
"test_generator = test_gen.flow(X_test, Y_test, batch_size=128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "_DXSGa-z8Yae",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "eeea17b5-1a06-4729-dfc2-ab5153acced8"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# We can now train our model which is fed data by our batch loader\n",
|
|||
|
|
"# Steps per epoch should always be total size of the set divided by the batch size\n",
|
|||
|
|
"\n",
|
|||
|
|
"# SIGNIFICANT MEMORY SAVINGS (important for larger, deeper networks)\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.fit(train_generator, steps_per_epoch=48000//128, epochs=5, verbose=1, validation_data=valid_generator, validation_steps = 12000 // 128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/5\n",
|
|||
|
|
"375/375 [==============================] - 20s 47ms/step - loss: 0.4392 - accuracy: 0.8652 - val_loss: 0.1599 - val_accuracy: 0.9488\n",
|
|||
|
|
"Epoch 2/5\n",
|
|||
|
|
"375/375 [==============================] - 19s 51ms/step - loss: 0.1279 - accuracy: 0.9613 - val_loss: 0.1044 - val_accuracy: 0.9666\n",
|
|||
|
|
"Epoch 3/5\n",
|
|||
|
|
"375/375 [==============================] - 17s 45ms/step - loss: 0.0938 - accuracy: 0.9706 - val_loss: 0.0757 - val_accuracy: 0.9782\n",
|
|||
|
|
"Epoch 4/5\n",
|
|||
|
|
"375/375 [==============================] - 20s 52ms/step - loss: 0.0749 - accuracy: 0.9767 - val_loss: 0.0734 - val_accuracy: 0.9763\n",
|
|||
|
|
"Epoch 5/5\n",
|
|||
|
|
"375/375 [==============================] - 17s 45ms/step - loss: 0.0602 - accuracy: 0.9817 - val_loss: 0.0637 - val_accuracy: 0.9788\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<keras.src.callbacks.History at 0x7ac85c406650>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 34
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "6YZaV3U-8Yah",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "85ade8d4-5e06-4f01-a3d5-91f242213bdb"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"score = model.evaluate(X_test, Y_test)\n",
|
|||
|
|
"print('Test score:', score[0])\n",
|
|||
|
|
"print('Test accuracy:', score[1])"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"313/313 [==============================] - 1s 3ms/step - loss: 0.0345 - accuracy: 0.9893\n",
|
|||
|
|
"Test score: 0.03448772057890892\n",
|
|||
|
|
"Test accuracy: 0.989300012588501\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "WqvF3eS2pnLp",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"outputId": "35271263-f64a-438c-d37e-f830a6a1a696"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The predict_classes function outputs the highest probability class\n",
|
|||
|
|
"# according to the trained classifier for each input example.\n",
|
|||
|
|
"predict = model.predict(X_test)\n",
|
|||
|
|
"predicted_classes = np.argmax(predict,axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Check which items we got right / wrong\n",
|
|||
|
|
"correct_indices = np.nonzero(predicted_classes == y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"incorrect_indices = np.nonzero(predicted_classes != y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"cnf_matrix = confusion_matrix(y_test, predicted_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"class_names = [str(i) for i in range(10)]\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Plot non-normalized confusion matrix\n",
|
|||
|
|
"plt.figure()\n",
|
|||
|
|
"plot_confusion_matrix(cnf_matrix, classes=class_names,\n",
|
|||
|
|
" title='Confusion matrix, without normalization')\n",
|
|||
|
|
"\n",
|
|||
|
|
"plt.show()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"313/313 [==============================] - 1s 2ms/step\n",
|
|||
|
|
"Confusion matrix, without normalization\n",
|
|||
|
|
"[[ 975 1 1 0 0 0 2 1 0 0]\n",
|
|||
|
|
" [ 0 1134 0 0 0 0 0 1 0 0]\n",
|
|||
|
|
" [ 0 4 1024 1 0 0 0 3 0 0]\n",
|
|||
|
|
" [ 0 1 1 1004 0 3 0 0 1 0]\n",
|
|||
|
|
" [ 0 0 0 0 982 0 0 0 0 0]\n",
|
|||
|
|
" [ 1 1 0 5 0 882 1 0 0 2]\n",
|
|||
|
|
" [ 1 3 0 1 1 5 946 0 1 0]\n",
|
|||
|
|
" [ 0 5 10 0 3 0 0 1009 1 0]\n",
|
|||
|
|
" [ 0 0 3 0 2 1 0 0 966 2]\n",
|
|||
|
|
" [ 0 5 0 0 23 3 0 5 2 971]]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 2 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1UAAAN6CAYAAABmBWMlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACtqElEQVR4nOzdd3gUVdvH8d8mkEIgDUhCqAGkCQGphiItgghSRBRFCUVQpIgoCBYERBH0VaQIoo+AitIUKVKkSZEOoghIR2oSakICJJDd94/IxiWgCZNkdsP381xzPe6Z2dl7zp5d9s4pY7HZbDYBAAAAAO6Im9kBAAAAAIArI6kCAAAAAANIqgAAAADAAJIqAAAAADCApAoAAAAADCCpAgAAAAADSKoAAAAAwACSKgAAAAAwII/ZAQAAAAC4M1evXlVycrLZYWSIh4eHvLy8zA4jW5BUAQAAAC7o6tWr8i5QULp+2exQMiQkJERHjhzJlYkVSRUAAADggpKTk6Xrl+VZKUpy9zA7nH+XkqzoPdOVnJxMUgUAAADAybh7yOLkSZXN7ACyGUkVAAAA4MosbqmbM3P2+AzK3VcHAAAAANmMpAoAAAAADCCpAgAAAAADmFMFAAAAuDKLJIvF7Cj+nZOHZxQ9VQAAAABgAEkVAAAAABjA8D8AAADAlbGkuuly99UBAAAAQDYjqQIAAAAAAxj+BwAAALgyi8UFVv9z8vgMoqcKAAAAAAwgqQIAAAAAA0iqAAAAAMAA5lQBAAAArowl1U2Xu68OAAAAALIZSRUAAAAAGMDwPwAAAMCVsaS66eipAgAAAAADSKoAAAAAwACSKgAAAAAwgDlVAAAAgEtzgSXVc3lfTu6+OgAAAADIZiRVAAAAAGAAw/8AAAAAV8aS6qajpwoAAAAADCCpAgAAAAADGP4HAAAAuDKLC6z+5+zxGZS7rw4AAAAAshlJFQAAAAAYQFIFAAAAAAYwpwoAAABwZSypbjp6qgAAAADAAJIqAAAAADCA4X8AAACAK2NJddPl7qsDAAAAgGxGUgUAAAAABjD8DwAAAHBlrP5nOnqqAAAAAMAAkioAAAAAMICkCgAAAAAMYE4VAAAA4MpYUt10ufvqAAAAACCbkVQBAAAAgAEM/wMAAABcmcXi/MPrWFIdAAAAAHA7JFUAAAAAYADD/wAAAABX5mZJ3ZyZs8dnED1VAAAAAGAASRUAAAAAGEBSBQAAAAAGMKcKAAAAcGUWNxdYUt3J4zMod18dAAAAAGQzkioAAAAAMIDhfwAAAIArs1hSN2fm7PEZRE8VAAAAABhAUgUAAAAABpBUAQAAAIABzKkCAAAAXBlLqpsud18dAAAAAGQzkioAAAAAMIDhfwAAAIArY0l109FTBQAAAAAGkFQBAAAAgAEM/wMAAABcGav/mS53Xx0AAAAAZDOSKgAAAAAwgKQKAAAAAAxgThUAAADgylhS3XT0VAEAAACAASRVAAAAAGAAw/8AAAAAV8aS6qbL3VcHAAAAANmMpArIxQ4cOKBmzZrJz89PFotFP/zwQ5ae/+jRo7JYLJo2bVqWnjc3KFWqlLp06WJ2GOlk5j27cewHH3yQ/YHhloYNGybLTZO7zWpbztqmAcAZkFQB2ezQoUN67rnnVLp0aXl5ecnX11f16tXTxx9/rCtXrmTra0dFRWnXrl1655139NVXX6lmzZrZ+nq50Z49ezRs2DAdPXrU7FCyzeLFizVs2DCzw0jn3XffzfI/BODfbdiwQcOGDdPFixfNDgVAZtxY/c/Zt1yMOVVANvrxxx/VoUMHeXp6qnPnzqpcubKSk5O1fv16DRw4ULt379aUKVOy5bWvXLmijRs36vXXX1efPn2y5TVKliypK1euKG/evNlyfmewZ88eDR8+XI0aNVKpUqUy/Lx9+/bJzc35/m51q/ds8eLFmjhxotMlVu+++64ee+wxtW3b1uxQnEp2tq0NGzZo+PDh6tKli/z9/XPsdQHA1ZFUAdnkyJEj6tixo0qWLKlVq1apSJEi9n29e/fWwYMH9eOPP2bb6585c0aS0v0wykoWi0VeXl7Zdn5XY7PZdPXqVXl7e8vT09PscG6J98yYxMRE+fj4mBqDWW3LWds0ADgD/uQEZJMxY8YoISFB//vf/xwSqhvKli2rF1980f74+vXrevvtt1WmTBl5enqqVKlSeu2115SUlOTwvFKlSqlVq1Zav369ateuLS8vL5UuXVpffvml/Zhhw4apZMmSkqSBAwfKYrHYe1m6dOlyyx6XW83dWL58uerXry9/f3/lz59f5cuX12uvvWbff7v5OatWrVKDBg3k4+Mjf39/tWnTRnv37r3l6x08eND+V3E/Pz917dpVly9fvn3F/q1Ro0aqXLmyfv/9dzVs2FD58uVT2bJlNXfuXEnSmjVrVKdOHXl7e6t8+fJasWKFw/P/+usvvfDCCypfvry8vb1VsGBBdejQwWGY37Rp09ShQwdJUuPGjWWxWGSxWPTzzz9LSnsvli1bppo1a8rb21uffvqpfd+N+Sc2m02NGzdW4cKFFRsbaz9/cnKyqlSpojJlyigxMfE/r/mfBgwYoIIFC8pms9nL+vbtK4vFonHjxtnLYmJiZLFYNGnSJEnp37MuXbpo4sSJkmS/vpvbgSRNmTLF3jZr1aqlrVu3pjsmI+97RtufxWJRYmKipk+fbo/p3+bz/Pzzz7JYLJo9e7beeecdFStWTF5eXmratKkOHjyY7vg5c+aoRo0a8vb2VqFChfT000/r5MmT6WLNnz+/Dh06pIcfflgFChRQp06d7PH16dNHc+bMUaVKleTt7a2IiAjt2rVLkvTpp5+qbNmy8vLyUqNGjdINH123bp06dOigEiVKyNPTU8WLF9dLL72UoSHBN89t+uf7dvN243V///13denSxT4MOSQkRN26ddO5c+cc3oOBAwdKksLCwtKd41Zzqg4fPqwOHTooMDBQ+fLl0/3335/uj0WZfW8AwBXRUwVkk4ULF6p06dKqW7duho5/9tlnNX36dD322GN6+eWXtXnzZo0aNUp79+7VvHnzHI49ePCgHnvsMXXv3l1RUVH64osv1KVLF9WoUUP33nuvHn30Ufn7++ull17Sk08+qYcfflj58+fPVPy7d+9Wq1atFB4erhEjRsjT01MHDx7UL7/88q/PW7FihVq0aKHSpUtr2LBhunLlisaPH6969eppx44d6X5QP/744woLC9OoUaO0Y8cOff755woKCtLo0aP/M8YLFy6oVatW6tixozp06KBJkyapY8eOmjFjhvr376/nn39eTz31lN5//3099thjOn78uAoUKCBJ2rp1qzZs2KCOHTuqWLFiOnr0qCZNmqRGjRppz549ypcvnx544AH169dP48aN02uvvaaKFStKkv3/pdQhUU8++aSee+459ejRQ+XLl08Xp8Vi0RdffKHw8HA9//zz+v777yVJb731lnbv3q2ff/45070fDRo00EcffaTdu3ercuXKklJ/qLu5uWndunXq16+fvUySHnjggVue57nnntOpU6e0fPlyffXVV7c85ptvvtGlS5f03HPPyWKxaMyYMXr00Ud1+PBh+zDCzL7v/+Wrr77Ss88+q9q1a6tnz56SpDJlyvzn89577z25ubnplVdeUVxcnMaMGaNOnTpp8+bN9mOmTZumrl27qlatWho1apRiYmL08ccf65dfftGvv/7q0Lt7/fp1NW/eXPXr19cHH3ygfPny2fetW7dOCxYsUO/evSVJo0aNUqtWrTRo0CB98skneuGFF3ThwgWNGTNG3bp106pVq+zPnTNnji5fvqxevXqpYMGC2rJli8aPH68TJ05ozpw5ma6rm73xxhuKjY21f+6XL1+uw4cPq2vXrgoJCbEPPd69e7c2bdoki8WiRx99VPv379e3336rjz76SIUKFZIkFS5c+JavGxMTo7p16+ry5cvq16+fChYsqOnTp6t169aaO3eu2rVrl+n3BsCdcoEl1XN7X44NQJaLi4uzSbK1adMmQ8fv3LnTJsn27LPPOpS/8sorNkm2VatW2ctKlixpk2Rbu
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Wn7FDt7DqSKB",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 906
|
|||
|
|
},
|
|||
|
|
"outputId": "edcab911-ae25-4e85-f8e9-c8f99a11239b"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples(correct_indices, predict, X_test, y_test, 5)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 25 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAN5CAYAAAA/32uUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADwN0lEQVR4nOzdd3hU1fbw8TWkF4GQhB5C6L0K0kNXaRaCoqC5inIFbNgAxYtSvAIKlkuRIiiCF0S5YAEpIoIoSBMxRCmBUIVQQkkgIdnvH/4yL/ucSc5MMpNk4Pt5njw+a88+e1Zk5WR2ztln25RSSgAAAAAAuSpR1AkAAAAAQHHHxAkAAAAALDBxAgAAAAALTJwAAAAAwAITJwAAAACwwMQJAAAAACwwcQIAAAAAC0ycAAAAAMACEycAAAAAsMDE6f8MHTpUunXrVtRp5GnVqlUSGhoqp0+fLupUUAxNmjRJ6tSpI9nZ2UWdSq4SEhLE19dX9uzZU9SpoBiihuHt+CwBb0cNW1CFaN68eUpE7F8BAQGqZs2aatiwYerkyZP5GnPMmDHamMavTZs2WY5x8OBB5efnp7777jutffr06SouLk5FRUUpEVHx8fEu5ZaVlaUmTpyoqlatqgICAlTDhg3VokWLHPZNSEhQt99+uwoJCVFhYWFq4MCB6tSpU6Z+jRs3VsOHD3cpD7iPJ2p479696sUXX1SNGzdWoaGhqnz58qpHjx7ql19+cXqM1NRUVaZMGfXhhx9q7f/973/VgAEDVI0aNZSIqNjYWJfzmzNnjqpTp44KCAhQNWrUUO+9957DfkePHlX9+vVTpUqVUrfccovq06ePOnDggKlfnz591D333ONyHnAPT9SwUkqNHz9e9e7dW5UtW1aJiBozZoxLx1PDcJanatiV39mO8FkCzvJUDV/vk08+USKiQkJCnD6GGrZWJBOnsWPHqgULFqjZs2er+Ph4VaJECRUTE6MuX77s8pi//vqrWrBggekrKipKhYWFqatXr1qO8cwzz6hatWqZ2qOjo1WZMmXUHXfcoXx9fV0ulJEjRyoRUY8//riaNWuW6tmzpxIR9emnn2r9jhw5oiIiIlT16tXVu+++qyZMmKDCwsJU48aNTflPnz5dBQcHqwsXLriUC9zDEzX8/PPPq9KlS6tBgwapDz74QE2aNElVr15d+fj4qDVr1jg1xtSpU1XJkiVVenq61h4bG6tCQ0NVp06dVFhYmMsfOmfOnKlERPXt21fNmjVLPfTQQ0pE1Jtvvqn1u3jxoqpZs6YqW7asmjhxopoyZYqKiopSlStXVikpKVrfb775RomI2r9/v0u5wD08UcNKKSUiqnz58ur222/P18SJGoazPFXDzv7Ozg2fJeAsT9VwjosXL6qKFSuqkJAQlyZO1LC1Ipk4Gf+S/txzzykRcekvO3lJTk5WNptNPf7445Z9MzIyVEREhBo9erTptUOHDqns7GyllFIhISEuFcrRo0eVn5+fGjZsmL0tOztbtW/fXlWuXFldu3bN3j5kyBAVFBSkDh8+bG9bs2aNEhH1wQcfaOP+9ddfysfHR82dO9fpXOA+nqjhbdu2qYsXL2ptKSkpKjIyUrVt29apMRo1aqQGDhxoak9OTlZZWVlKKaXq16/v0ofOtLQ0FR4ernr27Km1DxgwQIWEhKizZ8/a2yZOnKhERG3dutXetnfvXuXj46NGjRqlHZ+RkaHCwsLUq6++6nQucB9PnYeTkpKUUkqdPn06XxMnahjO8kQNu/I72xE+S8AVnv48PGLECFW7dm37uc4Z1LBzisUap86dO4uISFJSkr3twIEDcuDAgXyN9+mnn4pSSgYMGGDZd9OmTZKSkiJdu3Y1vRYdHS02my1fOSxfvlwyMzNl6NCh9jabzSZDhgyRo0ePyk8//WRv//zzz6VXr15SpUoVe1vXrl2lVq1asmTJEm3csmXLSqNGjWT58uX5ygueUZAabt68uYSGhmpt4eHh0r59e9m7d6/l8UlJSbJ7926HNRwVFSUlSuTvx3z9+vVy5swZrYZFRIYNGyaXL1+Wr7/+2t62dOlSadGihbRo0cLeVqdOHenSpYuphv38/KRjx47UcDFT0PNw1apV8/3e1DDcoSA17MrvbEf4LAF3cMfn4X379snUqVNlypQp4uvr6/Rx1LBzisXEKacgwsPD7W1dunSRLl265Gu8hQsXSlRUlHTo0MGy7+bNm8Vms0nTpk3z9V652blzp4SEhEjdunW19pYtW9pfFxE5duyYnDp1Sm699VbTGC1btrT3u17z5s1l8+bNbs0XBePuGhYROXnypERERFj2y6mFZs2a5fu9HMmpPWNtNm/eXEqUKGF/PTs7W3bv3p1rDR84cEAuXrxoGmPPnj1y4cIFt+aM/PNEDTuLGoY7FKSGnf2dnRs+S8Ad3HEefvbZZ6VTp07So0cPl96bGnZOkUycUlNTJSUlRY4ePSqLFy+WsWPHSlBQkPTq1avAY//++++ye/dueeCBB5yaHScmJkqZMmWkZMmSBX7v6504cULKlStnyqFChQoiInL8+HF7v+vbjX3Pnj0rV69e1dqrVasmKSkpcurUKbfmDOd5soZFRDZu3Cg//fST3H///ZZ9ExMTRUQkJibGLe+d48SJE+Lj4yNly5bV2v39/SU8PNxewzk1mlsNi/z/es9RrVo1yc7OtueOwufpGnYFNYz8cGcNO/s7Ozd8lkB+uPs8/PXXX8vq1atlypQpLh9LDTvH+Wt4bmS8DBgdHS0LFy6USpUq2dsOHTqUr7EXLlwoIuLUbXoiImfOnJGwsLB8vVde0tPTJSAgwNQeGBhof/36/1r1vf71nHxTUlJMHwhQODxZw6dOnZIHH3xQYmJi5KWXXrLsf+bMGfH19TXd7ldQ6enp4u/v7/C1wMBAl2v4etfXMIqGJ2vYVdQw8sOdNezs7+zc8FkC+eHOGs7IyJDhw4fLE088IfXq1XM5F2rYOUUycZo2bZrUqlVLfH19pVy5clK7du1838N+PaWULFq0SBo0aCCNGjVy6Th3CwoKMs2MRUSuXLlif/36/zrTN0dOvvm93xQF56kavnz5svTq1UsuXrwomzZtcvsHSVcEBQVJRkaGw9euXLlCDXs5T9VwcUIN39jcWcPO/s7OC58l4Cp31vDUqVMlJSVFXn/99XznQw1bK5KJU8uWLR3ew1hQP/74oxw+fFj+/e9/O31MeHi4nDt3zu25VKhQQdavXy9KKe0fNOdSZMWKFe39rm+/3okTJ6RMmTKm2XdOvs6sf4FneKKGMzIy5N5775Xdu3fLt99+Kw0aNHDquPDwcLl27ZpcvHhRbrnlFrflU6FCBcnKypJTp05pf8nJyMiQM2fO2Gs4p0Zzq2GR/1/vOajhouep83B+UMPID3fWsLO/s3PDZwnkh7tqODU1VcaPHy9Dhw6VCxcu2NdeXrp0SZRScujQIQkODs7zqgw17Jwb6s+LCxcuFJvNJg8++KDTx9SpU0fOnTsnqampbs2lSZMmkpaWZnoq2pYtW+yvi4hUqlRJIiMjZdu2baYxtm7dau93vaSkJImIiJDIyEi35oyik52dLQ8//LCsW7dOFi1aJLGxsU4fW6dOHRHRn8LjDjm1Z6zNbdu2SXZ2tv31EiVKSMOGDR3W8JYtW6RatWqmD8NJSUlSokQJqVWrlltzhneihlHUnP2dnRs+S6AonTt3Ti5duiSTJk2SmJgY+9fnn38uaWlpEhMTI4MHD85zDGrYOcV24uTq4xczMzPls88+k3bt2mmPMbTSunVrUUrJ9u3b85OmiPw9009MTNSK7a677hI/Pz+ZPn26vU0pJTNnzpRKlSpJmzZt7O19+/aVr776So4cOWJvW7dunfz555/Sr18/0/tt375dWrdune98UThcqeGnnnpKFi9eLNOnT5d7773XpffJqQVHJxtnpaWlSWJiorZeo3PnzlKmTBmZMWOG1nfGjBkSHBwsPXv2tLfFxcXJL7/8ouXwxx9/yHfffZdrDdevX19KlSqV75zheQXZFsIV1DA8xdkaduV3tiN8loCnOFPDZcuWl
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "r0m_gom9qL3o",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 906
|
|||
|
|
},
|
|||
|
|
"outputId": "4f195d18-df3a-4dfd-a7cd-6a2a7de0efaa"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples(incorrect_indices, predict, X_test, y_test, 5)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 25 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAN5CAYAAAA/32uUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdZ3hURdsH8H8K6YU0SCAQQkvoJYDSCYQWkI7SESlSpFlAiiIl9AdEpCOgFAUFHorSRaVJCUUxIEgIBAglAQPpbd4PPLsvs2ezZwO7bEL+v+vKh3t2zpyJ3kzO7DlzxkoIIUBERERERES5srZ0B4iIiIiIiPI7TpyIiIiIiIhUcOJERERERESkghMnIiIiIiIiFZw4ERERERERqeDEiYiIiIiISAUnTkRERERERCo4cSIiIiIiIlLBiRMREREREZEKTpz+Z/jw4WjZsqWlu2HQ3r174eLiggcPHli6K5QPzZ07F8HBwcjJybF0V3LFHCZDmMNU0BWEa4nly5ejdOnSSE9Pt3RXKB9iDqsQL9HatWsFAO2Pvb29qFChghgxYoS4e/euSc6xYcMGAUA4OzsbfUx0dLQoUqSI+PnnnxWfrV69WgQHBwt7e3tRvnx58cUXX+SpP5GRkeKNN94QHh4ewtHRUVSpUkUsWrRIqhMRESFee+014e3trT3P6NGjxf379xXt1ahRQ4wdOzZPfSDTMVcOz5gxQ7zxxhuiWLFiAoCYMmVKno5PTEwUnp6eYs2aNYrPduzYIWrVqiXs7e1FqVKlxKeffioyMzNV27x+/br0uz778+2330p1+/fvr7deUFCQol3msGWZI4dv374tevfuLSpWrChcXFyEu7u7qFu3rli3bp3Iyckxqg3mMBmrMF1LHD58ONccPnHihKJ+enq6iIiIEEFBQcLe3l4UK1ZMhIeHi9jYWG2d1NRUUbx4ccW1CL08hSmHhRDizJkzonXr1sLV1VW4uLiIli1binPnzinqNW3aVG+ut27dWqpnyRy2Ne00zDjTpk1DYGAg0tLScPToUSxbtgw//fQTLl68CCcnp+duNykpCePGjYOzs3Oejlu0aBECAwMRGhoqla9YsQJDhw5F165d8f777+PIkSMYNWoUUlJSMH78eNV29+/fjzfeeAO1atXCJ598AhcXF1y7dg23bt2S6kVGRqJmzZro0aMHXF1dcenSJaxatQo//vgjzp8/L/0+7777Lj788ENMnToVrq6uefo9yXRMncOTJ0+Gr68vatWqhX379uX5+DVr1iArKws9e/aUyvfs2YNOnTqhWbNmWLx4Mf7880/MmDED9+/fx7Jly4xqu2fPnggPD5fK6tevr6hnb2+P1atXS2Xu7u6Keszh/MGUORwfH49bt26hW7duKF26NDIzM3HgwAG8/fbb+PvvvzFz5kzVNpjDlFeF5VoCAEaNGoW6detKZeXLl5fizMxMtGvXDsePH8fgwYNRvXp1PHr0CCdPnkRiYiL8/f0BAA4ODujfvz8WLFiAkSNHwsrKKk+/J5lOYcjhs2fPolGjRihVqhSmTJmCnJwcLF26FE2bNsWpU6cQFBQk1ff398esWbOkshIlSkixRXP4Zc7SNDPs06dPS+Xvv/++ACA2bdr0Qu2PHz9eBAUFid69exs9w87IyBDe3t5i8uTJUnlKSorw8vIS7dq1k8o1bT98+NBgu4mJiaJ48eKic+fOIjs7O2+/iBDihx9+0Put6L1794SNjY346quv8twmvThz5fD169eFEEI8ePDgue44Va9eXfTp00dRXrlyZVGjRg3p2/lJkyYJKysrcenSJdU+ARDz5s1TPX///v2N/jfHHLYsc4/Dz2rfvr1wdnYWWVlZqnWZw2SswnQtobnj9P3336v2Yc6cOaJIkSLi5MmTqnXPnDkjAIhDhw6p1iXTK0w5HB4eLjw8PER8fLy27M6dO8LFxUV06dJFqtu0aVNRpUoVo/prqRzOF2ucmjdvDgC4fv26tuzatWu4du2a0W1cvXoVCxcuxIIFC2Bra/yNtKNHjyI+Ph5hYWFS+eHDh5GQkIDhw4dL5SNGjEBycjJ+/PFHg+1u2rQJ9+7dQ0REBKytrZGcnJyn5/bLlCkDAPj333+l8mLFiqF69erYsWOH0W2R+b1oDmv+fz+P69ev448//lDkcFRUFKKiojBkyBDp38Tw4cMhhMAPP/xg9DmSk5ORkZGhWi87OxuPHz82WIc5nD+ZYhzWVaZMGaSkpKjmDnOYTOFVvJZ41pMnT5CVlaX3s5ycHCxatAidO3dGvXr1kJWVhZSUlFzbCgkJgaenJ3M4n3kVc/jIkSMICwuDl5eXtszPzw9NmzbF7t27kZSUpDgmKytLb/mzLJXD+WLipEmIZ/+jtmjRAi1atDC6jTFjxiA0NFTxOIaa48ePw8rKCrVq1ZLKz507BwCoU6eOVB4SEgJra2vt57k5ePAg3NzccPv2bQQFBcHFxQVubm4YNmwY0tLSFPWFEIiPj8fdu3e1t0BtbGzQrFkzRd2QkBAcP348T78nmZcpcvh5aXKhdu3aUnluOVyiRAn4+/ur5rDG1KlT4eLiAgcHB9StWxf79+/XWy8lJQVubm5wd3eHp6cnRowYkevAxxzOf0yRw6mpqYiPj0dMTAy+/vprrF27FvXr14ejo6PB45jDZAqv4rWExoABA+Dm5gYHBweEhobizJkz0udRUVG4c+cOqlevjiFDhsDZ2RnOzs6oXr06Dh8+rLfN2rVr49ixY8b+ivQSvIo5nJ6ervdvgJOTEzIyMnDx4kWp/MqVK3B2doarqyt8fX3xySefIDMzU2/blshhi6xxSkxMRHx8PNLS0nDs2DFMmzYNjo6OaN++/XO19+OPP2L//v24cOFCno+9fPkyPD094ebmJpXHxcXBxsYGxYoVk8rt7Ozg5eWFO3fuGGz36tWryMrKQseOHTFw4EDMmjULv/zyCxYvXox///0X3377rVT/3r178PPz08b+/v7YtGkTgoODFW2XLVsW8fHxuH//vqJ/9HKYOodfxOXLlwEAgYGBUnlcXBwASHml4efnp5rD1tbWaNWqFTp37oySJUsiOjoaCxYsQNu2bbFz5060a9dOam/cuHGoXbs2cnJysHfvXixduhQXLlzAL7/8ovjWizlseebI4UWLFmHChAnauEWLFli7dq3qccxheh6F4VrCzs4OXbt2RXh4OLy9vREVFYX58+ejcePGOH78uPYi9+rVqwCAhQsXwtPTEytWrAAAzJw5E23atMHp06dRvXp1qe2yZcti/fr1ef5dyXQKQw4HBQXh999/R3Z2NmxsbAAAGRkZOHnyJADg9u3b2rrlypVDaGgoqlWrhuTkZPzwww+YMWMGrly5gs2bNyvatkQOW2TipHsbMCAgABs3bkTJkiW1ZTExMUa1lZGRgbFjx2Lo0KGoXLlynvuSkJAADw8PRXlqairs7Oz0HuPg4IDU1FSD7SYlJSElJQVDhw7FF198AQDo0qULMjIysGLFCkybNg0VKlTQ1vf09MSBAweQlpaGc+fOYdu2bbl+06npb3x8PP9gW4gpc/hFJSQkwNbWFi4uLlK5Jkft7e0Vxzg4OKg+jlS6dGnFiyr69u2LypUr44MPPpAuOnUXcvbo0QMVK1bEpEmT8MMPP6BHjx7S58xhyzNHDvfs2RN16tTBgwcPsHv3bty7d091rASYw/R8CsO1RIMGDdCgQQNt3KFDB3Tr1g3Vq1fHhAkTsHfvXgDQXi88efIE586dQ6lSpQA8ffSrfPnymDt3LjZs2CC17eHhgdTUVKSkpLzQiwjo+RWGHB4+fDiGDRuGgQMHYty4ccjJycGMGTO0X4w9e/xXX30lHdu3b18MGTIEq1atwtixY/H6669Ln1sihy0ycVqyZAkqVqwIW1tbFC9eHEFBQbC2fr6nBhcuXIj4+HhMnTr1ufsjhFCUOTo65vo8fFpamuqjJ5rPdd8Q1atXL6xYsQInTpyQJk52dnbaf0Dt27dHixYt0LBhQxQrVkzxzYOmv3wTjuWYMofNRZOD+vY5MCaH9fH09MSAAQMwe/Zs3
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "RsY5uPgU8Yak"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Great results!\n",
|
|||
|
|
"\n",
|
|||
|
|
"But wouldn't it be nice if we could visualize those convolutions so that we can see what the model is seeing?"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "BgbunToX8Yal"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"def visualize2(model, layer, img, r, c):\n",
|
|||
|
|
" from keras import Model\n",
|
|||
|
|
" # expand dimensions so that it represents a single 'sample'\n",
|
|||
|
|
" img = np.expand_dims(img, axis=0)\n",
|
|||
|
|
" outputs = [layer.output]\n",
|
|||
|
|
" model = Model(inputs=model.inputs, outputs=outputs)\n",
|
|||
|
|
" # get feature map for first hidden layer\n",
|
|||
|
|
" feature_maps = model.predict(img)\n",
|
|||
|
|
" for fmap in feature_maps:\n",
|
|||
|
|
" # plot all 64 maps in an 8x8 squares\n",
|
|||
|
|
" ix = 1\n",
|
|||
|
|
" for _ in range(r):\n",
|
|||
|
|
" for _ in range(c):\n",
|
|||
|
|
" # specify subplot and turn of axis\n",
|
|||
|
|
" ax = plt.subplot(r, c, ix)\n",
|
|||
|
|
" ax.set_xticks([])\n",
|
|||
|
|
" ax.set_yticks([])\n",
|
|||
|
|
" # plot filter channel in grayscale\n",
|
|||
|
|
" plt.imshow(fmap[:, :, ix-1], cmap='gray')\n",
|
|||
|
|
" ix += 1"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "ZiDemlgp8Yan",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 770
|
|||
|
|
},
|
|||
|
|
"outputId": "acdcfab0-616d-43ce-a8c5-aa7f38274626"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"plt.figure()\n",
|
|||
|
|
"img = X_test[11]\n",
|
|||
|
|
"plt.imshow(img[:,:,0], cmap='gray', interpolation='none')"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<matplotlib.image.AxesImage at 0x7ac8705b1360>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 40
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 1 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAuQAAALgCAYAAADcNWJzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAl3UlEQVR4nO3df6yW9X34/9cB5IjtOTc9IpxzxoGiVF2KUGvhSLTMViJg64ryB7ouwYZp6w6sQjodG0qtJmR+0mm6UTs3I+tSbGsiWl3GoiAYK9gU44xLy4CwAuGHLQnn8GMCgev7x74926mAHLgvXufH45HcCee+r/N+v8zV++TZK9e5T01RFEUAAAApBmQPAAAA/ZkgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASDcoe4HedOHEidu3aFXV1dVFTU5M9DgAAdFtRFHHgwIFobm6OAQNOfw28xwX5rl27oqWlJXsMAAA4Zzt27IiRI0ee9pged8tKXV1d9ggAAFAVZ9K2PS7I3aYCAEBfcSZt2+OCHAAA+hNBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiUoL8mXLlsXHP/7xuPDCC6O1tTV+9rOflbUVAAD0WqUE+Y9+9KNYuHBhLFmyJN56662YMGFCTJs2Ld57770ytgMAgF6rpiiKotqLtra2xsSJE+Pv/u7vIiLixIkT0dLSEvPnz4+/+Iu/OO33dnR0RKVSqfZIAABw3rW3t0d9ff1pj6n6FfKjR4/Gxo0bY+rUqf+7yYABMXXq1Fi/fv0Hjj9y5Eh0dHR0eQAAQH9R9SD/zW9+E8ePH48RI0Z0eX7EiBGxZ8+eDxy/dOnSqFQqnY+WlpZqjwQAAD1W+qesLFq0KNrb2zsfO3bsyB4JAADOm0HVXnDYsGExcODA2Lt3b5fn9+7dG42NjR84vra2Nmpra6s9BgAA9ApVv0I+ePDguOaaa2L16tWdz504cSJWr14dkydPrvZ2AADQq1X9CnlExMKFC2POnDnxmc98JiZNmhSPP/54HDp0KL7yla+UsR0AAPRapQT57Nmz49e//nU8+OCDsWfPnvjUpz4Vq1at+sAvegIAQH9XyueQnwufQw4AQF+R8jnkAADAmRPkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQaFD2AED/UVtbW/oeP/3pT0vf4+qrry59jxdffLHU9WfOnFnq+gCcOVfIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASDQoewCg56itrS11/ccee6zU9SMiPvWpT5W+R1EUpe+xcePG0vcAoGdwhRwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEg7IHAHqOP/uzPyt1/bvvvrvU9SMi1qxZU/oeDz74YOl7bNiwofQ9AOgZXCEHAIBEghwAABIJcgAASCTIAQAgkSAHAIBEghwAABIJcgAASCTIAQAgUdWD/Jvf/GbU1NR0eVx55ZXV3gYAAPqEUv5S5yc/+cl45ZVX/neTQf4gKAAAnEwppTxo0KBobGw8o2OPHDkSR44c6fy6o6OjjJEAAKBHKuUe8s2bN0dzc3Nceuml8eUvfzm2b99+ymOXLl0alUql89HS0lLGSAAA0CNVPchbW1tj+fLlsWrVqnjiiSdi27Zt8dnPfjYOHDhw0uMXLVoU7e3tnY8dO3ZUeyQAAOixqn7LyowZMzr/PX78+GhtbY3Ro0fHj3/845g7d+4Hjq+trY3a2tpqjwEAAL1C6R97OHTo0Lj88stjy5YtZW8FAAC9TulBfvDgwdi6dWs0NTWVvRUAAPQ6VQ/yb3zjG7Fu3br4r//6r3jjjTfi1ltvjYEDB8Ydd9xR7a0AAKDXq/o95Dt37ow77rgj9u3bF5dccklcf/31sWHDhrjkkkuqvRUAAPR6VQ/yH/7wh9VeEgAA+qzS7yEHAABOTZADAEAiQQ4AAIkEOQAAJKr6L3UCvVdjY2P2COfslVdeKX2PDRs2lL4HAP2HK+QAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkGpQ9ANBz1NXVlbr+sWPHSl0/IuKVV14pfQ8AqCZXyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEg0KHsA4Mw0NzeXvsfcuXNLXf+NN94odf2IiLfeeqv0PQCgmlwhBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAIJEgBwCARIIcAAASCXIAAEgkyAEAINGg7AGAM7N48eLsEYASXHvttaXv0dLSUvoe58O///u/l7r+f/7nf5a6PpyKK+QAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQQ5AAAkEuQAAJBIkAMAQCJBDgAAiQZlDwCcmS984QvZI5yzp556KnsE+pgnnnii9D3Kfu997GMfK3X9iIghQ4aUvsf50NHRUer6jz32WKnrR0Q8/PDDpe9B7+MKOQAAJBLkAACQSJADAEAiQQ4AAIkEOQAAJBLkAACQSJADAEAiQQ4AAIm6HeSvvfZa3HLLLdHc3Bw1NTXx/PPPd3m9KIp48MEHo6mpKYYMGRJTp06NzZs3V2teAADoU7od5IcOHYoJEybEsmXLTvr6o48+Gt/5znfie9/7Xrz55pvxkY98JKZNmxbvv//+OQ8LAAB9zaDufsOMGTNixowZJ32tKIp4/PHHY/HixfGlL30pIiK+//3vx4gRI+L555+P22+//dymBQCAPqaq95Bv27Yt9uzZE1OnTu18rlKpRGtra6xfv/6k33PkyJHo6Ojo8gAAgP6iqkG+Z8+eiIgYMWJEl+dHjBjR+drvWrp0aVQqlc5HS0tLNUcCAIAeLf1TVhYtWhTt7e2djx07dmSPBAAA501Vg7yxsTEiIvbu3dvl+b1793a+9rtqa2ujvr6+ywMAAPqLqgb5mDFjorGxMVavXt35XEdHR7z55psxefLkam4FAAB9Qrc/ZeXgwYOxZcuWzq+3bdsWb7/9djQ0NMSoUaPi3nvvjUceeSQ+8YlPxJgxY+KBBx6I5
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "UklZv03C8Yao",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 747
|
|||
|
|
},
|
|||
|
|
"outputId": "3d5d76c2-c8e7-4202-d3f7-08effa1600fd"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"visualize2(model, model.layers[1], img, 4, 4)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"1/1 [==============================] - 0s 75ms/step\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 16 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAswAAALJCAYAAACgHHWpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABd8ElEQVR4nO3deXRc93ne8R8AYgcIggAJAiS4r5IoUlwkarNEOZIs1YqsUy+K2+PUiWPn2G6jxKnjNG3sxrGzOMc9aWurrq3UcZzErixLkSxrF0WKokhx30ESXEEAJEFiGwDEPv1LqRS+z0tccBZg8P38+Qzv714OfjPzasT7ICsej8cDAAAAAFN2ui8AAAAAGMsYmAEAAAAHAzMAAADgYGAGAAAAHAzMAAAAgIOBGQAAAHAwMAMAAAAOBmYAAADAMWkkf2h4eDg0NTWF0tLSkJWVlexrwgQSj8dDLBYLNTU1ITs7tf/9xr5GsrCvkYnY18hEI93XIxqYm5qaQm1tbcIuDviXGhoawqxZs1J6TvY1ko19jUzEvkYmutq+HtHAXFpamrALSre8vDz5WH9/fwqvBO+Vjj2WSfs6kdatW2fmOTk58pi6ujozv3TpUkKuabxK574uKSlJyzdx3jmnTJli5pMm2R9FJ06cSMQlJdxEfe+Ix+Ohq6srrfs6Ly9vzH3DnJ+fb+bz5s0z84ULF8q1brrpJjO/5ZZbzLyoqMjMN23aJM/x1FNPmfn+/fvlMZksHo+H/v7+q+7rEQ3MY21zXotM+rtkkrH2wT6RqeHFG5hT/b9nx4t07uusrKwx97pS+2S87Z+J/t4xEfe1R12Pes/Mzc2VaxUUFJh5cXFxpFytE4J+jx9rz2uqXe3vP77epQAAAIAUY2AGAAAAHAzMAAAAgGNE/4Z5PKqpqTHz9vb21F4IMEb9zu/8jplXVFSY+caNG+Vavb29CbkmjH/ejTPq307W19cn63Kuau7cufKxgYEBM4/FYkm6GoxV3r87rq6uNvOqqqrI51E3Sr/yyitmfurUKTN/++235TkuXLgQ+brAN8wAAACAi4EZAAAAcDAwAwAAAA4GZgAAAMDBwAwAAAA4GJgBAAAAx7ivlSsvLzdzVV/U09OTzMsJIfi/XlH9znlquZAs3/zmN838lltuMfO/+Iu/MPPXXnstYdeEzOX9St501sf9wR/8gZl3d3fLY3bv3m3m+/fvT8g1YfzIy8uTj6lfgX348OHIaz3zzDNmnooZwXvtgm+YAQAAABcDMwAAAOBgYAYAAAAcDMwAAACAg4EZAAAAcIyLloyioiL52PTp0828q6vLzFV7RgghDA4OmvmcOXPMfNWqVWbe3t4uz3Hw4EEzpyUD1+LrX/+6fOxzn/ucmX/5y18285dffjkh14SJ6eTJk2k9/49+9CMzX79+vZm/+uqrcq0zZ84k5Jow/g0NDcnHjhw5YuZlZWVmfunSpYRcE1KLb5gBAAAABwMzAAAA4GBgBgAAABwMzAAAAICDgRkAAABwjIuWDO/3m1dXV5t5T0+Pmff398u11DGqWaO1tdXMT5w4Ic9x8eJF+RhwNaoV5oYbbpDHPPnkk2b+xBNPJOSaMDHl5OSYudcmkEgf/ehHzfy2224z88OHD5v5tm3b5Dnq6+ujXxjGtaysLDOPxWLymNzcXDOnDSOz8A0zAAAA4GBgBgAAABwMzAAAAICDgRkAAABwMDADAAAAjnHRklFaWiofmzJlipmrNgrvrtXh4WEzP378uJl3dnZGOncIIcTjcfkYcDWzZs0y86amJnnMt7/97WRdzphXXFwc6c/39fWZ+eDgYCIuJ6OUl5ebeVtbW0rOv3r1ajPftWuXmT/33HNm/tprr8lzdHV1mXl+fr6Zew0h7e3tZq4+dxKpsLDQzFW7w0SWl5dn5lVVVfKYgYEBM1f7wfuZX7hwwbm6xFDNY9nZ9neoRUVFci31HqvO4e051WJ2+fJlM1czmHofv1Z8wwwAAAA4GJgBAAAABwMzAAAA4GBgBgAAABwMzAAAAICDgRkAAABwjPtauXPnzpl5XV2dmSeywkdVDlEdh2RZu3atmWdlZcljVC3iRDBpkv0WV11dbeaqPq6lpUWeo6OjI/qFZQBVG5UqqjqqubnZzA8ePGjmqkYsBF2Zpc4xVqlaLpWHEMLkyZOTdTljmtrX6r1kNMfEYrHoFxaR9/NT73/z588389raWrnWqlWrzHzOnDlm7lXUqXlu586dZr5//34zP3z4cORzjATfMAMAAAAOBmYAAADAwcAMAAAAOBiYAQAAAAcDMwAAAOAYUy0ZU6ZMMXPvbs8LFy6YubrbPZG8O4yV3NxcMx8YGLjWy8EEcOONN5p5d3d3iq9kfFCNNeq9Jj8/38y99xOrLScejye0kSedVAOLdyd6oqxYsUI+tmDBAjNXd8Grn6HXctLW1uZcXWaznq+J0ACl9klDQ4M8ZsaMGWZ++vTphFyTZ+rUqWbuvXYWLVpk5mvWrDHze++9V66lGjRycnLM3NtD6rW4ZMkSM3/jjTfkWkpra6t5Tap15734hhkAAABwMDADAAAADgZmAAAAwMHADAAAADgYmAEAAADHmGrJqKioMHOvJSMV7QDqbs+hoSEz9663s7MzIdeEzHbXXXeZeVVVlZl/73vfS+blJNwNN9wgH1N3cDc3N5v51q1b5VrqjneVz5w508ytJgzvuuLxeOjv75fHjCelpaVm7rVLJMonPvEJ+Zj6We3fv9/M9+3bl5Brmiisz7dMaslQjVXqdVtYWCjXSkUbhvLBD37QzBcuXCiPWb58uZmrz5dnn31WrvX666+b+bZt28zca6NQLVArV640c9UQov4eIdg/x+Hh4RCLxeQx7+IbZgAAAMDBwAwAAAA4GJgBAAAABwMzAAAA4GBgBgAAABwMzAAAAIBjTNXKTZkyxcyHh4flMa2trWaenW3/t4C3lqLq7tS5U1Udl5eXZ+Yf+9jHzNyru/vlL39p5umsy5nIVqxYYeZtbW1m/sorryTzcq7quuuuM/PPfe5zZr5kyRK5lqqJ27lzZ+Tr6unpMfMdO3aYeW9vr5l7FXFW7WQm1W8VFBSk7dx33HGHfKylpcXM//Iv/zJZlzNuqc/D4uJieYxVvZhJ+1p9fqq/46RJelxSz2Mia2/VZ4J6Ly0vL5dr1dXVmflXvvIVMz9z5sxVri4xNm7cGClfunSpmS9evFieQ70WRoJvmAEAAAAHAzMAAADgYGAGAAAAHAzMAAAAgIOBGQAAAHCkpSWjsLDQzLOyssz80qVLcq1YLGbmo2nDKCoqMnN1p711F3GiffnLX5aPffjDHzbz5cuXm7lqIQkhhB/84Adm/ld/9VdmfuTIEbkWrp266/rAgQMpvpL/70tf+pJ87LHHHjPz/Px8M/8f/+N/yLW+973vmfmFCxf0xSVIOp/fsaqpqSnp51B3tXd1dclj/uAP/iBZlzNqX//61+Vjq1atMnPV2PLVr3418vmrq6vN/L777jNzr9Fpy5YtV2TDw8OySWa8UTOCaslQ80kIuklGzQje7FBTUxMpP3z4sJmrFpAQQvjHf/xH+dh4oto+Ll68KI+x5rmRtr/wDTMAAADgYGAGAAAAHAzMAAAAgIOBGQAAAHAwMAMAAACOtLRkqN9xrpoB2tra5FqJbKpQd7r29/cn7ByzZs0y8y9+8Ytmru6MDSGEvXv3mvnRo0fN3GvJOHbsmJmP9O5RRDdpkn75qfYX6871RPvWt75l5r//+78vj9mzZ4+Zf/rTn4705zH2qJagRJo8ebKZe3fz19fXJ+ty/tn69evN/G//9m/NvLa2Vq61f/9+M3/xxRejX5iQk5Nj5kuXLjXzy5cvy7Ws6x0eHg4tLS2ju7gxRs0V6n15YGBArtXd3W3mqrVEzTohhNDR0WHmqplKzUBnzpyR58h0XkvGteAbZgAAAMDBwAwAAAA4GJgBAAAABwMzAAAA4GBgBgAAABxpackoKioyc9We4TU1eE0DUam7U6M2RcyZM0c+9uCDD5q5ugP3hRdekGu99dZbZq5+5/0tt9wi11J3DPf19cljo
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "gt8S9bzR8Yar",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 670
|
|||
|
|
},
|
|||
|
|
"outputId": "fb35c9b7-2219-4adc-d761-a7c9c0b087a2"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"visualize2(model, model.layers[4], img, 4, 8)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"1/1 [==============================] - 0s 152ms/step\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 32 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAs0AAAJ8CAYAAAAF2ZxRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAsiUlEQVR4nO3de5DddXk/8O9ussnmsolpuCQbAiEQLuVqCBJFwFioWhgLQ2212IqiqK3Y21DtUFqm1FpxCkhtLdSCgKCWdKA4KJZwBycgIkIIl5JI2BAIEBJ2N7fN7jm/P36/tPgbzvd5NuecveX1mjnjH59nn+8nT85+970H8/20VKvVagEAANTUOtwbAACAkU5oBgCAgNAMAAABoRkAAAJCMwAABIRmAAAICM0AABAYnymqVCrFunXrio6OjqKlpaXZexq1qtVq0dPTU3R2dhatrYP7fcSMc8y4+cy4+cy4+cy4+cy4+cx4aKTnXE3o6uqqFkXhlXx1dXVlxmrGZjyiX2ZsxmPhZcZmPBZeZjwy5pz6pLmjoyNTNmLNmDEjrPniF7+Y6vWZz3ym5lp3d3cxd+7cXZpX9msyf5aiKIqNGzcOeg/1mD9/flizfv36VK/NmzeHNfXMeNGiRcX48bXf+suXLx9077GonhlPnDix9JONbdu27fK+xpJm3iv23XffVF3me3fHjh1hTfZTsPvvvz+saWtrS/WaOnVqzbVqtVps2rSpqTM+4YQTUnWZ+Z166qlhTfb7Zu3atWHN7NmzU73KZrFt27bi4osvbuqMszLzy/w9/OQnP0ldb6h/xjbzfnzsscem+mW+d0e7aM6p0DzaP9bP7H/SpEmpXtOmTWvI9Xb1awb7n2eGSmZfjXwf1TPj8ePHl4Zm/q96ZtzS0jLq7xtDYSTcKzLfC9VqtWHXy2jEn7FSqQyq165cP3sfycyvvb091StjwoQJYc3EiRNTvTL7auaMs7K/aEVG6s/YZt6P/Tz8X9GcR+a7AwAARhChGQAAAkIzAAAEhGYAAAgIzQAAEBCaAQAgMKjnjMyaNav0cSzr1q2re0ODkX1u4euvvx7WfP7zn0/1uuSSS2qu7XzEUTNt2LCh6dd4s3vvvTdV98wzz4Q15557br3baQjPYW4+z2EefpnHnBVF7tFjL7/8clizYsWK1PUy+vr6UnXNvh/Omzev9GfeQQcdlOqTeRzaqlWrwpp/+Zd/SV1v5syZYc0nPvGJVK/e3t6aa9u3b0/1GAqPPvpoWJN5ZOzBBx+cut6mTZvCmsyZA0VRPuNKpVL3M6FbW1tLH6V2991319V/d+KTZgAACAjNAAAQEJoBACAgNAMAQEBoBgCAgNAMAAABoRkAAAJCMwAABIRmAAAIDOpEwMypUEOpp6dnyK+5du3aIb9ms/zgBz8Ia0488cRUr4svvrje7RA45JBDwpqnnnoq1avsdCjK7bXXXqm6slPgKpVK8dJLL9W1j2nTppX+Pc6YMSPVZ/369WHNSLv3D5WWlpbSGd9///2pPitXrmzUllIyJyUecMABqV7z58+vubZ58+bi61//enpfzZS5P86ZMyesWbBgQep6hx56aFjz2muvpXp94xvfqLk2MDBQ94mAW7Zsqevr+V8+aQYAgIDQDAAAAaEZAAACQjMAAASEZgAACAjNAAAQEJoBACAgNAMAQGBQh5swtmzatCmsOeKII1K9VqxYUeduiNxzzz1hzUUXXdT0fWTMmzevaG2t/Tv56tWrh3A3jXXsscem6o488siaa9u3by8uvfTSRm3pLR1++OGpuj333DOsWbRoUVhz0003pa6XOXhjpOjt7S19H69atWoId9NYy5cvT9X9+Mc/rrnW19fXqO3Ubdy4cWHNu9/97rDmgx/8YOp6e+yxR1jz3e9+N9Xrscceq7lWrVZTPSgXHeiVnbNPmgEAICA0AwBAQGgGAICA0AwAAAGhGQAAAkIzAAAEhGYAAAgIzQAAEBjVh5scddRRqboHHnggrLnwwgtTvS6//PJU3Wjwu7/7u8O9hbd0zDHH1FwbGBgofRD8WPZbv/VbYU3mvT4UJk6cWHrYwOLFi1N9sgcwDKXbbrstVff888/XXBsYGKh7H+3t7aUHbyxYsCDV59BDDw1rPvShD4U1p556aup6mXvonXfemerVbK+++mpD+mQOJpo8eXJYM23atNT1rr322rDmy1/+cqrXcIsOpdipu7s7rMncHx966KHU9a677rqwZtu2bale7Lp58+al6k466aTS9b6+vuI73/lO2McnzQAAEBCaAQAgIDQDAEBAaAYAgIDQDAAAAaEZAAACQjMAAASEZgAACAjNAAAQGNUnAr7xxhupur6+vrCm7PSu3VnZqW5vdtFFF4U1s2bNSvX62te+VnOtESepjVYj5bS/jAULFhRtbW0115988slUn/322y+s6e/vD2tefPHF1PUaKftn3FVz5swp/f585zvfmerT1dUV1tx1111hzaZNm1LXGymn/Q2l6DSyoiiK9evXhzUf/vCHU9fLnEA4WlSr1VTdqlWrwpoHH3yw3u0wwmSzW6Mynk+aAQAgIDQDAEBAaAYAgIDQDAAAAaEZAAACQjMAAASEZgAACAjNAAAQGNWHm2QfVj1z5szmbmQEOuyww8KaU045Jaw54IADUtfL1F166aWpXitWrEjVseumTJlSc61arRZbtmypq//5559fTJ06teb6Qw89lOpz8803hzUvvfRSWDMch5s0289+9rPS9aVLl6b69PT0hDW33HJLWLN169bU9UaTww8/vPQAmZ///OepPi0tLY3aEjW8/PLLw72FXTZx4sSaa9VqNXVAW5ljjjmmGD++dtwrO4jqzTL32jVr1oQ1ZXt5s23btqXqhpJPmgEAICA0AwBAQGgGAICA0AwAAAGhGQAAAkIzAAAEhGYAAAgIzQAAEBCaAQAgMKpPBNwdTZgwIVU3e/bssGbGjBlhzcMPP5y63nnnnZeqG0sOOuigsObZZ59t2PUyJ1tOnjw51atSqZSu1Xsi4Pbt20tPffrsZz+b6jN37tyw5k/+5E/CmuyJVzt27EjVjQZXXXXVcG9h1HM6KbVMmjQprOno6Ej1euWVV+rdTqlrr722dC/77rtvqk9vb29Yk8kMzz33XOp6jz/+eFjzT//0T6lejeKTZgAACAjNAAAQEJoBACAgNAMAQEBoBgCAgNAMAAABoRkAAAKp5zRXq9Vm72NM2ZV5Zb8mW9ff3x/WbNu2Lazp6+tLXW+oNXPGWQMDAw3tFyl7tvJgaqK6nWv1zHjz5s2ldd3d3al+medFZ/7MI/UeNhLex2OdGTff7jjjzP6z9+NGXa/W10TPV87ejzPPaY7u/UVRFFu3bk1dbzjyRzTnVGju6elpyGZ2Fz09PcX06dMH/TUZ2cMX7rnnnobUjFTNnHHWqlWrGtovsnHjxobUZNUz4zPOOKNh+xjLRsL7eKwz4+bbHWec+dApU5NVz4wXL17csH2MddGcW6qJX18qlUqxbt26oqOjo2hpaWnoBseSarVa9PT0FJ2dnUVr6+D+ny9mnGPGzWfGzWfGzWfGzWfGzWfGQyM751RoBgCA3Zl/CAgAAAGhGQAAAkIzAAAEhGYAAAgIzQAAEBCaAQAgIDQDAEBAaAYAgIDQDAAAAaEZAAACQjMAAASEZgAACAjNAAAQEJoBACAgNAMAQEBoBgCAgNAMAAABoRkAAAJCMwAABIRmAAAICM0AABAQmgEAICA0AwBAQGgGAICA0AwAAAGhGQAAAkIzAAAEhGYAAAgIzQAAEBCaAQAgIDQDAEBAaAYAgIDQDAAAAaEZAAACQjMAAASEZgAACAjNAAAQEJoBACAgNAMAQEBoBgCAgNAMAAABoRkAAAJCMwAABIRmAAAICM0AABAQmgEAICA0AwBAQGgGAICA0AwAAAGhGQAAAkIzAAAEhGYAAAgIzQAAEBCaAQAgIDQDAEBAaAYAgIDQDAAAAaEZAAACQjMAAASEZgAACAjNAAAQEJoBACAgNAMAQEBoBgCAgNAMAAABoRkAAAJCMwAABIRmAAAICM0AABAQmgEAI
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "okUztZOd79by"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Transfer learning"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "8MRwLZgGhK-K"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Funkcje pomocnicze"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "1AIiBfX78luv"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Loading the specified subset of digits from the MNIST dataset\n",
|
|||
|
|
"def get_mnist(digits):\n",
|
|||
|
|
" # Reload the MNIST data\n",
|
|||
|
|
" (X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
|
|||
|
|
"\n",
|
|||
|
|
" train_ids = np.in1d(y_train, digits)\n",
|
|||
|
|
" test_ids = np.in1d(y_test, digits)\n",
|
|||
|
|
" X_train = X_train[train_ids]\n",
|
|||
|
|
" X_test = X_test[test_ids]\n",
|
|||
|
|
" y_train = y_train[train_ids]\n",
|
|||
|
|
" y_test = y_test[test_ids]\n",
|
|||
|
|
"\n",
|
|||
|
|
" X_train = np.expand_dims(X_train, axis=-1)\n",
|
|||
|
|
" X_test = np.expand_dims(X_test, axis=-1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" X_train = X_train.astype('float32') # change integers to 32-bit floating point numbers\n",
|
|||
|
|
" X_test = X_test.astype('float32')\n",
|
|||
|
|
"\n",
|
|||
|
|
" X_train /= 255 # normalize each value for each pixel for the entire vector for each input\n",
|
|||
|
|
" X_test /= 255\n",
|
|||
|
|
"\n",
|
|||
|
|
" return X_train, y_train, X_test, y_test"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Zf0T6tdPBC70"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Preparing the base model\n",
|
|||
|
|
"def prepare_model(input_shape, class_count):\n",
|
|||
|
|
" model = Sequential() # Linear stacking of layers\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Convolution Layer 1\n",
|
|||
|
|
" model.add(Conv2D(16, (3, 3), padding=\"same\", input_shape=input_shape))\n",
|
|||
|
|
" model.add(Activation('relu') )\n",
|
|||
|
|
" model.add(MaxPooling2D(pool_size=(2,2)))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Convolution Layer 2\n",
|
|||
|
|
" model.add(Conv2D(32, (3, 3), padding=\"same\"))\n",
|
|||
|
|
" model.add(Activation('relu'))\n",
|
|||
|
|
" model.add(MaxPooling2D(pool_size=(2,2)))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Convolution Layer 3\n",
|
|||
|
|
" model.add(Conv2D(64, (3, 3), padding=\"same\"))\n",
|
|||
|
|
" model.add(Activation('relu'))\n",
|
|||
|
|
" model.add(MaxPooling2D(pool_size=(2,2)))\n",
|
|||
|
|
"\n",
|
|||
|
|
" model.add(Flatten())\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Fully Connected Layer\n",
|
|||
|
|
" model.add(Dense(64))\n",
|
|||
|
|
" model.add(Activation('relu'))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Fully Connected Layer\n",
|
|||
|
|
" model.add(Dense(class_count))\n",
|
|||
|
|
" model.add(Activation('softmax'))\n",
|
|||
|
|
"\n",
|
|||
|
|
" return model"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "EX6bXUPBGDyx"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# One-hot encoding of non-consecutive labels\n",
|
|||
|
|
"def one_hot(labels):\n",
|
|||
|
|
" from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
|
" encoder = OneHotEncoder(categories='auto')\n",
|
|||
|
|
" l = labels.reshape(-1, 1)\n",
|
|||
|
|
" output = encoder.fit_transform(l)\n",
|
|||
|
|
" return output.toarray()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "-VP_9oFFiE6a"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"def test_model(model, X_test, Y_test, y_test, digits):\n",
|
|||
|
|
" score = model.evaluate(X_test, Y_test)\n",
|
|||
|
|
" print('Test score:', score[0])\n",
|
|||
|
|
" print('Test accuracy:', score[1])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # The predict_classes function outputs the highest probability class\n",
|
|||
|
|
" # according to the trained classifier for each input example.\n",
|
|||
|
|
" predicted = model.predict(X_test)\n",
|
|||
|
|
" predicted_classes = np.argmax(predicted, axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" actual_classes = [digits[x] for x in predicted_classes]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Check which items we got right / wrong\n",
|
|||
|
|
" correct_indices = np.nonzero(actual_classes == y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
" incorrect_indices = np.nonzero(actual_classes != y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" cnf_matrix = confusion_matrix(y_test, actual_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
" class_names = [str(i) for i in digits]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Plot non-normalized confusion matrix\n",
|
|||
|
|
" plt.figure()\n",
|
|||
|
|
" plot_confusion_matrix(cnf_matrix, classes=class_names,\n",
|
|||
|
|
" title='Confusion matrix, without normalization')\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.show()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Fi7BwsX8h5Dm"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Trening modelu bazowego: cyfry 0..4"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "ZyAj3Kk4Bw4d",
|
|||
|
|
"outputId": "57531b4b-9442-40c0-ef12-ab3083bacfa8"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"digits = [0, 1, 2, 3, 4]\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train, y_train, X_test, y_test = get_mnist(digits)\n",
|
|||
|
|
"\n",
|
|||
|
|
"Y_train = one_hot(y_train)\n",
|
|||
|
|
"Y_test = one_hot(y_test)\n",
|
|||
|
|
"\n",
|
|||
|
|
"model = prepare_model( (28, 28, 1), 5)\n",
|
|||
|
|
"\n",
|
|||
|
|
"adam = tf.optimizers.Adam(learning_rate=0.001)\n",
|
|||
|
|
"model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.summary()\n",
|
|||
|
|
"\n"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Model: \"sequential_2\"\n",
|
|||
|
|
"_________________________________________________________________\n",
|
|||
|
|
" Layer (type) Output Shape Param # \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
" conv2d_2 (Conv2D) (None, 28, 28, 16) 160 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_7 (Activation) (None, 28, 28, 16) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_2 (MaxPoolin (None, 14, 14, 16) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_3 (Conv2D) (None, 14, 14, 32) 4640 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_8 (Activation) (None, 14, 14, 32) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_3 (MaxPoolin (None, 7, 7, 32) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_4 (Conv2D) (None, 7, 7, 64) 18496 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_9 (Activation) (None, 7, 7, 64) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_4 (MaxPoolin (None, 3, 3, 64) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" flatten_1 (Flatten) (None, 576) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_5 (Dense) (None, 64) 36928 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_10 (Activation) (None, 64) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_6 (Dense) (None, 5) 325 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_11 (Activation) (None, 5) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
"Total params: 60549 (236.52 KB)\n",
|
|||
|
|
"Trainable params: 60549 (236.52 KB)\n",
|
|||
|
|
"Non-trainable params: 0 (0.00 Byte)\n",
|
|||
|
|
"_________________________________________________________________\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "BjEZRMnvChne"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,\n",
|
|||
|
|
" height_shift_range=0.08, zoom_range=0.08, validation_split=0.2)\n",
|
|||
|
|
"\n",
|
|||
|
|
"test_gen = ImageDataGenerator()\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_generator = gen.flow(X_train, Y_train, batch_size=128, subset='training')\n",
|
|||
|
|
"valid_generator = gen.flow(X_train, Y_train, batch_size=128, subset='validation')\n",
|
|||
|
|
"test_generator = test_gen.flow(X_test, Y_test, batch_size=128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "byGKttIXDfE3",
|
|||
|
|
"outputId": "8247c380-84e0-4ce9-d5ed-ea4d9bcd1adc"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model.fit_generator(train_generator, steps_per_epoch=25000//128, epochs=5, verbose=1, validation_data=valid_generator, validation_steps = 5000 // 128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/5\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stderr",
|
|||
|
|
"text": [
|
|||
|
|
"<ipython-input-49-1bedfa515cd7>:1: UserWarning: `Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`, which supports generators.\n",
|
|||
|
|
" model.fit_generator(train_generator, steps_per_epoch=25000//128, epochs=5, verbose=1, validation_data=valid_generator, validation_steps = 5000 // 128)\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"191/195 [============================>.] - ETA: 0s - loss: 0.2686 - accuracy: 0.9136"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stderr",
|
|||
|
|
"text": [
|
|||
|
|
"WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 975 batches). You may need to use the repeat() function when building your dataset.\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r195/195 [==============================] - 12s 49ms/step - loss: 0.2677 - accuracy: 0.9139 - val_loss: 0.0842 - val_accuracy: 0.9722\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<keras.src.callbacks.History at 0x7ac88c26df00>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 49
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"id": "3sUGcowEMLHk",
|
|||
|
|
"outputId": "2d4ff1d7-bf82-4350-d63f-16165c51be19"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"test_model(model, X_test, Y_test, y_test, digits)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"161/161 [==============================] - 1s 3ms/step - loss: 0.0327 - accuracy: 0.9885\n",
|
|||
|
|
"Test score: 0.03266139328479767\n",
|
|||
|
|
"Test accuracy: 0.9885191917419434\n",
|
|||
|
|
"161/161 [==============================] - 0s 2ms/step\n",
|
|||
|
|
"Confusion matrix, without normalization\n",
|
|||
|
|
"[[ 976 3 1 0 0]\n",
|
|||
|
|
" [ 0 1130 4 0 1]\n",
|
|||
|
|
" [ 15 7 998 7 5]\n",
|
|||
|
|
" [ 1 0 4 1002 3]\n",
|
|||
|
|
" [ 1 5 2 0 974]]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 2 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1UAAAN6CAYAAABmBWMlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB8BUlEQVR4nOzdeVhU5d/H8c8ZlEUQcAPEFbVcUjSXjNyXn7iUlaVZlrhkZS6pqeVTmZpl2maaqW0uZbm0WFpZ5m7ummVu5ZamAioKggsK8/xBTI1ogQc8Z/D98jpXzZkzM5/DjDjf+d73PYbT6XQKAAAAAHBVHFYHAAAAAABPRlEFAAAAACZQVAEAAACACRRVAAAAAGACRRUAAAAAmEBRBQAAAAAmUFQBAAAAgAkUVQAAAABgQgGrAwAAAAC4OufOnVNqaqrVMbLF29tbvr6+VsfIExRVAAAAgAc6d+6c/AoXky6esTpKtoSFhWn//v35srCiqAIAAAA8UGpqqnTxjHyqxUhe3lbH+XdpqYrdMUOpqakUVQAAAABsxstbhs2LKqfVAfIYRRUAAADgyQxHxmZnds9nUv4+OwAAAADIYxRVAAAAAGACRRUAAAAAmMCcKgAAAMCTGZIMw+oU/87m8cyiUwUAAAAAJlBUAQAAAIAJDP8DAAAAPBlLqlsuf58dAAAAAOQxiioAAAAAMIHhfwAAAIAnMwwPWP3P5vlMolMFAAAAACZQVAEAAACACRRVAAAAAGACc6oAAAAAT8aS6pbL32cHAAAAAHmMogoAAAAATGD4HwAAAODJWFLdcnSqAAAAAMAEiioAAAAAMIGiCgAAAABMYE4VAAAA4NE8YEn1fN7Lyd9nBwAAAAB5jKIKAAAAAExg+B8AAADgyVhS3XJ0qgAAAADABIoqAAAAADCB4X8AAACAJzM8YPU/u+czKX+fHQAAAADkMYoqAAAAADCBogoAAAAATGBOFQAAAODJWFLdcnSqAAAAAMAEiioAAAAAMIHhfwAAAIAnY0l1y+XvswMAAACAPEZRBQAAAAAmMPwPAAAA8GSs/mc5OlUAAAAAYAJFFQAAAACYQFEFAAAAACYwpwoAAADwZCypbrn8fXYAAAAAkMcoqgAAAADABIb/AQAAAJ7MMOw/vI4l1QEAAAAAV0JRBQAAAAAmMPwPAAAA8GQOI2OzM7vnM4lOFQAAAACYQFEFAAAAACZQVAEAAACACcypAgAAADyZ4fCAJdVtns+k/H12AAAAAJDHKKoAAAAAwASG/wEAAACezDAyNjuzez6T6FQBAAAAgAkUVQAAAABgAkUVAAAAAJjAnCoAAADAk7GkuuXy99kBAAAAQB6jqAIAAAAAExj+BwAAAHgyllS3HJ0qAAAAADCBogoAAAAATGD4HwAAAODJWP3Pcvn77AAAAAAgj1FUAQAAAIAJFFUAAAAAYAJzqgAAAABPxpLqlqNTBQAAAAAmUFQBAAAAgAkM/wMAAAA8GUuqWy5/nx0AAAAA5DGKKiAf+/3339WqVSsFBQXJMAzNnz8/V+//wIEDMgxD06dPz9X7zQ/Kly+vbt26WR0ji5w8Z5nHvvrqq3kfDJc1YsQIGZdM7rbqtWXX1zQA2AFFFZDH9u7dq0cffVQVKlSQr6+vAgMD1aBBA7355ps6e/Zsnj52TEyMtm3bphdffFEffvih6tatm6ePlx/t2LFDI0aM0IEDB6yOkme++eYbjRgxwuoYWbz00ku5/kEA/t2aNWs0YsQInTp1yuooAHIic/U/u2/5GHOqgDz09ddfq2PHjvLx8VHXrl1VvXp1paamavXq1RoyZIi2b9+ud955J08e++zZs1q7dq2eeeYZ9e3bN08eo1y5cjp79qwKFiyYJ/dvBzt27NDIkSPVtGlTlS9fPtu32717txwO+31udbnn7JtvvtGkSZNsV1i99NJLuvfee3XXXXdZHcVW8vK1tWbNGo0cOVLdunVTcHDwNXtcAPB0FFVAHtm/f786d+6scuXKaenSpSpZsqTruj59+mjPnj36+uuv8+zxjx07JklZ3hjlJsMw5Ovrm2f372mcTqfOnTsnPz8/+fj4WB3nsnjOzElJSZG/v7+lGax6bdn1NQ0AdsBHTkAeGTdunJKTk/X++++7FVSZKlWqpCeeeMJ1+eLFi3rhhRdUsWJF+fj4qHz58vq///s/nT9/3u125cuX1+23367Vq1frlltuka+vrypUqKCZM2e6jhkxYoTKlSsnSRoyZIgMw3B1Wbp163bZjsvl5m4sXrxYDRs2VHBwsAICAlS5cmX93//9n+v6K83PWbp0qRo1aiR/f38FBwfrzjvv1M6dOy/7eHv27HF9Kh4UFKTu3bvrzJkzV/7B/qVp06aqXr26fvnlFzVp0kSFChVSpUqV9Omnn0qSVqxYofr168vPz0+VK1fWDz/84Hb7P/74Q48//rgqV64sPz8/FStWTB07dnQb5jd9+nR17NhRktSsWTMZhiHDMLR8+XJJfz8X3333nerWrSs/Pz9NnTrVdV3m/BOn06lmzZqpRIkSio+Pd91/amqqatSooYoVKyolJeU/z/mfBg0apGLFisnpdLr29evXT4ZhaMKECa59cXFxMgxDkydPlpT1OevWrZsmTZokSa7zu/R1IEnvvPOO67VZr149bdy4Mcsx2Xnes/v6MwxDKSkpmjFjhivTv83nWb58uQzD0Ny5c/Xiiy+qdOnS8vX1VYsWLbRnz54sx8+bN0916tSRn5+fihcvrgcffFCHDx/OkjUgIEB79+5V27ZtVbhwYXXp0sWVr2/fvpo3b56qVasmPz8/RUVFadu2bZKkqVOnqlKlSvL19VXTpk2zDB9dtWqVOnbsqLJly8rHx0dlypTRwIEDszUk+NK5Tf983i7dMh/3l19+Ubdu3VzDkMPCwtSjRw+dOHHC7TkYMmSIJCkiIiLLfVxuTtW+ffvUsWNHFS1aVIUKFdKtt96a5cOinD43AOCJ6FQBeWTBggWqUKGCbrvttmwd//DDD2vGjBm699579eSTT2r9+vUaM2aMdu7cqS+++MLt2D179ujee+9Vz549FRMTow8++EDdunVTnTp1dNNNN6lDhw4KDg7WwIEDdf/996tt27YKCAjIUf7t27fr9ttvV2RkpEaNGiUfHx/t2bNHP/7447/e7ocfflCbNm1UoUIFjRgxQmfPntXEiRPVoEEDbdmyJcsb6k6dOikiIkJjxozRli1b9N577ykkJERjx479z4wnT57U7bffrs6dO6tjx46aPHmyOnfurFmzZmnAgAF67LHH9MADD+iVV17Rvffeq0OHDqlw4cKSpI0bN2rNmjXq3LmzSpcurQMHDmjy5Mlq2rSpduzYoUKFCqlx48bq37+/JkyYoP/7v/9T1apVJcn1XyljSNT999+vRx99VL169VLlypWz5DQMQx988IEiIyP12GOP6fPPP5ckPf/889q+fbuWL1+e4+5Ho0aN9MYbb2j79u2qXr26pIw36g6HQ6tWrVL//v1d+ySpcePGl72fRx99VEeOHNHixYv14YcfXvaYjz/+WKdPn9ajjz4qwzA0btw4dejQQfv27XMNI8zp8/5fPvzwQz388MO65ZZb9Mgjj0iSKlas+J+3e/nll+VwODR48GAlJiZq3Lhx6tKli9avX+86Zvr06erevbvq1aunMWPGKC4uTm+++aZ+/PFH/fTTT27d3YsXLyo6OloNGzbUq6++qkKFCrmuW7Vqlb766iv16dNHkjRmzBjdfvvtGjp0qN5++209/vjjOnnypMaNG6cePXpo6dKlrtvOmzdPZ86cUe/evVWsWDFt2LBBEydO1J9//ql58+bl+Gd1qWeffVbx8fGuv/eLFy/Wvn371L17d4WFhbmGHm/fvl3r1q2TYRjq0KGDfvvtN33yySd64403VLx4cUlSiRIlLvu4cXFxuu2223TmzBn1799fxYoV04wZM9S+fXt9+umnuvvuu3P83AC4Wh6wpHp+7+U4AeS6xMREpyTnnXfema3jt27d6pTkfPjhh932Dx482CnJuXTpUte+cuXKOSU5V65c6doXHx/v9
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "0fkbvsjbiIi5"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Transfer wiedzy do nowego modelu: cyfry 5..9"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "MfJ3UmNrMwfC"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Extract specified range of layers to a new sequential model\n",
|
|||
|
|
"def extract_layers(main_model, starting_layer_ix, ending_layer_ix):\n",
|
|||
|
|
" # create an empty model\n",
|
|||
|
|
" new_model = Sequential()\n",
|
|||
|
|
" for ix in range(starting_layer_ix, ending_layer_ix + 1):\n",
|
|||
|
|
" curr_layer = main_model.get_layer(index=ix)\n",
|
|||
|
|
" # copy this layer over to the new model\n",
|
|||
|
|
" new_model.add(curr_layer)\n",
|
|||
|
|
" return new_model"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "aAY_aLnOiTct"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Nowy model stworzony jest na bazie wytrenowanego modelu bazowego. Część odpowiedzialna za detekcję cech w obrazie (warstwy splotowe) są wykorzystywane jako podstawa dla nowej sieci. Część klasyfikacyjna jest dodawana na nowo, z losowo zainicjowanymi wagami.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Warstwy splotowe były wstępnie wytrenowane w modelu bazowym, więc są zamrażane. Treningowi podlega tylko część klasyfikacyjna."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "qzDfXKIMNXMK",
|
|||
|
|
"outputId": "2c637b2d-91a2-4e18-a317-a26654b73cf8"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"new_model = extract_layers(model, 0, 9)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Fully Connected Layer\n",
|
|||
|
|
"new_model.add(Dense(64))\n",
|
|||
|
|
"new_model.add(Activation('relu'))\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Fully Connected Layer\n",
|
|||
|
|
"new_model.add(Dense(5))\n",
|
|||
|
|
"new_model.add(Activation('softmax'))\n",
|
|||
|
|
"\n",
|
|||
|
|
"for ix in range(0, 9+1):\n",
|
|||
|
|
" new_model.get_layer(index=ix).trainable=False\n",
|
|||
|
|
"\n",
|
|||
|
|
"new_digits = [5, 6, 7, 8, 9]\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train, y_train, X_test, y_test = get_mnist(new_digits)\n",
|
|||
|
|
"\n",
|
|||
|
|
"Y_train = one_hot(y_train)\n",
|
|||
|
|
"Y_test = one_hot(y_test)\n",
|
|||
|
|
"\n",
|
|||
|
|
"new_model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.legacy.Adam(), metrics=['accuracy'])\n",
|
|||
|
|
"\n",
|
|||
|
|
"new_model.summary()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Model: \"sequential_3\"\n",
|
|||
|
|
"_________________________________________________________________\n",
|
|||
|
|
" Layer (type) Output Shape Param # \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
" conv2d_2 (Conv2D) (None, 28, 28, 16) 160 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_7 (Activation) (None, 28, 28, 16) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_2 (MaxPoolin (None, 14, 14, 16) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_3 (Conv2D) (None, 14, 14, 32) 4640 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_8 (Activation) (None, 14, 14, 32) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_3 (MaxPoolin (None, 7, 7, 32) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_4 (Conv2D) (None, 7, 7, 64) 18496 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_9 (Activation) (None, 7, 7, 64) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_4 (MaxPoolin (None, 3, 3, 64) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" flatten_1 (Flatten) (None, 576) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_7 (Dense) (None, 64) 36928 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_12 (Activation) (None, 64) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_8 (Dense) (None, 5) 325 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_13 (Activation) (None, 5) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
"Total params: 60549 (236.52 KB)\n",
|
|||
|
|
"Trainable params: 37253 (145.52 KB)\n",
|
|||
|
|
"Non-trainable params: 23296 (91.00 KB)\n",
|
|||
|
|
"_________________________________________________________________\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "quqhC141PBTO"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,\n",
|
|||
|
|
" height_shift_range=0.08, zoom_range=0.08, validation_split=0.2)\n",
|
|||
|
|
"\n",
|
|||
|
|
"test_gen = ImageDataGenerator()\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_generator = gen.flow(X_train, Y_train, batch_size=128, subset='training')\n",
|
|||
|
|
"valid_generator = gen.flow(X_train, Y_train, batch_size=128, subset='validation')\n",
|
|||
|
|
"test_generator = test_gen.flow(X_test, Y_test, batch_size=128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "oqy-_NxblqDy"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Trening nowego modelu"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "l0H5XRxWPKvC",
|
|||
|
|
"outputId": "8707e644-66da-4ad4-b04d-b1d8585bdf73"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"new_model.fit_generator(train_generator, steps_per_epoch=24000//128, epochs=5, verbose=1, validation_data=valid_generator, validation_steps = 6000 // 128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/5\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stderr",
|
|||
|
|
"text": [
|
|||
|
|
"<ipython-input-54-05ca407e5c2c>:1: UserWarning: `Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`, which supports generators.\n",
|
|||
|
|
" new_model.fit_generator(train_generator, steps_per_epoch=24000//128, epochs=5, verbose=1, validation_data=valid_generator, validation_steps = 6000 // 128)\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"183/187 [============================>.] - ETA: 0s - loss: 0.3467 - accuracy: 0.8922"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stderr",
|
|||
|
|
"text": [
|
|||
|
|
"WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 935 batches). You may need to use the repeat() function when building your dataset.\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r187/187 [==============================] - 10s 48ms/step - loss: 0.3461 - accuracy: 0.8922 - val_loss: 0.1978 - val_accuracy: 0.9406\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<keras.src.callbacks.History at 0x7ac8819626b0>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 54
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"id": "FRuh3MN-PeDr",
|
|||
|
|
"outputId": "2e6dbbb2-3fd5-4e9c-acb6-a93d2b72892b"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"test_model(new_model, X_test, Y_test, y_test, new_digits)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"152/152 [==============================] - 0s 3ms/step - loss: 0.1151 - accuracy: 0.9656\n",
|
|||
|
|
"Test score: 0.1151459813117981\n",
|
|||
|
|
"Test accuracy: 0.9656449556350708\n",
|
|||
|
|
"152/152 [==============================] - 0s 2ms/step\n",
|
|||
|
|
"Confusion matrix, without normalization\n",
|
|||
|
|
"[[876 1 3 9 3]\n",
|
|||
|
|
" [ 3 949 0 6 0]\n",
|
|||
|
|
" [ 0 0 999 11 18]\n",
|
|||
|
|
" [ 18 3 10 936 7]\n",
|
|||
|
|
" [ 18 4 26 27 934]]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 2 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0wAAAN6CAYAAAC9vskHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB/l0lEQVR4nOzdeZyN5f/H8fd9hlnMZp0ZYx1kCymksWbJEpWUvoqMtcWWJaIiEiKVSGTJkiRUikT2JbsiIbLLmoYZgzFj5vz+kPPrNO5pppy578Pr6XE/6tznPue8z8ztOJ/zua7rGE6n0ykAAAAAQBoOqwMAAAAAgF1RMAEAAACACQomAAAAADBBwQQAAAAAJiiYAAAAAMAEBRMAAAAAmKBgAgAAAAATFEwAAAAAYCKb1QEAAAAA/DuJiYlKSkqyOkaG+Pr6yt/f3+oYmUbBBAAAAHihxMREBQTnka5esjpKhkREROjQoUNeVzRRMAEAAABeKCkpSbp6SX5lYyQfX6vjpC8lSad2T1dSUhIFEwAAAIAs5OMrw+YFk9PqAP8BBRMAAADgzQzHtc3O7J4vHd6bHAAAAAA8jIIJAAAAAExQMAEAAACACeYwAQAAAN7MkGQYVqdIn83jpYcOEwAAAACYoGACAAAAABMMyQMAAAC8GcuKe5T3JgcAAAAAD6NgAgAAAAATDMkDAAAAvJlheMEqeTbPlw46TAAAAABggoIJAAAAAExQMAEAAACACeYwAQAAAN6MZcU9ynuTAwAAAICHUTABAAAAgAmG5AEAAADejGXFPYoOEwAAAACYoGACAAAAABMUTAAAAABggjlMAAAAgFfzgmXFvbhP473JAQAAAMDDKJgAAAAAwARD8gAAAABvxrLiHkWHCQAAAABMUDABAAAAgAmG5AEAAADezPCCVfLsni8d3pscAAAAADyMggkAAAAATFAwAQAAAIAJ5jABAAAA3oxlxT2KDhMAAAAAmKBgAgAAAAATDMkDAAAAvBnLinuU9yYHAAAAAA+jYAIAAAAAEwzJAwAAALwZq+R5FB0mAAAAADBBwQQAAAAAJiiYAAAAAMAEc5gAAAAAb8ay4h7lvckBAAAAwMMomAAAAADABEPyAAAAAG9mGPYf8say4gAAAABw66FgAgAAAAATDMkDAAAAvJnDuLbZmd3zpYMOEwAAAACYoGACAAAAABMUTAAAAABggjlMAAAAgDczHF6wrLjN86XDe5MDAAAAgIdRMAEAAACACYbkAQAAAN7MMK5tdmb3fOmgwwQAAAAAJiiYAAAAAMAEBRMAAAAAmGAOEwAAAODNWFbco7w3OQAAAAB4GAUTAAAAAJhgSB4AAADgzVhW3KPoMAEAAACACQomAAAAADDBkDwAAADAm7FKnkd5b3IAAAAA8DAKJgAAAAAwQcEEAAAAACaYwwQAAAB4M5YV9yg6TAAAAABggoIJAAAAAEwwJA8AAADwZiwr7lHemxwAAAAAPIyCCbiF/frrr2rQoIFCQ0NlGIbmz59/U+//8OHDMgxD06ZNu6n3eysoWrSo2rZta3WMNDLzO7t+7KhRozwfDDc0aNAgGX+bKG3VuWXXcxoAPI2CCfCwAwcO6Nlnn1WxYsXk7++vkJAQVa9eXe+9954uX77s0ceOiYnRzp07NXToUH388ceqXLmyRx/vVrR7924NGjRIhw8ftjqKxyxatEiDBg2yOkYaw4YNu+lFPtK3fv16DRo0SOfPn7c6CoDMuL5Knt03L8UcJsCDvvnmG7Vo0UJ+fn5q06aNypUrp6SkJK1bt059+vTRrl27NHHiRI889uXLl7Vhwwa98sor6tq1q0ceo0iRIrp8+bKyZ8/ukfu3g927d2vw4MG6//77VbRo0Qzfbu/evXI47PeZ1I1+Z4sWLdK4ceNsVzQNGzZMjz/+uJo1a2Z1FFvx5Lm1fv16DR48WG3btlXOnDmz7HEBwM4omAAPOXTokFq2bKkiRYpoxYoVyp8/v+u6Ll26aP/+/frmm2889vi///67JKV503MzGYYhf39/j92/t3E6nUpMTFRAQID8/PysjnND/M7+m4sXLyowMNDSDFadW3Y9pwHA0/ioCPCQkSNHKiEhQVOmTHErlq4rUaKEXnjhBdflq1evasiQISpevLj8/PxUtGhRvfzyy7py5Yrb7YoWLaqmTZtq3bp1uvfee+Xv769ixYppxowZrmMGDRqkIkWKSJL69OkjwzBc3ZG2bdvesFNyo7kSS5cuVY0aNZQzZ04FBQWpVKlSevnll13Xm82HWbFihWrWrKnAwEDlzJlTjzzyiPbs2XPDx9u/f7/r0+zQ0FC1a9dOly5dMv/B/un+++9XuXLl9NNPP6l27drKkSOHSpQooXnz5kmSVq9erapVqyogIEClSpXSsmXL3G5/5MgRde7cWaVKlVJAQIDy5MmjFi1auA29mzZtmlq0aCFJqlOnjgzDkGEYWrVqlaT//10sWbJElStXVkBAgD788EPXddfnezidTtWpU0f58uXTmTNnXPeflJSk8uXLq3jx4rp48eI/Pue/6tWrl/LkySOn0+na161bNxmGoTFjxrj2nT59WoZhaPz48ZLS/s7atm2rcePGSZLr+f39PJCkiRMnus7NKlWqaMuWLWmOycjvPaPnn2EYunjxoqZPn+7KlN78mVWrVskwDM2ZM0dDhw5VwYIF5e/vr3r16mn//v1pjp87d64qVaqkgIAA5c2bV61bt9bx48fTZA0KCtKBAwf04IMPKjg4WK1atXLl69q1q+bOnauyZcsqICBA0dHR2rlzpyTpww8/VIkSJeTv76/7778/zZDOtWvXqkWLFipcuLD8/PxUqFAh9ezZM0PDdP8+l+ivv7e/b9cf96efflLbtm1dQ4MjIiLUvn17/fHHH26/gz59+kiSoqKi0tzHjeYwHTx4UC1atFDu3LmVI0cO3XfffWk+CMrs7wYA7IYOE+AhCxYsULFixVStWrUMHd+xY0dNnz5djz/+uHr37q1NmzZp+PDh2rNnj7788ku3Y/fv36/HH39cHTp0UExMjD766CO1bdtWlSpV0p133qnmzZsrZ86c6tmzp5588kk9+OCDCgoKylT+Xbt2qWnTpqpQoYJef/11+fn5af/+/fr+++/Tvd2yZcvUuHFjFStWTIMGDdLly5c1duxYVa9eXT/88EOaN8tPPPGEoqKiNHz4cP3www+aPHmywsLCNGLEiH/MeO7cOTVt2lQtW7ZUixYtNH78eLVs2VKffPKJevTooeeee05PPfWU3nrrLT3++OM6duyYgoODJUlbtmzR+vXr1bJlSxUsWFCHDx/W+PHjdf/992v37t3KkSOHatWqpe7du2vMmDF6+eWXVaZMGUly/Ve6NkzpySef1LPPPqtOnTqpVKlSaXIahqGPPvpIFSpU0HPPPacvvvhCkvTaa69p165dWrVqVaa7FjVr1tS7776rXbt2qVy5cpKuvQl3OBxau3atunfv7tonSbVq1brh/Tz77LM6ceKEli5dqo8//viGx8yaNUsXLlzQs88+K8MwNHLkSDVv3lwHDx50De3L7O/9n3z88cfq2LGj7r33Xj3zzDOSpOLFi//j7d588005HA69+OKLiouL08iRI9WqVStt2rTJdcy0adPUrl07ValSRcOHD9fp06f13nvv6fvvv9ePP/7o1pW9evWqGjZsqBo1amjUqFHKkSOH67q1a9fq66+/VpcuXSRJw4cPV9OmTdW3b1998MEH6ty5s86dO6eRI0eqffv2WrFiheu2c+fO1aVLl/T8888rT5482rx5s8aOHavffvtNc+fOzfTP6u9effVVnTlzxvX3funSpTp48KDatWuniIgI13DgXbt2aePGjTIMQ82bN9e+ffv06aef6t1331XevHklSfny5bvh454+fVrVqlXTpUuX1L17d+XJk0fTp0/Xww8/rHnz5unRRx/N9O8GwL/lBcuKe3OfxgngpouLi3NKcj7yyCMZOn779u1OSc6OHTu67X/xxRedkpwrVqxw7StSpIhTknPNmjWufWfOnHH6+fk5e/fu7dp36NAhpyTnW2+95
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "2zgQd-nJjU2F"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Fine-tuning"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "7kctU-jel8k-"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"Po wytrenowaniu części klasyfikacyjnej można ponownie odmrozić wszystkie warstwy i przeprowadzić dotrenowanie pełnej sieci na nowych danych, bez ryzyka związanego z dużymi gradientami w pierwszych fazach uczenia."
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "yu4ks06pR1aO",
|
|||
|
|
"outputId": "1ed39e42-c1a6-4b11-d75c-d23808ba121f"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"for ix in range(0, 9+1):\n",
|
|||
|
|
" new_model.get_layer(index=ix).trainable=True\n",
|
|||
|
|
"\n",
|
|||
|
|
"new_model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.legacy.Adam(), metrics=['accuracy'])\n",
|
|||
|
|
"new_model.summary()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Model: \"sequential_3\"\n",
|
|||
|
|
"_________________________________________________________________\n",
|
|||
|
|
" Layer (type) Output Shape Param # \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
" conv2d_2 (Conv2D) (None, 28, 28, 16) 160 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_7 (Activation) (None, 28, 28, 16) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_2 (MaxPoolin (None, 14, 14, 16) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_3 (Conv2D) (None, 14, 14, 32) 4640 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_8 (Activation) (None, 14, 14, 32) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_3 (MaxPoolin (None, 7, 7, 32) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" conv2d_4 (Conv2D) (None, 7, 7, 64) 18496 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_9 (Activation) (None, 7, 7, 64) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" max_pooling2d_4 (MaxPoolin (None, 3, 3, 64) 0 \n",
|
|||
|
|
" g2D) \n",
|
|||
|
|
" \n",
|
|||
|
|
" flatten_1 (Flatten) (None, 576) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_7 (Dense) (None, 64) 36928 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_12 (Activation) (None, 64) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_8 (Dense) (None, 5) 325 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_13 (Activation) (None, 5) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
"Total params: 60549 (236.52 KB)\n",
|
|||
|
|
"Trainable params: 60549 (236.52 KB)\n",
|
|||
|
|
"Non-trainable params: 0 (0.00 Byte)\n",
|
|||
|
|
"_________________________________________________________________\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "g9x5tXRzS9yE",
|
|||
|
|
"outputId": "48754b34-1097-441c-d66a-c63fdd06fb99"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"new_model.fit_generator(train_generator, steps_per_epoch=24000//128, epochs=2, verbose=1, validation_data=valid_generator, validation_steps = 6000 // 128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/2\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stderr",
|
|||
|
|
"text": [
|
|||
|
|
"<ipython-input-57-5ed88bbb5f18>:1: UserWarning: `Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`, which supports generators.\n",
|
|||
|
|
" new_model.fit_generator(train_generator, steps_per_epoch=24000//128, epochs=2, verbose=1, validation_data=valid_generator, validation_steps = 6000 // 128)\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"184/187 [============================>.] - ETA: 0s - loss: 0.1362 - accuracy: 0.9557"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stderr",
|
|||
|
|
"text": [
|
|||
|
|
"WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 374 batches). You may need to use the repeat() function when building your dataset.\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r187/187 [==============================] - 10s 49ms/step - loss: 0.1362 - accuracy: 0.9557 - val_loss: 0.0849 - val_accuracy: 0.9733\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<keras.src.callbacks.History at 0x7ac8817f1330>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 57
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"id": "pbfKwpJ9TI-e",
|
|||
|
|
"outputId": "cd721adf-b87f-49cf-f542-94dfdd5c3c24"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"test_model(new_model, X_test, Y_test, y_test, new_digits)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"152/152 [==============================] - 0s 3ms/step - loss: 0.0470 - accuracy: 0.9831\n",
|
|||
|
|
"Test score: 0.047003742307424545\n",
|
|||
|
|
"Test accuracy: 0.9831310510635376\n",
|
|||
|
|
"152/152 [==============================] - 0s 2ms/step\n",
|
|||
|
|
"Confusion matrix, without normalization\n",
|
|||
|
|
"[[ 880 4 2 3 3]\n",
|
|||
|
|
" [ 1 954 0 3 0]\n",
|
|||
|
|
" [ 0 0 1013 5 10]\n",
|
|||
|
|
" [ 6 7 5 951 5]\n",
|
|||
|
|
" [ 4 3 7 14 981]]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 2 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1UAAAN6CAYAAABmBWMlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB89UlEQVR4nOzde3zO9f/H8ednYwebbU7bzHHIKaSQ5pjIuUgpRebYAQkhvuUcopNDciyHJFEhEjkfciZyyvmUbCtjM2xju35/4Pq1qDafbZ/PNY+72+dW1+d6X5/r+dn1cbleex8uw+FwOAQAAAAAuCtuVgcAAAAAAFdGUQUAAAAAJlBUAQAAAIAJFFUAAAAAYAJFFQAAAACYQFEFAAAAACZQVAEAAACACRRVAAAAAGBCNqsDAAAAALg78fHxSkxMtDpGqnh4eMjLy8vqGBmCogoAAABwQfHx8fLOmUe6fsXqKKkSHBysEydOZMnCiqIKAAAAcEGJiYnS9SvyLBsuuXtYHeffJSUq4sBMJSYmUlQBAAAAsBl3Dxk2L6ocVgfIYBRVAAAAgCsz3G5sdmb3fCZl7bMDAAAAgAxGUQUAAAAAJlBUAQAAAIAJFFUAAACAKzMkGYbNt7Sf1vr16/XEE08oJCREhmFo4cKFKe53OBwaOHCg8ufPL29vb9WrV09HjhxJ0SY6OlqtW7eWn5+fAgIC1LFjR8XFxaVo88svv6hmzZry8vJSoUKFNHr06DRnpagCAAAAYDuXL1/WAw88oAkTJtzx/tGjR2vcuHGaNGmStm7dKh8fHzVo0EDx8fHONq1bt9b+/fu1YsUKLVmyROvXr9dLL73kvD82Nlb169dXkSJFtHPnTr333nsaPHiwpkyZkqashsPhyOorHAIAAABZTmxsrPz9/eX5wMsy3D2tjvOvHEkJStgzWTExMfLz80vz4w3D0IIFC9S8efMbx3M4FBISojfeeEO9e/eWJMXExCgoKEgzZsxQq1atdPDgQZUtW1bbt29X5cqVJUnLli1T48aN9dtvvykkJEQTJ07UW2+9pYiICHl43FiWvl+/flq4cKF+/fXXVOejpwoAAABwZbeWVLf7phuF4F+3hISEuzrlEydOKCIiQvXq1XPu8/f3V9WqVbV582ZJ0ubNmxUQEOAsqCSpXr16cnNz09atW51tatWq5SyoJKlBgwY6dOiQLly4kOo8FFUAAAAAMkWhQoXk7+/v3EaOHHlXx4mIiJAkBQUFpdgfFBTkvC8iIkKBgYEp7s+WLZty586dos2djvHX50gNvvwXAAAAQKY4c+ZMiuF/np72HraYWhRVAAAAgCu7tcKend3M5+fnd1dzqv4uODhYkhQZGan8+fM790dGRqpixYrONlFRUSked/36dUVHRzsfHxwcrMjIyBRtbt2+1SY1GP4HAAAAwKWEhoYqODhYq1atcu6LjY3V1q1bFRYWJkkKCwvTxYsXtXPnTmeb1atXKzk5WVWrVnW2Wb9+va5du+Zss2LFCpUqVUq5cuVKdR6KKgAAAAC2ExcXp927d2v37t2SbixOsXv3bp0+fVqGYahHjx5655139N1332nv3r1q27atQkJCnCsElilTRg0bNlTnzp21bds2/fTTT+rWrZtatWqlkJAQSdILL7wgDw8PdezYUfv379dXX32lsWPHqlevXmnKyvA/AAAAALazY8cO1alTx3n7VqETHh6uGTNmqG/fvrp8+bJeeuklXbx4UTVq1NCyZcvk5eXlfMwXX3yhbt26qW7dunJzc9PTTz+tcePGOe/39/fXjz/+qK5du6pSpUrKmzevBg4cmOK7rFKD76kCAAAAXJDze6oe6uYa31O16+O7/p4qu2P4HwAAAACYQFEFAAAAACYwpwoAAABwZS60pHpWRU8VAAAAAJhAUQUAAAAAJlBUAQAAAIAJzKkCAAAAXJqbZNi9r8Tu+czJ2mcHAAAAABmMogoAAAAATGD4HwAAAODKWFLdcvRUAQAAAIAJFFUAAAAAYALD/wAAAABXZrjA6n92z2dS1j47AAAAAMhgFFUAAAAAYAJFFQAAAACYwJwqAAAAwJWxpLrl6KkCAAAAABMoqgAAAADABIb/AQAAAK6MJdUtl7XPDgAAAAAyGEUVAAAAAJjA8D8AAADAlbH6n+XoqQIAAAAAEyiqAAAAAMAEiioAAAAAMIE5VQAAAIArY0l1y2XtswMAAACADEZRBQAAAAAmMPwPAAAAcGWGYf/hdSypDgAAAAD4JxRVAAAAAGACw/8AAAAAV+Zm3NjszO75TKKnCgAAAABMoKgCAAAAABMoqgAAAADABOZUAQAAAK7McHOBJdVtns+krH12AAAAAJDBKKoAAAAAwASG/wEAAACuzDBubHZm93wm0VMFAAAAACZQVAEAAACACRRVAAAAAGACc6oAAAAAV8aS6pbL2mcHAAAAABmMogoAAAAATGD4HwAAAODKWFLdcvRUAQAAAIAJFFUAAAAAYALD/wAAAABXxup/lsvaZwcAAAAAGYyiCgAAAABMoKgCAAAAABOYUwUAAAC4MpZUtxw9VQAAAABgAkUVAAAAAJjA8D8AAADAlbGkuuWy9tkBAAAAQAajqAKysCNHjqh+/fry9/eXYRhauHBhuh7/5MmTMgxDM2bMSNfjZgVFixZVu3btrI5xm7S8Zrfavv/++xkfDHc0ePBgGX+b3G3VtWXXaxoA7ICiCshgx44d08svv6xixYrJy8tLfn5+ql69usaOHaurV69m6HOHh4dr7969Gj58uD7//HNVrlw5Q58vKzpw4IAGDx6skydPWh0lwyxdulSDBw+2OsZtRowYke6/CMC/27RpkwYPHqyLFy9aHQVAWtxa/c/uWxbGnCogA33//fdq2bKlPD091bZtW5UrV06JiYnauHGj+vTpo/3792vKlCkZ8txXr17V5s2b9dZbb6lbt24Z8hxFihTR1atXlT179gw5vh0cOHBAQ4YM0aOPPqqiRYum+nGHDh2Sm5v9fm91p9ds6dKlmjBhgu0KqxEjRuiZZ55R8+bNrY5iKxl5bW3atElDhgxRu3btFBAQkGnPCwCujqIKyCAnTpxQq1atVKRIEa1evVr58+d33te1a1cdPXpU33//fYY9/x9//CFJt30wSk+GYcjLyyvDju9qHA6H4uPj5e3tLU9PT6vj3BGvmTmXL1+Wj4+PpRmsurbsek0DgB3wKycgg4wePVpxcXH69NNPUxRUt5QoUUKvv/668/b169c1bNgwFS9eXJ6enipatKj+97//KSEhIcXjihYtqqZNm2rjxo16+OGH5eXlpWLFimnWrFnONoMHD1aRIkUkSX369JFhGM5elnbt2t2xx+VOczdWrFihGjVqKCAgQL6+vipVqpT+97//Oe//p/k5q1evVs2aNeXj46OAgAA1a9ZMBw8evOPzHT161PlbcX9/f7Vv315Xrlz55x/sTY8++qjKlSunX375RbVr11aOHDlUokQJff3115KkdevWqWrVqvL29lapUqW0cuXKFI8/deqUunTpolKlSsnb21t58uRRy5YtUwzzmzFjhlq2bClJqlOnjgzDkGEYWrt2raT/fy2WL1+uypUry9vbW5MnT3bed2v+icPhUJ06dZQvXz5FRUU5j5+YmKjy5curePHiunz58n+e81/16tVLefLkkcPhcO577bXXZBiGxo0b59wXGRkpwzA0ceJESbe/Zu3atdOECRMkyXl+f78OJGnKlCnOa7NKlSravn37bW1S87qn9vozDEOXL1/WzJkznZn+bT7P2rVrZRiG5s2bp+HDh6tgwYLy8vJS3bp1dfTo0dvaz58/X5UqVZK3t7fy5s2rNm3a6OzZs7dl9fX11bFjx9S4cWPlzJlTrVu3dubr1q2b5s+fr7Jly8rb21thYWHau3evJGny5MkqUaKEvLy89Oijj942fHTDhg1q2bKlChcuLE9PTxUqVEg9e/ZM1ZDgv89t+uvr9vft1vP+8ssvateunXMYcnBwsDp06KDz58+neA369OkjSQoNDb3tGHeaU3X8+HG1bNlSuXPnVo4cOfTII4/c9suit
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "e6ssF6eHeHro"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# Test on different dataset"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "s00LerSteLv5",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "54a09fc6-ea36-4740-abbb-6f4f96a72402"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"!wget https://raw.githubusercontent.com/googlecreativelab/quickdraw-dataset/master/categories.txt"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"--2024-04-17 13:17:15-- https://raw.githubusercontent.com/googlecreativelab/quickdraw-dataset/master/categories.txt\n",
|
|||
|
|
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
|
|||
|
|
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
|
|||
|
|
"HTTP request sent, awaiting response... 200 OK\n",
|
|||
|
|
"Length: 2791 (2.7K) [text/plain]\n",
|
|||
|
|
"Saving to: ‘categories.txt.2’\n",
|
|||
|
|
"\n",
|
|||
|
|
"\rcategories.txt.2 0%[ ] 0 --.-KB/s \rcategories.txt.2 100%[===================>] 2.73K --.-KB/s in 0s \n",
|
|||
|
|
"\n",
|
|||
|
|
"2024-04-17 13:17:15 (48.8 MB/s) - ‘categories.txt.2’ saved [2791/2791]\n",
|
|||
|
|
"\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "H4d-lVtqgzgC"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"import os\n",
|
|||
|
|
"import urllib\n",
|
|||
|
|
"\n",
|
|||
|
|
"def download_and_load(class_names, test_split = 0.2, max_items_per_class = 10000):\n",
|
|||
|
|
" root = 'data'\n",
|
|||
|
|
" if not os.path.exists(root):\n",
|
|||
|
|
" os.makedirs(root)\n",
|
|||
|
|
"\n",
|
|||
|
|
" print('downloading ...')\n",
|
|||
|
|
" base = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'\n",
|
|||
|
|
" for c in class_names:\n",
|
|||
|
|
" cc = c.replace('_', '%20')\n",
|
|||
|
|
" path = base+cc+'.npy'\n",
|
|||
|
|
" print(path)\n",
|
|||
|
|
" if not os.path.exists(f'{root}/{cc}.npy'):\n",
|
|||
|
|
" urllib.request.urlretrieve(path, f'{root}/{cc}.npy')\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" print(\"Already downloaded\")\n",
|
|||
|
|
" print('loading ...')\n",
|
|||
|
|
"\n",
|
|||
|
|
" #initialize variables\n",
|
|||
|
|
" x = np.empty([0, 784])\n",
|
|||
|
|
" y = np.empty([0])\n",
|
|||
|
|
"\n",
|
|||
|
|
" #load each data file\n",
|
|||
|
|
" for idx, file in enumerate(class_names):\n",
|
|||
|
|
" file = file.replace('_', '%20')\n",
|
|||
|
|
" data = np.load(f'{root}/{file}.npy')\n",
|
|||
|
|
" data = data[0: max_items_per_class, :]\n",
|
|||
|
|
" labels = np.full(data.shape[0], idx)\n",
|
|||
|
|
"\n",
|
|||
|
|
" x = np.concatenate((x, data), axis=0)\n",
|
|||
|
|
" y = np.append(y, labels)\n",
|
|||
|
|
"\n",
|
|||
|
|
" data = None\n",
|
|||
|
|
" labels = None\n",
|
|||
|
|
"\n",
|
|||
|
|
" #randomize the dataset\n",
|
|||
|
|
" permutation = np.random.permutation(y.shape[0])\n",
|
|||
|
|
" x = x[permutation, :]\n",
|
|||
|
|
" y = y[permutation]\n",
|
|||
|
|
"\n",
|
|||
|
|
" #reshape and inverse the colors\n",
|
|||
|
|
" x = 255 - np.reshape(x, (x.shape[0], 28, 28))\n",
|
|||
|
|
"\n",
|
|||
|
|
" #separate into training and testing\n",
|
|||
|
|
" test_size = int(x.shape[0]/100*(test_split*100))\n",
|
|||
|
|
"\n",
|
|||
|
|
" x_test = x[0:test_size, :]\n",
|
|||
|
|
" y_test = y[0:test_size]\n",
|
|||
|
|
"\n",
|
|||
|
|
" x_train = x[test_size:x.shape[0], :]\n",
|
|||
|
|
" y_train = y[test_size:y.shape[0]]\n",
|
|||
|
|
"\n",
|
|||
|
|
" print('Training Data : ', x_train.shape[0])\n",
|
|||
|
|
" print('Testing Data : ', x_test.shape[0])\n",
|
|||
|
|
" return x_train, y_train, x_test, y_test"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "kAiJYOJBgXup"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# set your random seed value, put any number here\n",
|
|||
|
|
"RANDOM_SEED = 1234"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "CsfWTt6NeUEJ",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"outputId": "1136782a-7706-4944-f093-135b31374d0c"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"import random\n",
|
|||
|
|
"\n",
|
|||
|
|
"with open('categories.txt') as f:\n",
|
|||
|
|
" all_class_names = f.readlines()\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"random.seed(RANDOM_SEED)\n",
|
|||
|
|
"\n",
|
|||
|
|
"all_class_names = [x.strip().replace(' ', '_') for x in all_class_names]\n",
|
|||
|
|
"\n",
|
|||
|
|
"nb_classes = 10\n",
|
|||
|
|
"\n",
|
|||
|
|
"# select random 10 classes\n",
|
|||
|
|
"class_names = random.sample(all_class_names, nb_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# you can also try those \"hard\" classes instead\n",
|
|||
|
|
"#class_names = ['wheel', 'pizza', 'smiley_face', 'apple', 'potato', 'basketball', 'soccer_ball', 'brain', 'clock', 'circle']\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(class_names)\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train, y_train, X_test, y_test = download_and_load(class_names)\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train = X_train.reshape(X_train.shape[0], 28, 28, 1) #add an additional dimension to represent the single-channel\n",
|
|||
|
|
"X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train = X_train.astype('float32') # change integers to 32-bit floating point numbers\n",
|
|||
|
|
"X_test = X_test.astype('float32')\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train /= 255 # normalize each value for each pixel for the entire vector for each input\n",
|
|||
|
|
"X_test /= 255\n",
|
|||
|
|
"\n",
|
|||
|
|
"Y_train = to_categorical(y_train, nb_classes)\n",
|
|||
|
|
"Y_test = to_categorical(y_test, nb_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"visualize_classes(X_train, y_train)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"['pizza', 'cannon', 'ambulance', 'bulldozer', 'swing_set', 'baseball_bat', 'zebra', 'bridge', 'cactus', 'matches']\n",
|
|||
|
|
"downloading ...\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/pizza.npy\n",
|
|||
|
|
"Already downloaded\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/cannon.npy\n",
|
|||
|
|
"Already downloaded\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/ambulance.npy\n",
|
|||
|
|
"Already downloaded\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/bulldozer.npy\n",
|
|||
|
|
"Already downloaded\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/swing%20set.npy\n",
|
|||
|
|
"Already downloaded\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/baseball%20bat.npy\n",
|
|||
|
|
"Already downloaded\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/zebra.npy\n",
|
|||
|
|
"Already downloaded\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/bridge.npy\n",
|
|||
|
|
"Already downloaded\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/cactus.npy\n",
|
|||
|
|
"Already downloaded\n",
|
|||
|
|
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/matches.npy\n",
|
|||
|
|
"Already downloaded\n",
|
|||
|
|
"loading ...\n",
|
|||
|
|
"Training Data : 80000\n",
|
|||
|
|
"Testing Data : 20000\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 1000x2000 with 1 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAMaCAYAAAABQDBSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9Z3Oc55klfDrnnNBAIwcSzCIpyZIlJ3nG493Zra2t3a3dX7A/aH/Ffpipt3bG4xnbGkuWLStSTCBybgCNzjmH9wPrXLy7BZBgBqjnVLEEkUCj+3nu576vcM65dL1erwcNGjRo0KBBgwYNGjRoeIHQv+43oEGDBg0aNGjQoEGDhjcPWqKhQYMGDRo0aNCgQYOGFw4t0dCgQYMGDRo0aNCgQcMLh5ZoaNCgQYMGDRo0aNCg4YVDSzQ0aNCgQYMGDRo0aNDwwqElGho0aNCgQYMGDRo0aHjh0BINDRo0aNCgQYMGDRo0vHBoiYYGDRo0aNCgQYMGDRpeOLREQ4MGDRo0aNCgQYMGDS8cWqKhQYMGDRo0aNCgQYOGFw7j634DGjRo0KBBgwYNGt5sNBoNVCoVdDodNBoNNBoNAIDRaIRer4fZbIbX64XFYnnN7/TloNPpoNPpoNlsIpfLoVKpwGQyweFwwGg0wmq1wm63Q69/s3oAWqKhQYMGDRo0aNCg4aWiUqlgb28P9XoduVwOuVwOAGCz2STJmJubeyMTjV6vh2aziXq9jlKphIWFBezv78PlcmF4eBh2ux2BQAAWi0VLNDRo+CGj1+sBAHQ63Wt+Jxo0aNCgQcPZAAPtUqmEarWKXC6HTCYDALDb7RJgN5tNdLtd6HS6N+qc5eev1WqoVCrI5/NIp9NoNpuw2+1otVqwWq1oNBrQ6XTQ6/VvTMKhJRoaNJwQrVYL9XodvV4PFovljay6aNCgQYMGDS8K3W4X7XYbnU4Ha2tr+Jd/+RfkcjmUy2WUSiUAgNlshsFgwPj4OOx2OwDAarXC5XLBYDC8zrf/wtBoNHDnzh3cuXMHhUIBa2trSCaTsFgs8Pl8sFgsmJ2dxY0bN+B2uzE0NIRIJPJGJBtaoqFBwwnRaDRQKpXQ6XTgdrthNpvfqIqLBg0aNGjQ8CLR7XbRaDTQarWwurqKf/qnf8L+/j5arRZarZZ8n06nw4ULFzA3NweXywWPxwO73f5GJBq9Xg+NRgPfffcd/vEf/xHFYhEHBwfI5/PQ6/UwGAwwGAx4++230el0EAqFAAChUEhLNF4nSGHhfwG8ca02Dc+GZrOJRqPRtza63S6azSY6nQ5MJhOsVitMJhMMBgOMxv7HgC3OTqeDdruNRqOBdrst7c5Op4NarYZGowGj0Sj8UrY7NWh4VqhrdhDa3qZBg4azhm63i1qthnq9jnK5jGq1imq1CrPZDKfTiW63i0qlgnq9jlqtJv9us9nQ7XZf99t/bnQ6HWFDVCoVFAoFuQ4Uw5MmRToVABQKBdTrdZjNZhHLn1WcyUSj2+2i1Wqh2+1KMAg8bLVZrdbX/O40vE50u11sb29jeXkZzWZTArdqtYqNjQ1ks1nEYjG8++67CAaD8Pl8CIVCfclGo9HA1tYW0uk0MpkMFhYWxCEin8+j1+thaGgI0WgULpcL169fx8zMDIxGo7hHaNDwtFD3NeBR0sGK15vG29WgQcObj3K5jOXlZWSzWWxubqJWq0Gn0+HKlSu4efMm6vU6/vznP2NhYQGNRgOHh4fY3t5Gp9NBNBo98zFdsVhEKpVCNptFPB5HIpFAvV5HvV6X7+n1euh2u9jZ2cHHH38Mt9uNbrcLt9sNp9OJSCQCj8fzGj/F8+FMRkS9Xg+tVksyxWazCeDhgWyxWLTK3w8Y3W4Xh4eHuH37NqrVqvx9LpfD119/jZ2dHVy5cgU+nw/NZhM6nQ6BQKDvNVqtFhKJBDY3N7G9vY3f//732Nvbk2oEAIyNjWFsbAyhUAg+nw+RSEQSXS3R0PAs4L7W7XbR6/Uk0TAajej1etDr9drepkGDhjOFarWKeDyO/f19JBIJqeJPTU3ho48+QqVSwebmJh48eIB2u41cLodEIgGn04lOp/Oa3/3zodfroVKp4PDwEJlMRhIOlTLG7+v1ekgmk0ilUrDb7RgbG8P8/Dx8Pp9Qyc4qXmlExKyNbSMmCp1OR/6N/2VVTz10+TU9mNnNaLVa0Ol0cDqdcLlcknCQGmOxWGA0GoUyo1YHNbw54DppNpsol8sol8toNptotVoolUpoNBrQ6/Wo1+vY3d1Fu92GwWDAyMiIVIr1er0EfPV6Hc1mE3q9XtaPyWTqW4O1Wg3FYhG5XA4mkwnNZhNms1neD/CI8mIwGOBwOETs9kNDu91GtVqV/xaLRbTb7b7rZDKZ5Fq7XC55ds1m82t7XsmvTafTqFarcv+Pa+ure9Xg3x9F+SS4pgYTDdL71H1Nr9fLf/nv5PmaTCbodDoYjUb5ubPeen+Z6PV6aLfb0lHi2cJ/Ax5RGwwGA2w2m1xjLfHToOHxIHWqUqmg3W7L+ehwOOD1emV+Bp8lxnVH7aFnERTDM9Z9Enh+8DqoZ+RZxStLNJg81Ot1PHjwAKurq6jVakilUrIAecg2m02xOFO/pqCIr6Ue6DqdDg6HA06nE2azGcPDw/D7/XA4HBgeHobL5UIgEMD4+Lhw6jXXoDcLDM4KhQK2t7eRy+WkggA8DBZcLhcymQx+85vfwGQy4de//jUmJibQ6/VgtVphsVjQ6XRQKpWQyWRQq9XgcDgQCAQkSe10OtDr9SgWiwCA9fV1seYzGAzQ6/WyUbASzWE8Fy9exNTU1A+yOl2pVLC2toZcLof19XV8+eWXIq7v9XowGo0IBoPweDzw+/24du0ahoaG4Ha7EY1GX8vzyoQikUjgX//1X7G2toZmsymFEt5nFdznBoNV9fA86uf4s/z+brcr64Rry+VyweFwwGKxIBAIwGazwW63w+/3w2w2w+12w+v1wmg0wu/3w+VywWQywe12w2q1asHxEWi32ygWi2g0GkJvqNVqMlwLgOwNDocDk5OTItI0Go3a9dSg4TEgHWp3dxe1Wk2Si9HRUUxNTSGbzcLj8YjoWy3ynfUAG3j4eSqVihTZTvqZ+HNms1nkAWcVryzR4OHaarWQTqexsbGBSqWCeDyOQqGAVqslN4KVYh7YPLRrtZosvsGbxUTD4XDAZrNhcnISQ0ND8Hq9aLfb8Pl8AIBIJCLVPyYoGt4c9Ho91Ot1FAoF5HI5HBwcIJFIwGw2IxwOw+VyoVqt4uDgAJ1OBxcvXkSlUpGADHgkHOd6NJlMIkyjMFyn08n35HI5JJNJ+f2seHNTYZXZ4XBgbGzsjdg8nwWtVgvZbBaHh4dYW1vDV199hUwmI0G3yWTCyMgIgsGgWPsxgXtdLXS1C7uysoLbt2+j0WhIN4bVKhXcq1qtlqwHrgn+3eM6IoNgcmAwGODz+eB0OmG32zE8PAyHwwGPxyPUvUAggEajIddNr9fDarXC4XBoM2COAZ9rdtkODg5QLpf7uuVM6DweD4aGhvqSSO16atBwPDqdDiqVCkqlksyKMBqNcLlc8Hq96Ha7fUPq1I7GmwC1o3HShEE9L57mrDiteKmJBqkA7XYbiUQC8XgcpVIJt2/fxtLSUh+VotfribrebDbDZrP1Hci8WaS1lMtl0WeQ8+d2u6Wy1+12kc1mUa1WUa/XYbfbsbW1hXg8DqvVCr/fL9/r9Xol0LTb7dLa03C2wMrv6Ogo3n//fZRKJSSTSWSzWakGm81mpNNprKysiPNFvV5HsVhEMplEs9lEPp/H1tYWkskkqtWqBBuhUAizs7MwmUzSEWNVM51Oo9VqSWWUASoAmbnhdrtx8eJFCVJI4XuTQdpRu91GMpnEwsICtra2cHBwAKPRCLfbLZQ04KEeoVKpIJlM4s6dO9jd3cX09LQ82yaT6ZV1NtrtNjKZDIrFIuLxOFKpFDKZDPR6Pex2+7H3jlQcHhTs2KpdjlAoBK/X2/ca/JpdDP4/90GVAmU0GqU4w99hMBiEPmo2m+HxeOB0OmEymeB0OmGz2fq6bseBN
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "W5Jm2OPyk4FF"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Create the model"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "1k0Qnajnk3ey"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model = Sequential() # Linear stacking of layers\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Convolution Layer 1\n",
|
|||
|
|
"model.add(Conv2D(64, (5, 5), input_shape=(28,28,1))) # 16 different 5x5 kernels -- so 16 feature maps\n",
|
|||
|
|
"model.add(BatchNormalization()) # weights normalization\n",
|
|||
|
|
"model.add(Activation('relu') ) # activation\n",
|
|||
|
|
"model.add(MaxPooling2D(pool_size=(2,2))) # Pool the max values over a 2x2 kernel\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Convolution Layer 2\n",
|
|||
|
|
"model.add(Conv2D(32, (5, 5))) # 32 different 5x5 kernels -- so 32 feature maps\n",
|
|||
|
|
"model.add(BatchNormalization()) # weights normalization\n",
|
|||
|
|
"model.add(Activation('relu')) # activation\n",
|
|||
|
|
"model.add(MaxPooling2D(pool_size=(2,2))) # Pool the max values over a 2x2 kernel\n",
|
|||
|
|
"\n",
|
|||
|
|
"model.add(Flatten()) # Flatten final output matrix into a vector\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Fully Connected Layer\n",
|
|||
|
|
"model.add(Dense(128)) # 128 FC nodes\n",
|
|||
|
|
"model.add(Activation('relu')) # activation\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Fully Connected Layer\n",
|
|||
|
|
"model.add(Dense(10)) # final 10 FC nodes\n",
|
|||
|
|
"model.add(Activation('softmax')) # softmax activation\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# we'll use the same optimizer\n",
|
|||
|
|
"adam = tf.optimizers.Adam(learning_rate=0.001)\n",
|
|||
|
|
"model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "HmR_tm0gmR0u"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# data augmentation prevents overfitting by slightly changing the data randomly\n",
|
|||
|
|
"# Keras has a great built-in feature to do automatic augmentation\n",
|
|||
|
|
"\n",
|
|||
|
|
"gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,\n",
|
|||
|
|
" height_shift_range=0.08, zoom_range=0.08, validation_split=0.2)\n",
|
|||
|
|
"\n",
|
|||
|
|
"test_gen = ImageDataGenerator()\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_generator = gen.flow(X_train, Y_train, batch_size=128, subset='training')\n",
|
|||
|
|
"valid_generator = gen.flow(X_train, Y_train, batch_size=128, subset='validation')\n",
|
|||
|
|
"test_generator = test_gen.flow(X_test, Y_test, batch_size=128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "p-jeOuI8nfl9",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "19180704-a705-4043-8475-e50ef757ec61"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model.fit(train_generator, steps_per_epoch=64000//128, epochs=5, verbose=1, validation_data=valid_generator, validation_steps = 16000 // 128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/5\n",
|
|||
|
|
"500/500 [==============================] - 27s 49ms/step - loss: 0.7637 - accuracy: 0.7616 - val_loss: 0.8774 - val_accuracy: 0.7241\n",
|
|||
|
|
"Epoch 2/5\n",
|
|||
|
|
"500/500 [==============================] - 24s 48ms/step - loss: 0.4911 - accuracy: 0.8494 - val_loss: 0.5371 - val_accuracy: 0.8357\n",
|
|||
|
|
"Epoch 3/5\n",
|
|||
|
|
"500/500 [==============================] - 24s 48ms/step - loss: 0.4233 - accuracy: 0.8712 - val_loss: 0.5017 - val_accuracy: 0.8528\n",
|
|||
|
|
"Epoch 4/5\n",
|
|||
|
|
"500/500 [==============================] - 24s 48ms/step - loss: 0.3844 - accuracy: 0.8838 - val_loss: 0.5333 - val_accuracy: 0.8395\n",
|
|||
|
|
"Epoch 5/5\n",
|
|||
|
|
"500/500 [==============================] - 24s 48ms/step - loss: 0.3574 - accuracy: 0.8913 - val_loss: 0.7751 - val_accuracy: 0.7667\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<keras.src.callbacks.History at 0x7ac881903010>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 65
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "OlzMJklYooX8",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"outputId": "606d31d6-feb1-45b1-9179-109ee31013a6"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"score = model.evaluate(X_test, Y_test)\n",
|
|||
|
|
"print('Test score:', score[0])\n",
|
|||
|
|
"print('Test accuracy:', score[1])"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"625/625 [==============================] - 2s 3ms/step - loss: 1.1437 - accuracy: 0.6999\n",
|
|||
|
|
"Test score: 1.143710732460022\n",
|
|||
|
|
"Test accuracy: 0.6998999714851379\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "sAhuIjARoqhp",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"outputId": "f3589fdb-13ba-468f-cf29-6cb3bac97446"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# The predict_classes function outputs the highest probability class\n",
|
|||
|
|
"# according to the trained classifier for each input example.\n",
|
|||
|
|
"predicted = model.predict(X_test)\n",
|
|||
|
|
"predicted_classes = np.argmax(predicted, axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Check which items we got right / wrong\n",
|
|||
|
|
"correct_indices = np.nonzero(predicted_classes == y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"incorrect_indices = np.nonzero(predicted_classes != y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"cnf_matrix = confusion_matrix(y_test, predicted_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Plot non-normalized confusion matrix\n",
|
|||
|
|
"plt.figure()\n",
|
|||
|
|
"plot_confusion_matrix(cnf_matrix, classes=class_names,\n",
|
|||
|
|
" title='Confusion matrix, without normalization')\n",
|
|||
|
|
"\n",
|
|||
|
|
"plt.show()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"625/625 [==============================] - 1s 2ms/step\n",
|
|||
|
|
"Confusion matrix, without normalization\n",
|
|||
|
|
"[[1212 20 647 38 7 0 77 3 21 35]\n",
|
|||
|
|
" [ 0 1232 643 62 2 0 68 4 5 7]\n",
|
|||
|
|
" [ 1 3 1959 6 1 0 4 1 2 1]\n",
|
|||
|
|
" [ 1 11 1085 896 3 1 19 0 2 4]\n",
|
|||
|
|
" [ 1 7 221 24 1603 1 48 6 3 23]\n",
|
|||
|
|
" [ 3 582 75 40 25 977 21 14 122 124]\n",
|
|||
|
|
" [ 0 13 102 18 0 0 1820 2 2 4]\n",
|
|||
|
|
" [ 2 48 487 96 16 0 178 1192 1 10]\n",
|
|||
|
|
" [ 0 8 273 18 5 1 47 3 1648 16]\n",
|
|||
|
|
" [ 4 58 237 63 25 3 66 10 62 1459]]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 2 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA18AAANgCAYAAADTatIRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddVhUaRsG8HuGlpQWSUERAwsDFcVE7BYT7LVYdc21XT917Q5MrMXYtbsQbFGxsFuXsEjpme8PltkdQQWBOQPcv73Otc4575x5zsRhnnnf9zkiqVQqBRERERERERUosdABEBERERERFQdMvoiIiIiIiBSAyRcREREREZECMPkiIiIiIiJSACZfRERERERECsDki4iIiIiISAGYfBERERERESkAky8iIiIiIiIFUBU6ACIiIiIiUoykpCSkpKQIHcZ3qaurQ1NTU+gw8h2TLyIiIiKiYiApKQlaukZA2mehQ/kuc3NzPH/+vMglYEy+iIiIiIiKgZSUFCDtMzQqeAMq6kKH83XpKYgI80dKSgqTLyIiIiIiKsRU1CFS4uRLKnQABYjJFxERERFRcSISZyzKSpljy6Oie2RERERERERKhMkXERERERGRAnDYIRERERFRcSICIBIJHcXXKXFoecWeLyIiIiIiIgVg8kVERERERKQATL6IiIiIiIgUgHO+iIiIiIiKE5aaF0zRPTIiIiIiIiIlwuSLiIiIiIhIATjskIiIiIioOBGJlLzUvBLHlkfs+SIiIiIiIlIAJl9EREREREQKwGGHRERERETFCasdCqboHhkREREREZESYfJFRERERESkABx2SERERERUnLDaoWDY80VERERERKQATL6IiIiIiIgUgMkXERERERGRAnDOFxERERFRsaLkpeaLcP9Q0T0yIiIiIiIiJcLki4iIiIiISAE47JCIiIiIqDhhqXnBsOeLiIiIiIhIAZh8ERERERERKQCHHRIRERERFSciJa92qMyx5VHRPTIiIiIiIiIlwuSLiIiIiIhIATjskIiIiIioOGG1Q8Gw54uIiIiIiEgBmHwREREREREpAJMvIiIiIiIiBeCcLyIiIiKi4oSl5gVTdI+MiIiIiIhIiTD5IiIiIiIiUgAOOyQiIiIiKk5Yal4w7PkiIiIiIiJSACZfRERERERECsBhh0RERERExQmrHQqm6B4ZERERERGREmHyRUREREREpAAcdkhEREREVJyIRMo9tI/VDomIiIiIiCgvmHwREREREREpAJMvIiIiIiIiBeCcLyIiIiKi4kQsyliUlTLHlkfs+SIiIiIiIlIAJl9EREREREQKwGGHRERERETFiUis5KXmlTi2PCq6R0ZERERERKREmHwREREREREpAIcdEhEREREVJyJRxqKslDm2PGLPFxERERERkQIw+SIiIiIiIlIAJl9EREREREQKwDlfRERERETFCUvNC6boHhkREREREZESYfJFRERERESkABx2SERERERUnLDUvGDY80VERERERKQATL6IiIiIiIgUgMMOiYiIiIiKE1Y7FEzRPTIiIiIiIiIlwuSLiIiIiIhIATjskIiIiIioOGG1Q8Gw54uIiIiIiEgBmHwREREREREpAJMvIiIiIiIiBeCcLyIiIiKi4oSl5gVTdI+MiIiIiIhIiTD5IiIiIiIiUgAOOyQiIiIiKk5Yal4w7PkiIiIiIiJSACZfRERERERECsBhh0RERERExYqSVzsswv1DRffIiIiIiIiIlAiTLyIiIiIiIgXgsEMiIiIiouKE1Q4Fw54vIiIiIiIiBWDyRUREREREpABMvoiIiIiIiBSAyRcRERERUXEiEmWUmlfaJXdzvoKCgtCmTRtYWFhAJBJh3759XxyuKNtl/vz5sja2trZZts+dO1duP7dv34abmxs0NTVhZWWFefPm5fqpZ/JFRERERESFVkJCAqpUqYKVK1dmuz08PFxu2bhxI0QiETp16iTXbubMmXLtRowYIdsWGxuL5s2bw8bGBtevX8f8+fMxffp0+Pn55SpWVjskIiIiIqJCy9PTE56enl/dbm5uLnd7//79aNSoEcqUKSO3XldXN0vbTNu3b0dKSgo2btwIdXV1VKxYEaGhoVi0aBEGDRqU41jZ80VEREREVJwIPqwwBwsyepv+uyQnJ+f50CMjI3H48GH0798/y7a5c+fCyMgI1apVw/z585GWlibbdunSJTRo0ADq6uqydR4eHnj48CE+ffqU48dn8kVERERERErHysoK+vr6smXOnDl53qe/vz90dXXRsWNHufW+vr4ICAjA2bNnMXjwYMyePRvjxo2TbY+IiICZmZncfTJvR0RE5PjxOeyQiIiIiIiUzuvXr6Gnpye7raGhked9bty4ET179oSmpqbc+tGjR8v+7ezsDHV1dQwePBhz5szJl8fNxOSLiIiIiKg4EYlyXVFQof6JTU9PTy75yqvg4GA8fPgQO3fu/G7b2rVrIy0tDS9evICjoyPMzc0RGRkp1ybz9tfmiWWHww6JiIiIiKjI27BhA2rUqIEqVap8t21oaCjEYjFMTU0BAK6urggKCkJqaqqszcmTJ+Ho6IiSJUvmOAYmX0REREREVGjFx8cjNDQUoaGhAIDnz58jNDQUr169krWJjY3F7t27MWDAgCz3v3TpEpYsWYJbt27h2bNn2L59O0aNGoVevXrJEqsePXpAXV0d/fv3x71797Bz504sXbpUbrhiTnDYIRERERFRcfKfioJKKZexhYSEoFGjRrLbmQmRt7c3Nm/eDAAICAiAVCpF9+7ds9xfQ0MDAQEBmD59OpKTk2FnZ4dRo0bJJVb6+vo4ceIEhg0bhho1asDY2BhTp07NVZl5gD1fRKTEHj9+jObNm0NfXz/bK9bn1YsXLyASiWQnZvqXra0tfHx8hA4ji9y8ZpltFyxYUPCBUbamT58O0RfzSoR6bynre5qI8s7d3R1SqTTL8t+/FYMGDcLnz5+hr6+f5f7Vq1fH5cuXER0djcTERISFhWHixIlZCm04OzsjODgYSUlJePPmDcaPH5/rWJl8EdE3PX36FIMHD0aZMmWgqakJPT091KtXD0uXLkViYmKBPra3tzfu3LmD//3vf9i6dStcXFwK9PGKorCwMEyfPh0vXrwQOpQCc+TIEUyfPl3oMLKYPXt2vv9gQN928eJFTJ8+HdHR0UKHQkSULQ47JKKvOnz4MLp06QINDQ306dMHlSpVQkpKCs6fP4+xY8fi3r178PPzK5DHTkxMxKVLlzBp0iQMHz68QB7DxsYGiYmJUFNTK5D9K4OwsDDMmDED7u7usLW1zfH9Hj58CLFY+X6fy+41O3LkCFauXKl0Cdjs2bPRuXNntG/fXuhQlEpBvrcuXryIGTNmwMfHBwYGBgp7XCKinGLyRUTZev78Oby8vGBjY4MzZ86gVKlSsm3Dhg3DkydPcPjw4QJ7/Hfv3gFAli9Q+UkkEmW5zkdxJpVKkZSUBC0trXy9pkl+4muWNwkJCdDW1hY0BqHeW8r6niYSRCEpNV8U8ScgIsrWvHnzEB8fjw0bNsglXpkcHBzw888/y26npaXht99+g729PTQ0NGBra4tff/0VycnJcveztbVF69atcf78edSqVQuampooU6YMtmzZImszffp02NjYAADGjh0LkUgk67Xx8fHJtgcnu7klJ0+eRP369WFgYAAdHR04Ojri119/lW3/2vyhM2fOwM3NDdra2jAwMEC7du1w//79bB/vyZMnsl/Z9fX10bdvX3z+/PnrT+w/3N3dUalSJdy+fRsNGzZEiRIl4ODggD179gAAzp07h9q1a0NLSwuOjo44deqU3P1fvnyJoUOHwtHREVpaWjAyMkKXLl3khhdu3rwZXbp0AQA0atQIIpEIIpEIgYGBAP59LY4fPw4XFxdoaWlh7dq1sm2Z82OkUikaNWoEExMTREVFyfafkpKCypUrw97eHgkJCd895v8aPXo0jIyMIJVKZetGjBgBkUiEZcuWydZFRkZCJBJh9erVALK+Zj4+Pli5ciUAyI7vy/cBAPj5+cnemzVr1sS1a9eytMnJ6
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "BKDRjvPFow6L",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 906
|
|||
|
|
},
|
|||
|
|
"outputId": "921c5592-17ac-4516-e110-5bf80d79a3d8"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples(correct_indices, predicted, X_test, y_test, 5, class_names)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 25 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4YAAAN5CAYAAABUtyXBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1RUx9sH8O/SdqkiiNIEEUVRQJTYUETFghUraixYscYWY4nxZzexx15iXjSW2EWjsUsUS+xiwS6gYkERUbBQdt4/OHvD3LvALi6C8nzOyTmZ2dnZufjc2Xv3TpExxhgIIYQQQgghhBRbeoXdAEIIIYQQQgghhYtuDAkhhBBCCCGkmKMbQ0IIIYQQQggp5ujGkBBCCCGEEEKKOboxJIQQQgghhJBijm4MCSGEEEIIIaSYoxtDQgghhBBCCCnm6MaQEEIIIYQQQoo5ujEkhBBCCCGEkGKObgyLiXPnzsHIyAhxcXGF3ZQcpaeno2zZsli+fHlhN+WLVa5cOfTu3buwm1HovoR4T0xMhKmpKf7+++/CbkqRIJPJMGzYMJ3W+TWdD48ePYJCocCpU6cKuym5qlOnDsaOHVvYzSAF5EvoW+laIv+oH85dceiHi8SN4dq1ayGTyYT/FAoF3NzcMGzYMDx//rywm/dVmDhxIrp16wZnZ2ch79y5cxgyZAh8fHxgaGgImUymdb2nT59G/fr1YWJiAltbWwwfPhwpKSmSch8/fsS4ceNgb28PY2Nj1K5dG4cPH+bKGBoaYvTo0Zg5cyY+fPig/UEWAIrNwjdr1iyEh4dr9R518Q4AN2/eRGBgIMzMzGBlZYWePXvixYsXGtWZkpKCkSNHwtHREXK5HO7u7lixYkWe7xswYABkMhlat27N5VtbW6N///6YNGmS5geWC4rVr9u0adNQu3Zt1KtXT8i7ffs2Ro0aBV9fXygUCshkMsTGxmpVr6bnhFKpxJw5c+Di4gKFQgEvLy/8+eefknLjxo3DsmXL8OzZM62P8VPROVDw6FoidxSDX7di0Q+zIiAsLIwBYNOmTWPr169nv/32GwsJCWF6enrMxcWFpaamFnYTv2iXL19mANjp06e5/MmTJzNDQ0Pm4+PD3NzcmLbhcPnyZaZQKFj16tXZihUr2MSJE5lcLmeBgYGSsl27dmUGBgZszJgxbNWqVaxu3brMwMCARUZGcuWSkpKYkZER+/3337U/0ALwpcXmhw8fWFpaWmE3Q6dMTU1ZSEiIxuVzivdHjx6xUqVKMVdXV7Zo0SI2c+ZMVrJkSVatWjX28ePHXOvMyMhgvr6+zMjIiI0aNYotX76cBQUFMQBs5syZOb7v/PnzzMDAgCkUCtaqVSvJ69HR0QwAO3r0qMbHl5MvLVbFALChQ4fqtE5nZ2etYqeoSkhIYIaGhmzTpk1cflhYGNPT02MeHh7M29ubAWAxMTEa16vNOTF+/HgGgA0YMICtXr2atWrVigFgf/75J1cuMzOT2draskmTJuX7ePPrSz8Hijq6lsjblx6D1A/nrLj0w0XqxvD8+fNc/ujRoxkAyT8C0c7w4cOZk5MTUyqVXP6zZ8/Yu3fvGGOMDR06VOvOvEWLFszOzo4lJycLeb/99hsDwA4ePCjknT17lgFgc+fOFfLev3/PXF1dWd26dSX1tm7dmvn5+WnVloJCsVn4tL0xzCneBw8ezIyNjVlcXJyQd/jwYQaArVq1Ktc6t27dygBILjI6duzIFAoFe/78ueQ9SqWS1a1bl/Xt25c5OzurvTFkjDEPDw/Ws2dPTQ8vR196rNIFSc4WLFjAjI2N2du3b7n8xMRE9ubNG8YYY3PnztX6gkTTc+Lx48fM0NCQ+/dRKpXMz8+POTo6soyMDK7eYcOGMWdnZ8k5WNC+9HOgqKNribx96TFI/XDOiks/XCSGkuakcePGAICYmBgh7/79+7h//75G73/9+jVGjRqFcuXKQS6Xw9HREb169cLLly8BAGlpafjf//4HHx8flChRAqampvDz80NERARXT2xsLGQyGebNm4fVq1fD1dUVcrkcNWvWxPnz57myvXv3hpmZGeLj49GuXTuYmZnBxsYGY8aMQWZmJlc2NTUV33//PcqWLQu5XI5KlSph3rx5YIxx5VRjvsPDw+Hh4QG5XI6qVaviwIEDGv0dwsPD0bhxY8nwjjJlysDY2FijOsTevHmDw4cPo0ePHrCwsBDye/XqBTMzM2zdulXI2759O/T19REaGirkKRQK9OvXD2fOnMGjR4+4ups2bYqTJ0/i1atX+Wrb5/ApsZmeno6pU6eiYsWKUCgUsLa2Rv369YXhMHv27IFMJsPVq1eF9+zYsQMymQwdOnTg6nJ3d0eXLl2EtHgsv2pYy6lTpzB69GjY2NjA1NQU7du3lwxTUCqVmDJlCuzt7WFiYoJGjRohOjo6X/MDDh8+jPr168PS0hJmZmaoVKkSfvzxR67Mx48fMXnyZFSoUAFyuRxly5bF2LFj8fHjR6GMTCZDamoq1q1bJwzNyastOcX7jh070Lp1azg5OQl5TZo0gZubGxev6kRGRgIAunbtyuV37doVHz58wO7duyXvWb9+Pa5fv46ZM2fmWnfTpk3x119/Sc57XfmUWH316hXGjBkDT09PmJmZwcLCAi1atEBUVBRX7p9//oFMJsPWrVsxdepUODg4wNzcHJ06dUJycjI+fvyIkSNHonTp0jAzM0OfPn24f+fsNm7ciEqVKkGhUMDHxwcnTpzgXu/duzfKlSsned+UKVPyHMKWn+OZOXMmHB0doVAoEBAQgHv37knqPXv2LFq2bImSJUvC1NQUXl5eWLRoEVfm1q1b6NSpE6ysrKBQKPDNN99gz549ubZXJTw8HLVr14aZmRmXb2VlBXNzc43qUEfTc2L37t1IT0/HkCFDhDyZTIbBgwfj8ePHOHPmDFdv06ZNERcXhytXruS7bbpE1xJZ6Fqi8FA//GnHQ/3w5+uHDfJ3GJ+H6oSxtrYW8gICAgAgz/G7KSkp8PPzw82bN9G3b1/UqFEDL1++xJ49e/D48WOUKlUKb968wZo1a9CtWzcMGDAAb9++xe+//47mzZvj3Llz8Pb25urctGkT3r59i4EDB0Imk2HOnDno0KEDHjx4AENDQ6FcZmYmmjdvjtq1a2PevHk4cuQI5s+fD1dXVwwePBgAwBhD27ZtERERgX79+sHb2xsHDx7EDz/8gPj4eCxcuJD77JMnT2Lnzp0YMmQIzM3NsXjxYnTs2BEPHz7k/j5i8fHxePjwIWrUqJHn31sb165dQ0ZGBr755hsu38jICN7e3rh8+bKQd/nyZbi5uXGdPgDUqlULAHDlyhWULVtWyPfx8QFjDKdPn5bMyyoqPiU2p0yZgp9//hn9+/dHrVq18ObNG1y4cAGXLl1C06ZNUb9+fchkMpw4cQJeXl4Asm5M9PT0cPLkSaGeFy9e4NatWxpNFP/uu+9QsmRJTJ48GbGxsfj1118xbNgwbNmyRSgzYcIEzJkzB23atEHz5s0RFRWF5s2baz1H48aNG2jdujW8vLwwbdo0yOVy3Lt3j5usrVQq0bZtW5w8eRKhoaFwd3fHtWvXsHDhQty5c0eYU7h+/Xrh76S6GHB1dc3xs3OK9/j4eCQkJEjiFciKw7wWgPn48SP09fVhZGTE5ZuYmAAALl68iAEDBgj5b9++xbhx4/Djjz/C1tY217p9fHywcOFC3LhxAx4eHrmWzY9PidUHDx4gPDwcnTt3houLC54/f45Vq1bB398f0dHRsLe358r//PPPMDY2xvjx43Hv3j0sWbIEhoaG0NPTQ1JSEqZMmYJ///0Xa9euhYuLC/73v/9x7z9+/Di2bNmC4cOHQy6XY/ny5QgMDMS5c+d08rfR9nh++eUX6OnpYcyYMUhOTsacOXPQvXt3nD17Vihz+PBhtG7dGnZ2dhgxYgRsbW1x8+ZN7N27FyNGjACQdU7Uq1cPDg4OGD9+PExNTbF161a0a9cOO3bsQPv27XNsc3p6Os6fPy98d+iKNufE5cuXYWpqCnd3d0k51ev169cX8n18fAAAp06dQvXq1
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "lgtxWrarps6b",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 906
|
|||
|
|
},
|
|||
|
|
"outputId": "8ba4ca25-2781-4c3e-f6d9-a74096f3d7ab"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples(incorrect_indices, predicted, X_test, y_test, 5, class_names)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 25 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4wAAAN5CAYAAABDlbUIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1wUx/sH8M/RmyAgKDZARBHsKFiDioiIggWxC2oUFTVGjSXR2DX2LtZgj73GFlGMWGLvFRF7BQFFpM/vD353X2Z3Oe4QBPR5v1788czNzs4ez+3t3u7MyhhjDIQQQgghhBBCiIBGYXeAEEIIIYQQQkjRRCeMhBBCCCGEEEIk0QkjIYQQQgghhBBJdMJICCGEEEIIIUQSnTASQgghhBBCCJFEJ4yEEEIIIYQQQiTRCSMhhBBCCCGEEEl0wkgIIYQQQgghRBKdMBJCCCGEEEIIkUQnjF9o3bp1kMlkuHTpUmF3BYMHD4aHh0dhd0OpI0eOwMjICO/evSvsrpA8KEr5Llcc8n7FihWoWLEiUlJSCrsr36SilJeUj8VLQeTOyZMnIZPJcPLkSUVZYGAgbGxscl328ePHkMlkWLduXb7152vavn07zMzMkJiYWNhdydGdO3egpaWFW7duFXZX8gXlcP6iHJZW5E8Y5R8E+Z+enh6qVKmCIUOG4M2bN4XdvSIjOjoaa9aswa+//ip6be3atahWrRr09PRgb2+PJUuWqNRmYGAg994L/168eKGo+88//6Bfv36oXr06NDU1c9yptG7dGpUrV8bMmTPztJ1FQXHPyRkzZmDv3r2F3Y18URB5L3flyhX4+PjAzMwMBgYGqF69OhYvXqx4Xf6lmNNf//79FXUDAwORmpqKlStX5n1jC1Fxz/mvpTDzUSg+Ph6WlpaQyWTYuXMn91pxzUfKw6ItIyMDEydOxNChQ2FkZMS9dvbsWTRp0gQGBgYoU6YMhg0blqcD8tOnTyv+/zExMZJ1tm3bhoYNG8LQ0BAlS5ZEo0aNcOLECcXrjo6O8Pb2xu+//672+r8U5XDRVlRyWM7DwwMymQxDhgzhygsjh7W+2pq+0JQpU2Bra4vk5GScPn0aISEhOHToEG7dugUDA4PC7l6hW7RoEWxtbdG8eXOufOXKlRg4cCA6deqEESNGICIiAsOGDUNSUhLGjBmjtM2goCC0bNmSK2OMYeDAgbCxsUG5cuUU5Vu2bMG2bdtQt25dlC1bNtd2R40ahcmTJ6NEiRJqbmnRUVxzcsaMGfDz80P79u0LuytfrCDyHsj6AaRdu3aoU6cOJkyYACMjI0RFReH58+eKOhYWFti4caNo2SNHjmDz5s1o1aqVokxPTw8BAQGYP38+hg4dCplM9gVbXXiKa85/LYWZj0K///47kpKSJF8r7vlIeVg0HThwAPfv38eAAQO48mvXrsHd3R3VqlXD/Pnz8fz5c8ydOxeRkZE4fPiwyu1nZmZi6NChMDQ0xKdPnyTrTJo0CVOmTIGfnx8CAwORlpaGW7ducT9wA8DAgQPRpk0bREVFwc7OTv2N/UKUw0VTUchhud27d+PcuXM5vv7Vc5gVcaGhoQwAu3jxIlc+YsQIBoBt2bKlkHqWJaf+qSIzM5MlJSV9cR9SU1NZqVKl2Pjx47nypKQkZm5uzry9vbnyHj16MENDQ/b+/Xu11xUREcEAsOnTp3PlL168YKmpqYwxxry9vZm1tXWObbx584ZpamqytWvXqr3+oqCo52RuDA0NWUBAQJ6W/ZJ8z28FlfcJCQmsdOnSrEOHDiwjI0Ptfrm7uzNjY2P2+fNnrvzSpUsMADt+/LjabRa2op7z3/J+OC/5ePPmTaalpcWmTJnCALAdO3aI6hTHfCyIPCyIfVp4eDgDwMLDwxVlAQEBSr8X5aKjoxkAFhoamm/9yS+JiYlKX/fx8WFNmjQRlXt5eTErKyuWkJCgKFu9ejUDwI4ePary+kNCQpi5uTn76aefGAD27t077vVz584xmUzG5s+fn2tbqampzNTUlE2YMEHl9ecHyuHCVdRzWO7z58/MxsZGsQ8PDg4W1fnaOVzkb0nNSYsWLQBk3QIkFxUVhaioqFyXVXYb2ePHjxX17t27Bz8/P5iZmUFPTw/16tXD/v37JdtMSkpCUFAQzM3NYWxsjN69eyMuLo6rY2Njg7Zt2+Lo0aOoV68e9PX1FbcEhYaGokWLFrC0tISuri4cHR0REhKi0ntx+vRpxMTEiK4GhoeHIzY2FoMHD+bKg4OD8enTJxw8eFCl9rPbsmULZDIZunfvzpWXLVsW2traKrVhaWmJmjVrYt++fWqvvyj7kpyU36Zy+vRpDBs2DBYWFihZsiSCgoKQmpqK+Ph49O7dG6ampjA1NcXo0aPBGOPamDt3Lho1agRzc3Po6+vD2dlZdCuaTCbDp0+fsH79ekXOBwYGKl5/8eIF+vXrh7Jly0JXVxe2trYYNGgQUlNTuXZSUlIwYsQIWFhYwNDQEB06dJAcl3r48GE0bdoUhoaGKFGiBLy9vXH79m2uzuvXr9GnTx+UL18eurq6sLKygq+vL/dZlFJQeb9lyxa8efMG06dPh4aGBj59+oTMzEyly8i9evUK4eHh6NixI/T09LjXnJ2dYWZm9k3lPe2H/6co5eNPP/2EDh06oGnTpjnW+Zby8UvyUE6V3JHJZJg0aZJoWRsbG24/qqr4+HgEBgbCxMQEJUuWREBAAOLj4yXrnjhxQrEvLVmyJHx9fXH37l3F67ndIp/d+fPn0bp1a5iYmMDAwABubm44c+YMV2fSpEmQyWS4c+cOunfvDlNTUzRp0iTHbUlOTsaRI0dE+f/hwwccO3YMPXv2hLGxsaK8d+/eMDIywvbt21V6r96/f4/x48djypQpKFmypGSdhQsXokyZMvjpp5/AGFN6u6C2tjaaNWtWZPKfcphyOLvZs2cjMzMTo0aNyrHO187hYnNLqpD8Q2Rubq4oc3d3B4BcDzSlbiMbP3483r59q7hn+fbt22jcuDHKlSuHsWPHwtDQENu3b0f79u2xa9cudOjQgVt+yJAhKFmyJCZNmoT79+8jJCQET548UQwclrt//z66deuGoKAg9O/fH1WrVgUAhISEwMnJCT4+PtDS0sKBAwcwePBgZGZmIjg4WOn2nD17FjKZDHXq1OHKr169CgCoV68eV+7s7AwNDQ1cvXoVPXv2VNp2dmlpadi+fTsaNWqk0sBnZZydnb+ZcXRyX5KTckOHDkWZMmUwefJk/Pfff1i1ahVKliyJs2fPomLFipgxYwYOHTqEOXPmoHr16ujdu7di2UWLFsHHxwc9evRAamoqtm7dis6dO+Pvv/+Gt7c3gKzc//HHH+Hi4qK45UJ+K8PLly/h4uKC+Ph4DBgwAA4ODnjx4gV27tyJpKQk6OjocP00NTXFxIkT8fjxYyxcuBBDhgzBtm3bFHU2btyIgIAAeHp6YtasWUhKSkJISAiaNGmCq1evKnKoU6dOuH37NoYOHQobGxu8ffsWx44dw9OnT5XmWUHlfVhYGIyNjfHixQu0b98eDx48gKGhIXr16oUFCxaITgSz27p1KzIzM9GjRw/J1+vWrSv6UivOaD/8P0UlH3fs2IGzZ8/i7t27uf4PvpV8zI99r6q5k18YY/D19cXp06cxcOBAVKtWDXv27EFAQICoblhYGLy8vFCpUiVMmjQJnz9/xpIlS9C4cWNcuXIFNjY2krfIp6Wl4eeff+b23SdOnICXlxecnZ0xceJEaGhoKH4oiYiIgIuLC9dG586dYW9vjxkzZoh+pMzu8uXLSE1NRd26dbnymzdvIj09XZT/Ojo6qF27tuLzkZsJEyagTJkyCAoKwtSpUyXrHD9+HI0aNcLixYsxbdo0xMbGokyZMvjtt99EY8CArM/gvn378OHDB+5EoDBQDlMOyz19+hR//PEH/vzzT+jr6ytt86vm8Fe5jvkF5Jfaw8LC2Lt379izZ8/Y1q1bmbm5OdPX12fPnz9X1LW2tlbpcrnQ7NmzGQC2YcMGRZm7uzurUaMGS05OVpRlZmayRo0aMXt7e1H/nJ2dFbdkZ
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "r29_OvwJtyM7"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"# CIFAR-10"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "nuo0JHXDzxT7"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Ładowanie zbioru danych"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "pPrx-UWet0Ng",
|
|||
|
|
"outputId": "7311e643-6b69-4e2e-a2ca-ef4faf46454d"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"from keras.datasets import cifar10\n",
|
|||
|
|
"\n",
|
|||
|
|
"(X_train, y_train), (X_test, y_test) = cifar10.load_data()\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train = X_train.astype('float32') # change integers to 32-bit floating point numbers\n",
|
|||
|
|
"X_test = X_test.astype('float32')\n",
|
|||
|
|
"\n",
|
|||
|
|
"X_train /= 255 # normalize each value for each pixel for the entire vector for each input\n",
|
|||
|
|
"X_test /= 255\n",
|
|||
|
|
"\n",
|
|||
|
|
"y_train = y_train.reshape((1,-1))[0]\n",
|
|||
|
|
"y_test = y_test.reshape((1,-1))[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(\"Training matrix shape\", X_train.shape, y_train.shape)\n",
|
|||
|
|
"print(\"Testing matrix shape\", X_test.shape, y_test.shape)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# one-hot format classes\n",
|
|||
|
|
"\n",
|
|||
|
|
"nb_classes = 10\n",
|
|||
|
|
"\n",
|
|||
|
|
"Y_train = to_categorical(y_train, nb_classes)\n",
|
|||
|
|
"Y_test = to_categorical(y_test, nb_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"cifar_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Training matrix shape (50000, 32, 32, 3) (50000,)\n",
|
|||
|
|
"Testing matrix shape (10000, 32, 32, 3) (10000,)\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "Q2LGp6AVzzqC"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Podgląd zbioru treningowego"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 828
|
|||
|
|
},
|
|||
|
|
"id": "o6OJ7XPdxe1i",
|
|||
|
|
"outputId": "626c8984-08cb-413c-9329-f7aa03110986"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"for i in range(0, 10):\n",
|
|||
|
|
" img_batch = X_train[y_train == i][0:10]\n",
|
|||
|
|
" img_batch = np.reshape(img_batch, (img_batch.shape[0]*img_batch.shape[1], img_batch.shape[2], img_batch.shape[3]))\n",
|
|||
|
|
" if i > 0:\n",
|
|||
|
|
" img = np.concatenate([img, img_batch], axis = 1)\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" img = img_batch\n",
|
|||
|
|
"plt.figure(figsize=(10,20))\n",
|
|||
|
|
"plt.axis('off')\n",
|
|||
|
|
"plt.imshow(img, cmap='gray')\n"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<matplotlib.image.AxesImage at 0x7ac88cff5570>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 71
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 1000x2000 with 1 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAMaCAYAAAABQDBSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz92Y4ky5Kmi32ig5m5e0Rk5pr23lXV1ezqw74gCPCOV+cZ+CJ8Dj4Ar/gWBME7vsABAaKJ0ziNc6rBbtS41t5ryMyY3N3MdOKFiJp7ZOYeagAIEKELsSIywt3cBlVRkV9++UVaa43X8Tpex+t4Ha/jdbyO1/E6Xsfr+Fcc7v/XJ/A6XsfreB2v43W8jtfxOl7H6/j/v/EaaLyO1/E6XsfreB2v43W8jtfxOv7Vx2ug8Tpex+t4Ha/jdbyO1/E6Xsfr+Fcfr4HG63gdr+N1vI7X8Tpex+t4Ha/jX328Bhqv43W8jtfxOl7H63gdr+N1vI5/9fEaaLyO1/E6XsfreB2v43W8jtfxOv7Vx2ug8Tpex+t4Ha/jdbyO1/E6Xsfr+Fcfr4HG63gdr+N1vI7X8Tpex+t4Ha/jX32EP/WF/6//+P+kAXL1OxF58XX9u/7zy9c6+5189v4Xr2/Q+LyP4OV4X/rb5+e8vbo1Pu1LeH1urTVqLfaal6/tP18f4+X369/31+nP989wf3z53hgC3jm89/gYQYTmAg1PAyp6fzwOaYI0kAaugW8g/bP6BTqhCRSgXt+D1pCq79V40lEr1KL3tkr95GG+vN5Pf95+h10fUIGGfPb6T+9VrfWTY1zdd6Bun9OQ+APhq/87zf+Wv/tvz/z2H094d2CKv8bJwDhMhBARu0fOCdPo8d6xG2/Yj7cEP3IY3uFdZDdNDMPAvJ55fP7Amhd+ef5HTusD8ZCZ7hbCUDm8XYhjwfmE85laE2k501olBD2+kwFhBwitVlqrtJZobcG7yH73juAHvI94H3BScMy0WpnnzJoqyyycnhylePKyo+RALpFUhqv7UjmfnliWEzkXUsqUUpmXM6kk+t0vpXI8zpTUeFf/grft31DWlfPDAyVn8popuSDO48Koc96+3rx5w1dff8U0TXzz7bdM47g9n3meuf/4gWVZ+PHv/577X37h+PzMh/fvKbXqPLO5J87hnCcMA845duPIEAdijEy7Hd57pt2OGCPDOLDf7fHBs9sdCCESh4kQJ0IIjKO+Po4RHwI+BOIwIOKIw4D3AXEe8bZubP5dZo+d1/W8rjpZX9gA0fX/3/7r/8L/5f/8f+Jv/tt/IcaI9/6FLXIiOHHbvwFKezmXcymczjO1VoJ4PI4gjsHet+SVXCsEh8SAc44YPN45YozEGAkhstvtEBHO5xPrulJyoSz55VqplVquPl9AnNpP54Tg9ZynMTIEz24cuLvZUWvl4+Mjy7qSaiOVps9lOuCc2oVWIefEup4RgbdvJ/a7iHOCc2q3vffQ4JfvH3n//ZPaEbMJuRZqa/ZM7PRCRLxnGEemaffC1pdSKKXgvd+uvZRMq9Wmln5uCAHvPeM4stvm04RzjnVdSGllWReej8+01hjHkRAC+/2eu7s7u059/fl05Hw8UkuhrAuttj49dB6HAR8Cb999xTTtbO0l1nXl/v17lnXZrq3Pa+c94zQRYqDWSsqZdV357Y8/8nw8klJmTYkQB27efI2PkXk+sSwLjkaQhpPG3sPo4TB63u0HYnDcHEaG6NmNkf004MThndd7E6KuBXHgHLU1jqcz67pSG5QKzjkO+x3Be6TpQ261knOm2vdSCs7rXESgUmmtkXMi5RXnHfvdDu+dPVmdhzllm4Ieh9uerXOOyZ5Bn9/OObyPgJByodRq97HbI0dDqLVScqE13c9KbfwP//nv+B/+57+nVf0sQZBBEC/s73a8++4OH/xmAR4fjtx/eKLWRmseEIZpYBgGQvRMU8B5x+iF6B273cTd3R0NOC8zKWfmZeE0n/EhcDjc4L2nnBN1KeRlZXk8Qmvsx5EhRnY3e27fvQGBuSRKqyznM8v5TMmFedG5JjgEoRUouVJr5TyfyTnjR0/cRV0btVCb/r1WfR616XfM53Aiuic5RwzB1nGllMK73/x7/vf/h/8jb3/979TumY9S7b0eMY8AfF+TLgCOCpQGpVbmkqit6SMCpGZIKyIwxUAMnhg80xgAoRQ9vbUKq20SYo+50M/D7EMzO/vCWP/x8cI/6dbmE5/lC96hrvGrP3zmywK39QM39SPr6YmPv/s71vnE8f1PzM/31Cq05nA+cvfVrxl3N9y+ecebr74jDiP7N18R4kgIDh+E0/GBn374b8zzkY8//8Dz00dazpR5gQauOgTHMI6MuxFxHh/V/vzFv/13/Oo3f8bh5pZvv/sVwXvS+UhJKx8/vufH331PKYXWMq1Vai2Umjkdj/zww/csy5n5fCKtM6UWcs66prLa+mXOnE6JUivLmiilUluhtUKMgdvbAyF4dvuRYQg0Mo1F51CBWhslQ1YTwP/t//r/+KPP7U8ONPrD+Oe8Sr4QBXzpdy8O8clsefn6ly/4YpAhsk1AEfmi0/zp3/pbvvR6/R2ffe7ve/0fGw1dFILQV0HrP1+95tNL2/6sH2zHgSZt++N2mq198qaX/2ifftj1q64Dms+Wbv8s+b2vuD7Gpz9/+fXN7sdlSI+ybHMTV9UAt4Zns1VXAd4Xvq7d0Mu+ZhujHd++Lq9uXzhB+WKg3fp9kIvJ7Ca0te71igVkAk3f02p78Xj6z3U796vHt52L2EW3yzO+nIxehqjzrwGAo3lwLuC6Ey0OBJz3OOf1teL093bdegyPcw7n/eb0ex9Ais57GuLUyRF3ufbrZ1FrRUSotVCLo5ZKqRkKlJIRBOcy4jIClJChNZyX7Ua7fi3Z6bl5cDib+/qiJhbwXuMQ10bhk81JkC/YjM8WgH5Ga5/ZqmZ/u3oq9G1V59XFNvTXX8c/23d5+foNZ2lX869dzVs77xeP/cVxZAs8+OT3zuaD1NpvkL1fEJtPl89oXM9mEfkTHIJPjLZcPY8/9K5PwKpm7315DeCuAqrt9Zcz/GzLELkce7tPvPysPmX689ye0TaHL85ebU3XLDqHnDn3cg2qXF8Tl8/of6utIrVSawdfGtW17RxaBZq7CoovjiXNwromNBytVaQ5GpXWdPNvterrun0R9NydILUiFmjUDpLYd7VJFpBvDmHbHNvWqu1NDSwQwZxG2cJNUff1au3349cKIlVnWv9cUJvT9DwRA2/sb9s1bKGrXKbtlUnv9qavnS9NuW32Xq+p7Xou6+56vlyO1a7+pvPQez3v7uh7p4AXAh4HteG8w3l9lt45KlUDMoRKo5kj7vpxt89uiBNck26p1DG1f9Oq/ts5vFfb288Ds8vOu24YbN68vBt9e2tXc7Nfa2tfuoPX9/Vqz3/hXnxmlb54nE+P9k8dX/S1ro3tZ76jXP3/D5+B7ql97dm9v5obmy3cnpV8+mYuc7Vtx/v0/AHd09A54r3uw957+7rMqT4vxF1suPcefVZen5c0GmbfnQL620d/8lg62LjZqH5NX3J5vniPLk/7n+Lz/smBhnOfs6wum4H75AH0xSsvvusN+P3Zjy9u6Fef83K0z37+7HVXE9C5awPejdNlExazRLVejKi+5uIQOnf522WxXWdALq+vrTuveh7dEDWB0hqlFFKpIA4JIN6rF9XvUf+Ixmacqh2rlQq1UnOj1KITx3fny0xzreSSzcnxgMOJx7twed22Qdr9++TOfmmhgF1buyyrF0/ik3v8pYzG1aWp0bX72WwzFQ8SGmEsDPuVda78cp+hBm4O3zAOB4LzhuQIuTqCE0qCtDSiHymTJ/hIJZHbSiMx3gixOd4eHLvikLHA7oj4TA4zRQpSGmQUCc9B7Za3YKFBbQkw5xxPjCPeB7zzDIPDC9AC5JFaIec9rcJyXklr5vS88vDhrM5GM2eDSqWQc+V0PlNLpZGgCTkL66rzMmehVk+3prUKreqzDHFkCndQKtOwg9YYh5EhDPgQGacdiKNaUDFOE9OWXVCkuJZCrZVx3
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "6EF56o2gz3Mz"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Przygotowanie modelu"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "6UvabhsguqWc"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"def generate_model():\n",
|
|||
|
|
" model = Sequential() # Linear stacking of layers\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Convolution Layer 1\n",
|
|||
|
|
" model.add(Conv2D(16, (3, 3), input_shape=(32,32,3)))\n",
|
|||
|
|
" model.add(Activation('relu') )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # ...\n",
|
|||
|
|
"\n",
|
|||
|
|
" model.add(Flatten()) # Flatten final output matrix into a vector\n",
|
|||
|
|
"\n",
|
|||
|
|
" # ...\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Fully Connected Layer\n",
|
|||
|
|
" model.add(Dense(10)) # final 10 FC nodes\n",
|
|||
|
|
" model.add(Activation('softmax')) # softmax activation\n",
|
|||
|
|
"\n",
|
|||
|
|
" model.summary()\n",
|
|||
|
|
"\n",
|
|||
|
|
" adam = tf.optimizers.Adam(learning_rate=0.001)\n",
|
|||
|
|
" model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" return model"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "PSQkAq9IvJsS",
|
|||
|
|
"outputId": "5a311cb1-e467-4085-8ab5-045a6f6d5d78",
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model = generate_model()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Model: \"sequential_5\"\n",
|
|||
|
|
"_________________________________________________________________\n",
|
|||
|
|
" Layer (type) Output Shape Param # \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
" conv2d_7 (Conv2D) (None, 30, 30, 16) 448 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_18 (Activation) (None, 30, 30, 16) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" flatten_3 (Flatten) (None, 14400) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
" dense_11 (Dense) (None, 10) 144010 \n",
|
|||
|
|
" \n",
|
|||
|
|
" activation_19 (Activation) (None, 10) 0 \n",
|
|||
|
|
" \n",
|
|||
|
|
"=================================================================\n",
|
|||
|
|
"Total params: 144458 (564.29 KB)\n",
|
|||
|
|
"Trainable params: 144458 (564.29 KB)\n",
|
|||
|
|
"Non-trainable params: 0 (0.00 Byte)\n",
|
|||
|
|
"_________________________________________________________________\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "YMtJVU_oz55g"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Trening"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "sjZaiBLJvQkP"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,\n",
|
|||
|
|
" height_shift_range=0.08, zoom_range=0.08, validation_split=0.2)\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_generator = gen.flow(X_train, Y_train, batch_size=128, subset='training')\n",
|
|||
|
|
"valid_generator = gen.flow(X_train, Y_train, batch_size=128, subset='validation')"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/"
|
|||
|
|
},
|
|||
|
|
"id": "MlMHmbOsvbWP",
|
|||
|
|
"outputId": "4921a9b5-477c-4ae4-eb53-510d2d057b64"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"model.fit(train_generator, steps_per_epoch=40000//128, epochs=2, verbose=1, validation_data=valid_generator, validation_steps = 10000 // 128)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"Epoch 1/2\n",
|
|||
|
|
"312/312 [==============================] - 28s 86ms/step - loss: 1.7245 - accuracy: 0.3867 - val_loss: 1.5536 - val_accuracy: 0.4515\n",
|
|||
|
|
"Epoch 2/2\n",
|
|||
|
|
"312/312 [==============================] - 26s 85ms/step - loss: 1.5090 - accuracy: 0.4636 - val_loss: 1.4395 - val_accuracy: 0.4973\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "execute_result",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<keras.src.callbacks.History at 0x7ac88cfd2c20>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"execution_count": 75
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "mN5zKMDNz8Jp"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Test"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 1000
|
|||
|
|
},
|
|||
|
|
"id": "lLZORONWvqex",
|
|||
|
|
"outputId": "f83c604e-cbf2-428c-a751-fd23a7653a08"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"score = model.evaluate(X_test, Y_test)\n",
|
|||
|
|
"print('Test score:', score[0])\n",
|
|||
|
|
"print('Test accuracy:', score[1])\n",
|
|||
|
|
"\n",
|
|||
|
|
"# The predict_classes function outputs the highest probability class\n",
|
|||
|
|
"# according to the trained classifier for each input example.\n",
|
|||
|
|
"predicted = model.predict(X_test)\n",
|
|||
|
|
"predicted_classes = np.argmax(predicted, axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Check which items we got right / wrong\n",
|
|||
|
|
"correct_indices = np.nonzero(predicted_classes == y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"incorrect_indices = np.nonzero(predicted_classes != y_test)[0]\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"cnf_matrix = confusion_matrix(y_test, predicted_classes)\n",
|
|||
|
|
"\n",
|
|||
|
|
"class_names = [str(i) for i in range(10)]\n",
|
|||
|
|
"\n",
|
|||
|
|
"# Plot non-normalized confusion matrix\n",
|
|||
|
|
"plt.figure()\n",
|
|||
|
|
"plot_confusion_matrix(cnf_matrix, classes=class_names,\n",
|
|||
|
|
" title='Confusion matrix, without normalization')\n",
|
|||
|
|
"\n",
|
|||
|
|
"plt.show()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"name": "stdout",
|
|||
|
|
"text": [
|
|||
|
|
"313/313 [==============================] - 1s 3ms/step - loss: 1.4067 - accuracy: 0.5069\n",
|
|||
|
|
"Test score: 1.4067156314849854\n",
|
|||
|
|
"Test accuracy: 0.5069000124931335\n",
|
|||
|
|
"313/313 [==============================] - 1s 2ms/step\n",
|
|||
|
|
"Confusion matrix, without normalization\n",
|
|||
|
|
"[[523 99 41 11 9 6 20 23 211 57]\n",
|
|||
|
|
" [ 17 765 4 2 1 4 14 16 61 116]\n",
|
|||
|
|
" [ 74 59 378 44 63 60 136 94 56 36]\n",
|
|||
|
|
" [ 26 56 73 216 39 193 180 104 37 76]\n",
|
|||
|
|
" [ 42 35 171 44 249 45 196 167 32 19]\n",
|
|||
|
|
" [ 16 37 108 118 43 369 84 149 33 43]\n",
|
|||
|
|
" [ 5 48 69 37 32 34 665 51 16 43]\n",
|
|||
|
|
" [ 19 47 38 38 28 60 41 626 22 81]\n",
|
|||
|
|
" [ 85 105 7 6 2 3 16 7 722 47]\n",
|
|||
|
|
" [ 32 248 9 7 4 7 23 26 88 556]]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 2 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0wAAAN5CAYAAAA7KrupAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddVgUaxsG8HtpJKVBGgREwA7sBFHsTuzuPMaxFTuP3d3d3d3dgk1YpISw3x/I6sruJ57D7uzi/TvXXMd9593hmdmYefaNEYnFYjGIiIiIiIgoCw2hAyAiIiIiIlJVTJiIiIiIiIjkYMJEREREREQkBxMmIiIiIiIiOZgwERERERERycGEiYiIiIiISA4mTERERERERHJoCR0AERERERH9O0lJSUhJSRE6jGzT0dGBnp6e0GH8FiZMRERERERqKCkpCfpG5sDXRKFDyTYbGxuEhYWpVdLEhImIiIiISA2lpKQAXxOh6x0CaOoIHc6vpaUg4v4qpKSkMGEiIiIiIiIl0dSBSA0SJrHQAfxLTJiIiIiIiNSZSCNjUXXqEKMM6hk1ERERERGREjBhIiIiIiIikoMJExERERERkRwcw0REREREpM5EAEQioaP4NTUIURa2MBEREREREcnBhImIiIiIiEgOdskjIiIiIlJnnFZcodQzaiIiIiIiIiVgwkRERERERCQHEyYiIiIiIiI5OIaJiIiIiEidiURqMq24GsQoA1uYiIiIiIiI5GDCREREREREJAe75BERERERqTNOK65Q6hk1ERERERGREjBhIiIiIiIikoNd8oiIiIiI1BlnyVMotjARERERERHJwYSJiIiIiIhIDiZMREREREREcnAMExERERGRWlOTacXVtK1GPaMmIiIiIiJSAiZMREREREREcrBLHhERERGROuO04grFFiYiIiIiIiI5mDARERERERHJwYSJiIiIiIhIDo5hIiIiIiJSZyI1mVZcHWKUQT2jJiIiIiIiUgImTERERERERHKwSx4RERERkTrjtOIKxRYmIiIiIiIiOZgwERERERERycEueURERERE6oyz5CmUekZNRERERESkBEyYiIiIiIiI5GDCREREREREJAfHMBERERERqTNOK65QbGEiIiIiIiKSgwkTERERERGRHOySR0RERESkzjituEKpZ9RERERERERKwISJiIiIiIhIDiZMREREREREcnAMExERERGROhOJ1GN8EKcVJyIiIiIiyl2YMBEREREREcnBLnlEREREROpMQ5SxqDp1iFEGtjARERERERHJwYSJiIiIiIhIDnbJIyIiIiJSZyINNZklTw1ilEE9oyYiIiIiIlICJkxERERERERyMGEiIiIiIiKSg2OYiIiIiIjUmUiUsag6dYhRBrYwERERERERycGEiYiIiIiISA52ySMiIiIiUmecVlyh1DNqIiIiIiIiJWDCREREREREJAcTJiIiIiIiIjk4homIiIiISJ1xWnGFYgsTERERERGRHEyYiIiIiIiI5GCXPCIiIiIidcZpxRVKPaMmIiIiIiJSAiZMREREREREcjBhIiIiIiJSZ5mz5KnD8hucnZ0hEomyLD169AAAJCUloUePHjA3N4ehoSEaNmyIyMhIqW28fPkStWrVQp48eWBlZYVBgwbh69evvxUHEyYiIiIiIlI5V65cwbt37yTLkSNHAACNGzcGAPTr1w979uzBli1bcOrUKbx9+xYNGjSQPD8tLQ21atVCSkoKzp8/j1WrVmHlypUYOXLkb8UhEovF4pzbLSIiIiIiUobY2FiYmJhAt8o4iLT0hA7nl8Rfk5B8/G+8evUKxsbGknJdXV3o6ur+8vl9+/bF3r178eTJE8TGxsLS0hLr169Ho0aNAAAPHz5EgQIFcOHCBZQuXRoHDhxAcHAw3r59C2trawDAwoULMWTIEERHR0NHRydbcbOFiYiIiIiIlMbBwQEmJiaSJTQ09JfPSUlJwdq1a9G+fXuIRCJcu3YNqampqFatmqSOl5cXHB0dceHCBQDAhQsX4OvrK0mWACAwMBCxsbG4d+9etuPltOJEREREROpMzaYVl9XC9Cs7d+7E58+f0bZtWwBAREQEdHR0YGpqKlXP2toaERERkjo/JkuZ6zPXZRcTJiIiIiIiUhpjY2OphCk7li1bhqCgINjZ2SkoKvnUIBUlIiIiIqI/1YsXL3D06FF07NhRUmZjY4OUlBR8/vxZqm5kZCRsbGwkdX6eNS/zcWad7GDCRERERESkzoSeKlxB04pnWrFiBaysrFCrVi1JWbFixaCtrY1jx45Jyh49eoSXL1/C398fAODv7487d+4gKipKUufIkSMwNjaGt7d3tv8+u+QREREREZFKSk9Px4oVKxASEgItre+pi4mJCTp06ID+/fvDzMwMxsbG6NWrF/z9/VG6dGkAQEBAALy9vdG6dWtMmTIFERERGDFiBHr06JGtcVOZmDAREREREZFKOnr0KF6+fIn27dtnWTdz5kxoaGigYcOGSE5ORmBgIObPny9Zr6mpib1796Jbt27w9/eHgYEBQkJCMHbs2N+KgV3yiHKxJ0+eICAgACYmJhCJRNi5c2eObj88PBwikQgrV67M0e3mBs7OzpKZfFTJ77xmmXWnTZum+MBIptGjR0P0UxcWod5bqvqeJqLcLSAgAGKxGB4eHlnW6enpYd68efj48SMSEhKwffv2LGOTnJycsH//fiQmJiI6OhrTpk2TaqnKDiZMRAr27NkzdOnSBa6urtDT04OxsTHKli2L2bNn48uXLwr92yEhIbhz5w4mTJiANWvWoHjx4gr9e7nR/fv3MXr0aISHhwsdisLs378fo0ePFjqMLCZOnJjjST79f+fPn8fo0aOzDKImIlWn8X1qcVVe1DT1YJc8IgXat28fGjduDF1dXbRp0wY+Pj5ISUnB2bNnMWjQINy7dw+LFy9WyN/+8uULLly4gOHDh6Nnz54K+RtOTk748uULtLW1FbJ9VXD//n2MGTMGlSpVgrOzc7af9+jRI2hoqN6JQdZrtn//fsybN0/lkqaJEyeiUaNGqFevntChqBRFvrfOnz+PMWPGoG3btlnubaKq72kiIkVjwkSkIGFhYWjWrBmcnJxw/Phx2NraStb16NEDT58+xb59+xT296OjowEgy0VPThKJRNDT01PY9tWNWCxGUlIS9PX1f2swqTLxNftvEhISYGBgIGgMQr23VPU9TUSkaPypiEhBpkyZgvj4eCxbtkwqWcrk7u6OPn36SB5//foV48aNg5ubG3R1deHs7Ixhw4YhOTlZ6nnOzs4IDg7G2bNnUbJkSejp6cHV1RWrV6+W1Bk9ejScnJwAAIMGDYJIJJK0jrRt21ZmS4mssRJHjhxBuXLlYGpqCkNDQ3h6emLYsGGS9fLGwxw/fhzly5eHgYEBTE1NUbduXTx48EDm33v69Knk12wTExO0a9cOiYmJ8g/sN5UqVYKPjw9u376NihUrIk+ePHB3d8fWrVsBAKdOnUKpUqWgr68PT09PHD16VOr5L168QPfu3eHp6Ql9fX2Ym5ujcePGUl3vVq5cicaNGwMAKleuDJFIBJFIhJMnTwL4/locOnQIxYsXh76+PhYtWiRZlzneQywWo3LlyrC0tJSa2jQlJQW+vr5wc3NDQkLCL/f5R/3794e5uTnEYrGkrFevXhCJRJgzZ46kLDIyEiKRCAsWLACQ9TVr27Yt5s2bBwCS/fv5fQAAixcvlrw3S5QogStXrmSpk53XPbvvP5FIhISEBKxatUoS0/8bP3Py5EmIRCJs3rwZEyZMgL29PfT09FC1alU8ffo0S/0tW7agWLFi0NfXh4WFBVq1aoU3b95kidXQ0BDPnj1DzZo1YWRkhJYtW0ri69mzJ7Zs2QJvb2/o6+tLpq8FgEWLFsHd3R16enqoVKlSli6dZ86cQePGjeHo6AhdXV04ODigX79+2eqm+/NYoh9ft5+XzL97+/ZttG3bVtI12MbGBu3bt8eHDx+kXoNBgwYBAFxcXLJsQ9YYpufPn6Nx48YwMzNDnjx5ULp06Sw/BP3ua0NE/4LQU
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "ktfHkYc1zHsY"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"def show_samples_rgb(indices, preds, images, labels, count=3, names = []):\n",
|
|||
|
|
" plt.figure()\n",
|
|||
|
|
" for i, sample in enumerate(indices[:count**2]):\n",
|
|||
|
|
" pred_id = int(np.argmax(preds[sample]))\n",
|
|||
|
|
" real_id = int(labels[sample])\n",
|
|||
|
|
" pred_score = preds[sample][pred_id]\n",
|
|||
|
|
" real_score = preds[sample][real_id]\n",
|
|||
|
|
" plt.subplot(count,count,i+1)\n",
|
|||
|
|
" plt.imshow(images[sample], interpolation='none')\n",
|
|||
|
|
" plt.axis('off')\n",
|
|||
|
|
" if len(names) > 0:\n",
|
|||
|
|
" plt.title(\"P: {} ({:.2f})\\nE: {} ({:.2f})\".format(names[pred_id], pred_score, names[real_id], real_score))\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" plt.title(\"P: {} ({:.2f})\\nE: {} ({:.2f})\".format(pred_id, pred_score, real_id, real_score))\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.tight_layout()"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": []
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "F8hS785GzoGJ"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Poprawne klasyfikacje"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 906
|
|||
|
|
},
|
|||
|
|
"id": "IUcJVbOlzm5z",
|
|||
|
|
"outputId": "89f6a4f9-85cf-41ae-f8b6-fd84ea6d7158"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples_rgb(correct_indices, predicted, X_test, y_test, 5, cifar_names)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 25 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4EAAAN5CAYAAAC2az64AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3wVRff/P7eX9B4glNCki4ZiQalSFawIYgMRC/aGgI+ABUUU8asCVkDsotjLA3bRRxQFQTokdNJ7bm6d3x/8cnfPuWlAkOA979crr+zszO7O7p6Znb17PnMMSikFQRAEQRAEQRAEISwwnugKCIIgCIIgCIIgCP8c8hIoCIIgCIIgCIIQRshLoCAIgiAIgiAIQhghL4GCIAiCIAiCIAhhhLwECoIgCIIgCIIghBHyEigIgiAIgiAIghBGyEugIAiCIAiCIAhCGCEvgYIgCIIgCIIgCGGEvAQKgiAIgiAIgiCEEfIS2Ih49913ER8fj7KyshNdlRrZtGkTzGYzNm7ceKKrItSDa6+9FpGRkfUqazAYMHPmzAavw/Dhw3H99dc3+H4bkkWLFqFFixZwu90nuiqCDrHf+nH//fejd+/eJ7oaYcv27dsxePBgxMTEwGAw4MMPPzzRVQqyd+9e2O12rF69+kRXpUa8Xi+aN2+OBQsWnOiqCCeIfv36oUuXLnWWy8rKgsFgwJIlS4LrZs6cCYPB0OB1euKJJ9ChQwcEAoEG33dDcaxj8pPyJXDJkiUwGAzBP7vdjvbt2+OWW25Bdnb2ia4eKioqMHPmTHz33Xf13sbv92PGjBm49dZbQwY9P//8M/r06QOn04nU1FTcdttt9X5R1F8n/d/jjz9OylU1Iv5nt9tJuU6dOmHEiBF48MEH631uQuO32ePF6tWr8d///hdTpkwh6wOBAJ544gmkp6fDbrejW7dueOutt+q1T34t9X+HDh0KlsvPz8fcuXNx7rnnIikpCbGxsTjjjDPwzjvvhOzz2muvhcfjwQsvvHBsJ/wvRey34eyXc/3118NgMOD8888PyWvVqlW1dn7jjTeScnfccQfWr1+Pjz/++KjqEC4cLzu+5pprsGHDBjz66KNYtmwZevTo0YC1PjYeeugh9O7dG2effTZZv3//fowePRqxsbGIjo7GqFGjsGvXrnrts1+/ftXa5dChQ0PKut1uTJkyBU2bNoXD4UDv3r2xcuVKUsZiseCuu+7Co48+isrKyqM/2X+Ik6U/fPPNNzF//vwTXY2TkpKSEsyZMwdTpkyB0UhflT7++GOcfvrpsNvtaNGiBWbMmAGfz1fnPmsaZ1f9Vf1QEwgEsGTJEowcORLNmzdHREQEunTpgkceeSSkfRzrmNx8VFs1Eh566CGkp6ejsrISP/30ExYuXIjPP/8cGzduhNPpPGH1qqiowKxZswAc7izrwyeffIKtW7di0qRJZP26deswcOBAdOzYEfPmzcO+ffvw5JNPYvv27fjiiy/qte/zzjsPV199NVl32mmnVVt24cKF5CXUZDKFlLnxxhsxfPhw7Ny5E23atKlXHYTDNFabBQCXywWzuWG7hLlz52LgwIFo27YtWT99+nQ8/vjjuP7669GzZ0989NFHuOKKK2AwGDBmzJh67bvqWuqJjY0NLv/yyy+YPn06hg8fjgceeABmsxnvv/8+xowZg02bNgXbKADY7XZcc801mDdvHm699dbj8qvivwGx38M0hP0CwO+//44lS5aE/Nimp3v37rj77rvJuvbt25N0amoqRo0ahSeffBIjR46s9/HDlYa0Y5fLFexrbrnlluNU46MjNzcXS5cuxdKlS8n6srIy9O/fH8XFxZg2bRosFguefvpp9O3bF+vWrUNCQkKd+05LS8Njjz1G1jVt2jSk3LXXXovly5fjjjvuQLt27bBkyRIMHz4c3377Lfr06RMsN378eNx///148803MWHChKM843+WxtwfAodfAjdu3Ig77rjjRFelwWjZsiVcLhcsFstxPc6rr74Kn8+HsWPHkvVffPEFLrzwQvTr1w/PPvssNmzYgEceeQQ5OTlYuHBhrfu8+OKLQ54lADBt2jSUlZWhZ8+eAA6/Q4wfPx5nnHEGbrzxRiQnJ+OXX37BjBkz8PXXX+Obb74hY5RjGpOrk5DFixcrAOq3334j6++66y4FQL355psnqGaHyc3NVQDUjBkz6r3NyJEjVZ8+fULWDxs2TDVp0kQVFxcH17300ksKgPrqq6/q3C8ANXny5DrLzZgxQwFQubm5dZb1eDwqLi5O/ec//6mzrHCYE2Wz11xzjYqIiDgu+66L7OxsZTab1csvv0zW79u3T1ksFmKXgUBAnXPOOSotLU35fL5a91vTteTs2rVLZWVlkXWBQEANGDBA2Ww2VVZWRvJ+//13BUB9/fXX9Tm9sELsV+NY7Ve/zZlnnqkmTJigWrZsqUaMGBFSpqb11bF8+XJlMBjUzp0761U+HDkedrx7924FQM2dO7fOsrzPOd7MmzdPORwOVVpaStbPmTNHAVBr1qwJrtu8ebMymUxq6tSpde63b9++qnPnznWW+/XXX0OujcvlUm3atFFnnnlmSPnzzz9fnXPOOXXu90TT2MegVYwYMUK1bNnyRFejXtTXpqqjavzakHTr1k1deeWVIes7deqkTj31VOX1eoPrpk+frgwGg9q8efMRH2fPnj3KYDCo66+/PrjO7Xar1atXh5SdNWuWAqBWrlxJ1h/LmPykdAetiQEDBgAAMjMzg+t27tyJnTt31mv7oqIi3HnnnWjVqhVsNhvS0tJw9dVXIy8vDwDg8Xjw4IMPIiMjAzExMYiIiMA555yDb7/9NriPrKwsJCUlAQBmzZoV/Mxbm1alsrISX375JQYNGkTWl5SUYOXKlbjyyisRHR0dXH/11VcjMjIS7777br3OCzj8a2V93CyUUigpKYFSqsYyFosF/fr1w0cffVTv4wvVcyw26/V6MWvWLLRr1w52ux0JCQno06dPiKsNcNj158ILL0RkZCSSkpJwzz33wO/3kzLcTqtcF7Zs2YLRo0cjOjoaCQkJuP322+tlS5999hl8Pl+IXX/00Ufwer24+eabybFvuukm7Nu3D7/88kud+66itLQ05DyqSE9PR8uWLUPO8cILL4Tb7Q5xfcrIyEB8fLzY9REg9qsd+0jtd9myZdi4cSMeffTROst6PB6Ul5fXWqaqnmK/R87R2vHMmTODfcy9994Lg8GAVq1aBfMMBgM2bdqEK664AnFxccEvXz6fDw8//DDatGkDm82GVq1aYdq0aSGa5EAggJkzZ6Jp06ZwOp3o378/Nm3ahFatWuHaa6+t87w+/PBD9O7dO0Risnz5cvTs2TP45QEAOnTogIEDBx7RuMLn89UqTVm+fDlMJhPxcLLb7bjuuuvwyy+/YO/evaT8eeedh59++gkFBQX1rkNj4lj6w4KCAtxzzz3o2rUrIiMjER0djWHDhmH9+vWkXJUralZWFln/3XffwWAwBGVI/fr1w2effYbdu3cHx6FVtgkAOTk5uO6665CSkgK73Y5TTz015Itxlf7uySefxPPPP4/WrVvD6XRi8ODB2Lt3L5RSePjhh5GWlgaHw4FRo0ZVe+8WLFiAzp07w2azoWnTppg8eTKKioqqvQ5r167FWWedBYfDgfT0dCxatKjaOuk1gTXx+uuvIyMjAw6HA/Hx8RgzZkyIzVVHZmYm/vrrr5C+f9OmTdi0aRMmTZpEvE5uvvlmKKWwfPnyOvfNeeutt6CUwrhx44LrrFYrzjrrrJCyF110EQBg8+bNZP2xjMn/VS+BVQ1N78owcOBADBw4sM5ty8rKcM455+DZZ5/F4MGD8cwzz+DGG2/Eli1bsG/fPgCHX8pefvll9OvXD3PmzMHMmTORm5uLIUOGYN26dQCApKSk4Cfhiy66CMuWLcOyZctw8cUX13jstWvXwuPx4PTTTyfrN2zYAJ/PF6IvsFqt6N69O/7888+6LwoOdxoRERFwOBzo1KkT3nzzzRrLtm7dGjExMYiKisKVV15Zo397RkYGNm7ciJKSknrVQaieY7HZm
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {
|
|||
|
|
"id": "PiyibL4yzpup"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"## Błędne klasyfikacje"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"metadata": {
|
|||
|
|
"colab": {
|
|||
|
|
"base_uri": "https://localhost:8080/",
|
|||
|
|
"height": 906
|
|||
|
|
},
|
|||
|
|
"id": "ECh_2RW6zgKB",
|
|||
|
|
"outputId": "afdb0285-dbc2-4793-fb90-489f4f11ed22"
|
|||
|
|
},
|
|||
|
|
"source": [
|
|||
|
|
"show_samples_rgb(incorrect_indices, predicted, X_test, y_test, 5, cifar_names)"
|
|||
|
|
],
|
|||
|
|
"execution_count": null,
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"output_type": "display_data",
|
|||
|
|
"data": {
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 900x900 with 25 Axes>"
|
|||
|
|
],
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2cAAAN5CAYAAACMl0OOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3gV1dbG39NLTnoPIQm9I1IFQaqAoGBBrogIiHgFUVFRUVSKigXECuqnXlEU9YqKDQtNROUC0pEiLZRAek9OP/v7gydnZu2TBgQ5kPV7Hh7Onj1lz8zae/Zk1ruWRgghwDAMwzAMwzAMw1xQtBe6AQzDMAzDMAzDMAy/nDEMwzAMwzAMwwQF/HLGMAzDMAzDMAwTBPDLGcMwDMMwDMMwTBDAL2cMwzAMwzAMwzBBAL+cMQzDMAzDMAzDBAH8csYwDMMwDMMwDBME8MsZwzAMwzAMwzBMEMAvZwzDMAzDMAzDMEEAv5xVw7hx42Cz2Wq1rkajwaxZs+q8DUOGDMHEiRPrfL91yVtvvYWUlBQ4nc4L3ZR6zeLFi6HRaJCenn6hmxLApk2bYDQacfTo0QvdlCpxu91o2LAhFi1adKGbwuDC23NpaSni4uLw8ccfX5Dj15ZbbrkFI0eOvNDNYKqA5xG1Y/r06ejWrduFbka95cCBAxg4cCDCw8Oh0WiwfPnyC90kP8ePH4fZbMbvv/9+oZtSLXU5Fl8UL2cVD+mKf2azGc2bN8eUKVOQlZV1oZt33vj999/x888/49FHHyXLfT4fXnzxRTRq1Ahmsxnt27fHJ598Uqt9ytdS/S8zM5Os+8ADD6Bjx46IioqC1WpFq1atMGvWLJSWlpL1xo0bB5fLhbfffvvcTvgSpL7arsyMGTMwatQopKamkuV79+7F4MGDYbPZEBUVhTFjxiAnJ6fG/eXl5WHevHm46qqrEBsbi4iICFxxxRX47LPPKl3/wIEDuOWWW5CcnAyr1YqWLVtizpw5KC8v969jMBjw4IMP4tlnn4XD4Ti3E75EqU/2/OqrryI0NBS33HILWV5YWIi77roLsbGxCAkJQd++fbF169Ya9+fz+bB48WIMGzYMDRs2REhICNq2bYtnnnkmwN6qG6c1Gg15YXz00UfxxRdfYMeOHXVz4pco9cl21ZyPeYTMxIkTodFocO211wbUORwOPPfcc2jdujWsVisaNGiAm2++GX/99RdZb+rUqdixYwe++eabs2pDfeF82fHYsWOxa9cuPPvss1iyZAk6d+5ch60+N+bMmYNu3brhyiuvJMszMjIwcuRIREREICwsDMOHD8fhw4drtc+5c+fiiiuuQGxsLMxmM5o1a4apU6dWOv949tlnMWzYMMTHx1f7B5S6HIv157yHf5A5c+agUaNGcDgc+O233/Dmm29ixYoV2L17N6xW6wVtm91uh15ft5dz3rx56N+/P5o2bUqWz5gxA88//zwmTpyILl264Ouvv8att94KjUYTMJGoioprqSYiIoKUN2/ejF69emH8+PEwm83Ytm0bnn/+eaxatQq//vortNrT7/Zmsxljx47FggULcO+990Kj0Zz9SV+iBLPtnm+2b9+OVatW4Y8//iDLT5w4gauuugrh4eGYO3cuSktLMX/+fOzatcv/pa0qNmzYgBkzZmDIkCF44oknoNfr8cUXX+CWW27Bnj17MHv2bP+6x48fR9euXREeHo4pU6YgKioKGzZswMyZM7FlyxZ8/fXX/nXHjx+P6dOnY+nSpbjjjjvq/mJcIlzq9ux2u/Hqq6/igQcegE6n8y/3+XwYOnQoduzYgYcffhgxMTFYtGgR+vTpgy1btqBZs2ZV7rO8vBzjx4/HFVdcgbvvvhtxcXF+O1y9ejXWrFnjHzuvuuoqLFmyJGAfL7/8Mnbs2IH+/fv7l11++eXo3LkzXnrpJXz44Yd1eBUuTYLZdi+2eQQA/Pnnn1i8eDHMZnOl9aNHj8Y333yDiRMnomPHjjh58iQWLlyI7t27Y9euXf4/2CUkJGD48OGYP38+hg0bdvYnXE+oSzu22+3+Z+qUKVPOU4vPjpycHHzwwQf44IMPyPLS0lL07dsXRUVFePzxx2EwGPDyyy+jd+/e2L59O6Kjo6vd75YtW9ChQwfccsstCA0Nxd69e/HOO+/g+++/x/bt2xESEuJf94knnkBCQgIuv/xy/PTTT1Xus07HYnER8P777wsAYvPmzWT5gw8+KACIpUuXnpfjjh07VoSEhJyXfddEVlaW0Ov14t133yXLT5w4IQwGg7jnnnv8y3w+n+jVq5dITk4WHo+n2v1WdS1ry/z58wUAsWHDBrL8zz//FADE6tWrz2q/lyr/pO1WHOvIkSN1ts/aUFpaWm39fffdJ1JSUoTP5yPLJ02aJCwWizh69Kh/2cqVKwUA8fbbb1e7z8OHD4v09HSyzOfziX79+gmTyUTa9OyzzwoAYvfu3WT922+/XQAQ+fn5ZPm1114revXqVe3x6yv1wZ6FEOLLL78UAMTBgwfJ8s8++0wAEJ9//rl/WXZ2toiIiBCjRo2qdp9Op1P8/vvvActnz54tAIiVK1dWu315ebkIDQ0VV199dUDd/PnzRUhIiCgpKal2H/UZnkconOs8Qr1N9+7dxR133CFSU1PF0KFDA44DQEybNo0sX7NmjQAgFixYQJYvW7ZMaDQacejQoTM5zXrF+bDjo0ePCgBi3rx5Na5b0/O+rlmwYIGwWCwBY9sLL7wgAIhNmzb5l+3du1fodDrx2GOPndWxli1bJgCITz75hCyveAbl5OQIAGLmzJlV7qOuxuKLwq2xKvr16wcAOHLkiH/ZoUOHcOjQoRq3dbvdmD17Npo1awaz2Yzo6Gj07NkTK1euDFg3IyMD119/PWw2G2JjYzFt2jR4vV6yjvypc9asWdBoNNi3bx9GjhyJsLAwREdH4/7776+Vy9T3338Pj8eDAQMGkOVff/013G43Jk+eTI49adIknDhxAhs2bKhx3xWUlJQEnEdNpKWlATjt2qOmU6dOiIqKIl8hmKo5F9sFgL/++gv9+vWDxWJBcnIynnnmGfh8vkrX/eGHH9CrVy+EhIQgNDQUQ4cODXApAYB9+/ZhxIgRiIqKgtlsRufOnQNcTCpcKtatW4fJkycjLi4OycnJ1bZ1+fLl6NevX8AX1S+++ALXXnstUlJS/MsGDBiA5s2b47///W+1+2zUqFGAi6RGo8H1118Pp9NJXBuKi4sBAPHx8WT9xMREaLXagC90V199NX777Tfk5+dX2wZG4Z+050WLFqFNmzYwmUxISkrCPffcEzAeAcDChQvRuHFjWCwWdO3aFevXr0efPn3Qp0+fGtuzfPlypKWloUmTJmT5smXLEB8fjxtvvNG/LDY2FiNHjsTXX39dre7WaDSiR48eActvuOEGAKddfKvj22+/RUlJCUaPHh1Qd/XVV6OsrKzS5xdTPTyPUI59pvOIJUuWYPfu3Xj22WcrrS8pKQFQ+dgLABaLhSyvaCfPI86cs7XjWbNm+Z+lDz/8MDQajX+eV2F/e/bswa233orIyEj07NkTAODxePD000+jSZMmMJlMSEtLw+OPPx4wBvp8PsyaNQtJSUmwWq3o27cv9uzZg7S0NIwbN67G81q+fDm6desWoNtctmwZunTpgi5duviXtWzZEv37969x/lAVVc1vK5bXhroaiy/ql7MKo1N/vuzfvz9x+aiKWbNmYfbs2ejbty/eeOMNzJgxAykpKQHaAa/Xi0GDBiE6Ohrz589H79698dJLL+H//u//atXGkSNH+n2uhwwZgtdeew133XVXjdv98ccfiI6ODpiAbtu2DSEhIWjVqhVZ3rVrV399bejbty/CwsJgtVoxbNgwHDhwoNL1PB4PcnNzcfLkSfz888944oknEBoa6j+emo4dOwa9YDNYOBfbzczMRN++fbF9+3ZMnz4dU6dOxYcffohXX301YN0lS5Zg6NChsNlseOGFF/Dkk09iz5496NmzJwm08Ndff+GKK67A3r17MX36dLz00ksICQnB9ddfj6+++ipgv5MnT8aePXvw1FNPYfr06
|
|||
|
|
},
|
|||
|
|
"metadata": {}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
}
|