{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "KN4PS8Phd7Ea" }, "source": [ "# \\[4/22/2024 Updated\\]\n", "## 578hw3_Check2_CoLab.ipynb (Spring 2024)\n", "\n", "### **NOTE**: This is a version for Google CoLab.\n", "\n", "### Run this file (all cells) **AFTER** you have made all required modifications in the Network class (in \"NN578_network.ipynb\").\n", "\n", "### You should get the same output as what's shown in this file." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wn3pTCf-eSDZ", "outputId": "f23c6c11-f19e-4a91-de1a-9d0a015a299d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mounted at /content/drive\n" ] } ], "source": [ "## nt: Code piece to mount my Google Drive\n", "from google.colab import drive\n", "drive.mount(\"/content/drive\") # my Google Drive root directory will be mapped here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9RuRK-BmeSx2", "outputId": "f5642db6-6a69-40ac-aeae-4dff596f2424" }, "outputs": [], "source": [ "# nt: Change the working directory to the work directory (where the code file is).\n", "import os\n", "thisdir = '/content/drive/My Drive/CSC578_Spring2024/HW#3'\n", "os.chdir(thisdir)\n", "\n", "# Ensure the files are there (in the folder)\n", "!pwd" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "LIJYLmOed7Eg", "outputId": "9cd8a3d7-d0ad-41b4-c880-ee2ad757ca35" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting import-ipynb\n", " Downloading import_ipynb-0.1.4-py3-none-any.whl (4.1 kB)\n", "Requirement already satisfied: IPython in /usr/local/lib/python3.10/dist-packages (from import-ipynb) (7.34.0)\n", "Requirement already satisfied: nbformat in /usr/local/lib/python3.10/dist-packages (from import-ipynb) (5.10.4)\n", "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.10/dist-packages (from IPython->import-ipynb) (67.7.2)\n", "Collecting jedi>=0.16 (from IPython->import-ipynb)\n", " Downloading jedi-0.19.1-py2.py3-none-any.whl (1.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from IPython->import-ipynb) (4.4.2)\n", "Requirement already satisfied: pickleshare in /usr/local/lib/python3.10/dist-packages (from IPython->import-ipynb) (0.7.5)\n", "Requirement already satisfied: traitlets>=4.2 in /usr/local/lib/python3.10/dist-packages (from IPython->import-ipynb) (5.7.1)\n", "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from IPython->import-ipynb) (3.0.43)\n", "Requirement already satisfied: pygments in /usr/local/lib/python3.10/dist-packages (from IPython->import-ipynb) (2.16.1)\n", "Requirement already satisfied: backcall in /usr/local/lib/python3.10/dist-packages (from IPython->import-ipynb) (0.2.0)\n", "Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.10/dist-packages (from IPython->import-ipynb) (0.1.7)\n", "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.10/dist-packages (from IPython->import-ipynb) (4.9.0)\n", "Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.10/dist-packages (from nbformat->import-ipynb) (2.19.1)\n", "Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.10/dist-packages (from nbformat->import-ipynb) (4.19.2)\n", "Requirement already satisfied: jupyter-core!=5.0.*,>=4.12 in /usr/local/lib/python3.10/dist-packages (from nbformat->import-ipynb) (5.7.2)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.3 in /usr/local/lib/python3.10/dist-packages (from jedi>=0.16->IPython->import-ipynb) (0.8.4)\n", "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat->import-ipynb) (23.2.0)\n", "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat->import-ipynb) (2023.12.1)\n", "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat->import-ipynb) (0.34.0)\n", "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat->import-ipynb) (0.18.0)\n", "Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.10/dist-packages (from jupyter-core!=5.0.*,>=4.12->nbformat->import-ipynb) (4.2.0)\n", "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.10/dist-packages (from pexpect>4.3->IPython->import-ipynb) (0.7.0)\n", "Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->IPython->import-ipynb) (0.2.13)\n", "Installing collected packages: jedi, import-ipynb\n", "Successfully installed import-ipynb-0.1.4 jedi-0.19.1\n" ] } ], "source": [ "# First install this library so that we can import code from other Notebooks\n", "!pip install import-ipynb\n", "import import_ipynb" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "cVG630W4d7Ei", "outputId": "de4161c7-585c-413a-e031-7beb16f7eec8" }, "outputs": [], "source": [ "# import the class Network from \"NN578_network.ipynb\"\n", "#import NN578_network as network_nb\n", "import NN578_network as network_nb\n", "import numpy as np\n", "\n", "# Load the dataset\n", "iris_data = network_nb.my_load_csv('iris-3.csv', 4, 3)\n", "iris_train = iris_data[:105]\n", "iris_test = iris_data[105:]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "l_Im7Dscd7Ej", "outputId": "5ebc0949-853d-4887-e709-e45000bcccc6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Epoch 0] Train: Count= 50, Accuracy=0.3333, MSE=0.3335, CE=1.9120, LL=1.1041\n", "[Epoch 1] Train: Count=100, Accuracy=0.6667, MSE=0.2436, CE=1.5128, LL=0.7782\n", "[Epoch 2] Train: Count=100, Accuracy=0.6667, MSE=0.2098, CE=1.3345, LL=0.6645\n", "\n", "------ Returned Results --------------\n", "** train: [{'Count': 50, 'Accuracy': 0.3333333333333333, 'MSE': 0.3335158842115605, 'CE': 1.9119831603934248, 'LL': 1.1041273241055265}, {'Count': 100, 'Accuracy': 0.6666666666666666, 'MSE': 0.24357648203481708, 'CE': 1.5127940376356965, 'LL': 0.778221019446114}, {'Count': 100, 'Accuracy': 0.6666666666666666, 'MSE': 0.20978951763761794, 'CE': 1.334470423907723, 'LL': 0.6645348071506488}]\n", "** valid: []\n", "\n", "*************************************************\n", "Initial activations shape: [(4, 1), (2, 1), (3, 1)]\n", "*************************************************\n" ] } ], "source": [ "# Create a network from the saved network\n", "net1 = network_nb.Network.load_network(\"iris-423.dat\")\n", "\n", "# Train the network for 3 epochs, with minibatch size 8, eta=2.5 and no testset.\n", "trn_results, val_results = net1.SGD(iris_data, 3, 8, 2.5)\n", "print ('\\n------ Returned Results --------------')\n", "print ('** train: {}'.format(trn_results))\n", "print ('** valid: {}'.format(val_results))\n", "print ('\\n*************************************************')\n", "print ('Initial activations shape: {}'.format(net1.init_acts_shape))\n", "print ('*************************************************')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jHHfH5x-d7Ek", "outputId": "aebe0eb2-fa67-4569-f142-340c6b172bfd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Epoch 0] Train: Count= 36, Accuracy=0.3429, MSE=0.3370, CE=1.9248, LL=1.0320\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3407, CE=1.9408, LL=1.0413\n", "[Epoch 1] Train: Count= 36, Accuracy=0.3429, MSE=0.3347, CE=1.9158, LL=1.0944\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3374, CE=1.9273, LL=1.1012\n", "[Epoch 2] Train: Count= 36, Accuracy=0.3429, MSE=0.3331, CE=1.9081, LL=1.1037\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3354, CE=1.9184, LL=1.1098\n", "[Epoch 3] Train: Count= 36, Accuracy=0.3429, MSE=0.3214, CE=1.8516, LL=1.0707\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3233, CE=1.8595, LL=1.0754\n", "[Epoch 4] Train: Count= 71, Accuracy=0.6762, MSE=0.2837, CE=1.6817, LL=0.9288\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2848, CE=1.6861, LL=0.9305\n", "[Epoch 5] Train: Count= 71, Accuracy=0.6762, MSE=0.2572, CE=1.5678, LL=0.8249\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2578, CE=1.5703, LL=0.8251\n", "[Epoch 6] Train: Count= 71, Accuracy=0.6762, MSE=0.2383, CE=1.4828, LL=0.7573\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2390, CE=1.4852, LL=0.7573\n", "[Epoch 7] Train: Count= 71, Accuracy=0.6762, MSE=0.2242, CE=1.4152, LL=0.7093\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2251, CE=1.4186, LL=0.7095\n", "[Epoch 8] Train: Count= 71, Accuracy=0.6762, MSE=0.2135, CE=1.3606, LL=0.6728\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2148, CE=1.3653, LL=0.6734\n", "[Epoch 9] Train: Count= 71, Accuracy=0.6762, MSE=0.2051, CE=1.3143, LL=0.6431\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2068, CE=1.3209, LL=0.6443\n", "[Epoch 10] Train: Count= 71, Accuracy=0.6762, MSE=0.1977, CE=1.2704, LL=0.6160\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2000, CE=1.2793, LL=0.6180\n", "[Epoch 11] Train: Count= 71, Accuracy=0.6762, MSE=0.1914, CE=1.2284, LL=0.5892\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1940, CE=1.2388, LL=0.5917\n", "[Epoch 12] Train: Count= 71, Accuracy=0.6762, MSE=0.1871, CE=1.1967, LL=0.5678\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1899, CE=1.2075, LL=0.5703\n", "[Epoch 13] Train: Count= 71, Accuracy=0.6762, MSE=0.1841, CE=1.1730, LL=0.5528\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1870, CE=1.1840, LL=0.5552\n", "[Epoch 14] Train: Count= 71, Accuracy=0.6762, MSE=0.1817, CE=1.1540, LL=0.5421\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1846, CE=1.1650, LL=0.5442\n", "[Epoch 15] Train: Count= 71, Accuracy=0.6762, MSE=0.1798, CE=1.1381, LL=0.5338\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1828, CE=1.1491, LL=0.5356\n", "[Epoch 16] Train: Count= 71, Accuracy=0.6762, MSE=0.1783, CE=1.1246, LL=0.5271\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1812, CE=1.1355, LL=0.5285\n", "[Epoch 17] Train: Count= 71, Accuracy=0.6762, MSE=0.1769, CE=1.1130, LL=0.5214\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1798, CE=1.1237, LL=0.5223\n", "[Epoch 18] Train: Count= 71, Accuracy=0.6762, MSE=0.1757, CE=1.1029, LL=0.5165\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1786, CE=1.1133, LL=0.5169\n", "[Epoch 19] Train: Count= 71, Accuracy=0.6762, MSE=0.1746, CE=1.0940, LL=0.5123\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1774, CE=1.1040, LL=0.5120\n", "[Epoch 20] Train: Count= 71, Accuracy=0.6762, MSE=0.1736, CE=1.0860, LL=0.5086\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1763, CE=1.0955, LL=0.5076\n", "[Epoch 21] Train: Count= 71, Accuracy=0.6762, MSE=0.1726, CE=1.0788, LL=0.5056\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1751, CE=1.0875, LL=0.5038\n", "[Epoch 22] Train: Count= 71, Accuracy=0.6762, MSE=0.1715, CE=1.0724, LL=0.5035\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1738, CE=1.0799, LL=0.5009\n", "[Epoch 23] Train: Count= 71, Accuracy=0.6762, MSE=0.1704, CE=1.0668, LL=0.5029\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1723, CE=1.0726, LL=0.4996\n", "[Epoch 24] Train: Count= 71, Accuracy=0.6762, MSE=0.1694, CE=1.0625, LL=0.5040\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1707, CE=1.0662, LL=0.5004\n", "[Epoch 25] Train: Count= 69, Accuracy=0.6571, MSE=0.1685, CE=1.0596, LL=0.5060\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1692, CE=1.0611, LL=0.5028\n", "[Epoch 26] Train: Count= 69, Accuracy=0.6571, MSE=0.1674, CE=1.0553, LL=0.5047\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1676, CE=1.0557, LL=0.5027\n", "[Epoch 27] Train: Count= 69, Accuracy=0.6571, MSE=0.1645, CE=1.0431, LL=0.4932\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1648, CE=1.0440, LL=0.4921\n", "[Epoch 28] Train: Count= 71, Accuracy=0.6762, MSE=0.1598, CE=1.0223, LL=0.4686\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1605, CE=1.0245, LL=0.4671\n", "[Epoch 29] Train: Count= 95, Accuracy=0.9048, MSE=0.1563, CE=1.0049, LL=0.4394\n", " Valid: Count= 38, Accuracy=0.8444, MSE=0.1576, CE=1.0087, LL=0.4365\n", "[Epoch 30] Train: Count=103, Accuracy=0.9810, MSE=0.1557, CE=0.9986, LL=0.4163\n", " Valid: Count= 45, Accuracy=1.0000, MSE=0.1573, CE=1.0034, LL=0.4123\n", "[Epoch 31] Train: Count=100, Accuracy=0.9524, MSE=0.1557, CE=0.9961, LL=0.4026\n", " Valid: Count= 43, Accuracy=0.9556, MSE=0.1573, CE=1.0009, LL=0.3977\n", "[Epoch 32] Train: Count= 97, Accuracy=0.9238, MSE=0.1537, CE=0.9872, LL=0.3957\n", " Valid: Count= 43, Accuracy=0.9556, MSE=0.1551, CE=0.9911, LL=0.3905\n", "[Epoch 33] Train: Count= 99, Accuracy=0.9429, MSE=0.1487, CE=0.9675, LL=0.3924\n", " Valid: Count= 43, Accuracy=0.9556, MSE=0.1497, CE=0.9701, LL=0.3877\n", "[Epoch 34] Train: Count=100, Accuracy=0.9524, MSE=0.1414, CE=0.9388, LL=0.3905\n", " Valid: Count= 43, Accuracy=0.9556, MSE=0.1420, CE=0.9403, LL=0.3872\n", "[Epoch 35] Train: Count=102, Accuracy=0.9714, MSE=0.1333, CE=0.9074, LL=0.3887\n", " Valid: Count= 44, Accuracy=0.9778, MSE=0.1339, CE=0.9093, LL=0.3882\n", "[Epoch 36] Train: Count=103, Accuracy=0.9810, MSE=0.1264, CE=0.8800, LL=0.3873\n", " Valid: Count= 45, Accuracy=1.0000, MSE=0.1277, CE=0.8852, LL=0.3907\n", "[Epoch 37] Train: Count=102, Accuracy=0.9714, MSE=0.1215, CE=0.8600, LL=0.3868\n", " Valid: Count= 43, Accuracy=0.9556, MSE=0.1243, CE=0.8714, LL=0.3953\n", "[Epoch 38] Train: Count=100, Accuracy=0.9524, MSE=0.1183, CE=0.8464, LL=0.3872\n", " Valid: Count= 41, Accuracy=0.9111, MSE=0.1230, CE=0.8656, LL=0.4012\n", "[Epoch 39] Train: Count= 98, Accuracy=0.9333, MSE=0.1162, CE=0.8364, LL=0.3878\n", " Valid: Count= 41, Accuracy=0.9111, MSE=0.1228, CE=0.8636, LL=0.4070\n", "[Epoch 40] Train: Count= 97, Accuracy=0.9238, MSE=0.1146, CE=0.8283, LL=0.3883\n", " Valid: Count= 41, Accuracy=0.9111, MSE=0.1230, CE=0.8631, LL=0.4124\n", "[Epoch 41] Train: Count= 98, Accuracy=0.9333, MSE=0.1134, CE=0.8217, LL=0.3891\n", " Valid: Count= 40, Accuracy=0.8889, MSE=0.1235, CE=0.8639, LL=0.4180\n", "[Epoch 42] Train: Count= 98, Accuracy=0.9333, MSE=0.1124, CE=0.8155, LL=0.3895\n", " Valid: Count= 39, Accuracy=0.8667, MSE=0.1240, CE=0.8646, LL=0.4228\n", "[Epoch 43] Train: Count= 97, Accuracy=0.9238, MSE=0.1109, CE=0.8072, LL=0.3876\n", " Valid: Count= 39, Accuracy=0.8667, MSE=0.1237, CE=0.8619, LL=0.4244\n", "[Epoch 44] Train: Count= 97, Accuracy=0.9238, MSE=0.1087, CE=0.7956, LL=0.3827\n", " Valid: Count= 39, Accuracy=0.8667, MSE=0.1224, CE=0.8544, LL=0.4220\n", "[Epoch 45] Train: Count= 97, Accuracy=0.9238, MSE=0.1058, CE=0.7805, LL=0.3749\n", " Valid: Count= 39, Accuracy=0.8667, MSE=0.1200, CE=0.8420, LL=0.4156\n", "[Epoch 46] Train: Count= 98, Accuracy=0.9333, MSE=0.1021, CE=0.7622, LL=0.3645\n", " Valid: Count= 39, Accuracy=0.8667, MSE=0.1166, CE=0.8248, LL=0.4057\n", "[Epoch 47] Train: Count= 98, Accuracy=0.9333, MSE=0.0978, CE=0.7409, LL=0.3518\n", " Valid: Count= 40, Accuracy=0.8889, MSE=0.1122, CE=0.8033, LL=0.3926\n", "[Epoch 48] Train: Count= 97, Accuracy=0.9238, MSE=0.0930, CE=0.7176, LL=0.3377\n", " Valid: Count= 41, Accuracy=0.9111, MSE=0.1070, CE=0.7782, LL=0.3771\n", "[Epoch 49] Train: Count= 98, Accuracy=0.9333, MSE=0.0882, CE=0.6938, LL=0.3231\n", " Valid: Count= 41, Accuracy=0.9111, MSE=0.1015, CE=0.7514, LL=0.3603\n", "\n", "------ Returned Results --------------\n", "** train: {'Count': 98, 'Accuracy': 0.9333333333333333, 'MSE': 0.08817276167025205, 'CE': 0.6937547467420072, 'LL': 0.32310317781762193}\n", "** valid: {'Count': 41, 'Accuracy': 0.9111111111111111, 'MSE': 0.10151042106871096, 'CE': 0.7514016925963276, 'LL': 0.36031495908338235}\n", "\n", "*************************************************\n", "Initial activations shape: [(4, 1), (2, 1), (3, 1)]\n", "*************************************************\n" ] } ], "source": [ "# Re-load the saved network and run it again for the maximum of 50 epochs, this\n", "# time with a test set. The execution actually terminates ealier because of the\n", "# 'no improvement over the last two epochs' condition.\n", "net2 = network_nb.Network.load_network(\"iris-423.dat\")\n", "trn_results, val_results = net2.SGD(iris_train, 50, 8, 1.0, iris_test)\n", "print ('\\n------ Returned Results --------------')\n", "print ('** train: {}'.format(trn_results[-1]))\n", "print ('** valid: {}'.format(val_results[-1]))\n", "print ('\\n*************************************************')\n", "print ('Initial activations shape: {}'.format(net2.init_acts_shape))\n", "print ('*************************************************')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hvr5GhkEwWmj", "outputId": "6dad3d45-0c36-465f-bcdf-38778c34beb3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Epoch 0] Train: Count= 36, Accuracy=0.3429, MSE=0.3374, CE=1.9283, LL=1.1200\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3409, CE=1.9433, LL=1.1290\n", "[Epoch 1] Train: Count= 36, Accuracy=0.3429, MSE=0.3300, CE=1.8918, LL=1.1142\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3330, CE=1.9047, LL=1.1220\n", "[Epoch 2] Train: Count= 71, Accuracy=0.6762, MSE=0.2676, CE=1.6061, LL=0.8720\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2695, CE=1.6139, LL=0.8753\n", "[Epoch 3] Train: Count= 71, Accuracy=0.6762, MSE=0.2305, CE=1.4393, LL=0.7326\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2326, CE=1.4477, LL=0.7357\n", "[Epoch 4] Train: Count= 71, Accuracy=0.6762, MSE=0.2078, CE=1.3248, LL=0.6528\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2109, CE=1.3372, LL=0.6570\n", "[Epoch 5] Train: Count= 71, Accuracy=0.6762, MSE=0.1920, CE=1.2281, LL=0.5890\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1958, CE=1.2429, LL=0.5939\n", "[Epoch 6] Train: Count= 71, Accuracy=0.6762, MSE=0.1848, CE=1.1750, LL=0.5544\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1887, CE=1.1901, LL=0.5589\n", "[Epoch 7] Train: Count= 71, Accuracy=0.6762, MSE=0.1805, CE=1.1402, LL=0.5369\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1843, CE=1.1550, LL=0.5405\n", "[Epoch 8] Train: Count= 71, Accuracy=0.6762, MSE=0.1775, CE=1.1155, LL=0.5267\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1813, CE=1.1295, LL=0.5291\n", "[Epoch 9] Train: Count= 71, Accuracy=0.6762, MSE=0.1754, CE=1.0979, LL=0.5226\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1787, CE=1.1102, LL=0.5234\n", "[Epoch 10] Train: Count= 71, Accuracy=0.6762, MSE=0.1744, CE=1.0895, LL=0.5285\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1767, CE=1.0976, LL=0.5277\n", "[Epoch 11] Train: Count= 69, Accuracy=0.6571, MSE=0.1783, CE=1.1050, LL=0.5539\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1787, CE=1.1069, LL=0.5545\n", "\n", "------ Returned Results --------------\n", "** train: {'Count': 69, 'Accuracy': 0.6571428571428571, 'MSE': 0.17831844953424905, 'CE': 1.104950985615117, 'LL': 0.553866411203669}\n", "** valid: {'Count': 29, 'Accuracy': 0.6444444444444445, 'MSE': 0.1787218094339433, 'CE': 1.106865617722989, 'LL': 0.5545264874362218}\n", "\n", "*************************************************\n", "Initial activations shape: [(4, 1), (2, 1), (3, 1)]\n", "*************************************************\n" ] } ], "source": [ "## 4/22/2024 ADDITION\n", "# This code tests the early exit based on the 'MSE loss of the latest epoch was\n", "# larger than or equal to the maximum of the MSE losses of the THREE preceding\n", "# epochs (if exist)' (i.e., patience 3) condition.\n", "net2 = network_nb.Network.load_network(\"iris-423.dat\")\n", "trn_results, val_results = net2.SGD(iris_train, 50, 8, 2.0, iris_test)\n", "print ('\\n------ Returned Results --------------')\n", "print ('** train: {}'.format(trn_results[-1]))\n", "print ('** valid: {}'.format(val_results[-1]))\n", "print ('\\n*************************************************')\n", "print ('Initial activations shape: {}'.format(net2.init_acts_shape))\n", "print ('*************************************************')" ] }, { "cell_type": "markdown", "metadata": { "id": "YZ6Rivbzd7Ek" }, "source": [ "## Further check with a 3-layer network (\"iris4-20-7-3.dat\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NJBGcxVzd7El", "outputId": "86660a65-c9de-4b36-cb95-b0cf1de0acc3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Epoch 0] Train: Count= 36, Accuracy=0.3429, MSE=0.3350, CE=1.9171, LL=1.0974\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3368, CE=1.9249, LL=1.1024\n", "[Epoch 1] Train: Count= 36, Accuracy=0.3429, MSE=0.3354, CE=1.9191, LL=1.1178\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3376, CE=1.9289, LL=1.1241\n", "[Epoch 2] Train: Count= 36, Accuracy=0.3429, MSE=0.3353, CE=1.9185, LL=1.1189\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3376, CE=1.9286, LL=1.1254\n", "[Epoch 3] Train: Count= 36, Accuracy=0.3429, MSE=0.3350, CE=1.9172, LL=1.1172\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3373, CE=1.9273, LL=1.1237\n", "[Epoch 4] Train: Count= 36, Accuracy=0.3429, MSE=0.3347, CE=1.9157, LL=1.1149\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3370, CE=1.9256, LL=1.1213\n", "[Epoch 5] Train: Count= 36, Accuracy=0.3429, MSE=0.3343, CE=1.9138, LL=1.1122\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3365, CE=1.9236, LL=1.1185\n", "[Epoch 6] Train: Count= 36, Accuracy=0.3429, MSE=0.3337, CE=1.9112, LL=1.1089\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3360, CE=1.9210, LL=1.1151\n", "[Epoch 7] Train: Count= 36, Accuracy=0.3429, MSE=0.3330, CE=1.9077, LL=1.1045\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3352, CE=1.9173, LL=1.1106\n", "[Epoch 8] Train: Count= 36, Accuracy=0.3429, MSE=0.3319, CE=1.9025, LL=1.0985\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3340, CE=1.9119, LL=1.1045\n", "[Epoch 9] Train: Count= 36, Accuracy=0.3429, MSE=0.3303, CE=1.8949, LL=1.0900\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3324, CE=1.9039, LL=1.0958\n", "[Epoch 10] Train: Count= 36, Accuracy=0.3429, MSE=0.3279, CE=1.8837, LL=1.0782\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3299, CE=1.8924, LL=1.0837\n", "[Epoch 11] Train: Count= 36, Accuracy=0.3429, MSE=0.3245, CE=1.8676, LL=1.0619\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3264, CE=1.8758, LL=1.0670\n", "[Epoch 12] Train: Count= 36, Accuracy=0.3429, MSE=0.3194, CE=1.8440, LL=1.0390\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3212, CE=1.8516, LL=1.0436\n", "[Epoch 13] Train: Count= 36, Accuracy=0.3429, MSE=0.3119, CE=1.8093, LL=1.0064\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3135, CE=1.8160, LL=1.0104\n", "[Epoch 14] Train: Count= 36, Accuracy=0.3429, MSE=0.3009, CE=1.7584, LL=0.9606\n", " Valid: Count= 14, Accuracy=0.3111, MSE=0.3022, CE=1.7639, LL=0.9637\n", "[Epoch 15] Train: Count= 70, Accuracy=0.6667, MSE=0.2851, CE=1.6873, LL=0.9000\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2861, CE=1.6912, LL=0.9019\n", "[Epoch 16] Train: Count= 71, Accuracy=0.6762, MSE=0.2650, CE=1.5973, LL=0.8278\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2657, CE=1.5996, LL=0.8285\n", "[Epoch 17] Train: Count= 71, Accuracy=0.6762, MSE=0.2428, CE=1.4976, LL=0.7540\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2433, CE=1.4987, LL=0.7536\n", "[Epoch 18] Train: Count= 71, Accuracy=0.6762, MSE=0.2232, CE=1.4045, LL=0.6919\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2236, CE=1.4050, LL=0.6907\n", "[Epoch 19] Train: Count= 71, Accuracy=0.6762, MSE=0.2083, CE=1.3279, LL=0.6455\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.2088, CE=1.3284, LL=0.6437\n", "[Epoch 20] Train: Count= 71, Accuracy=0.6762, MSE=0.1979, CE=1.2686, LL=0.6125\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1985, CE=1.2693, LL=0.6102\n", "[Epoch 21] Train: Count= 71, Accuracy=0.6762, MSE=0.1906, CE=1.2230, LL=0.5894\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1913, CE=1.2239, LL=0.5866\n", "[Epoch 22] Train: Count= 71, Accuracy=0.6762, MSE=0.1853, CE=1.1876, LL=0.5733\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1862, CE=1.1887, LL=0.5701\n", "[Epoch 23] Train: Count= 71, Accuracy=0.6762, MSE=0.1814, CE=1.1595, LL=0.5625\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1823, CE=1.1609, LL=0.5589\n", "[Epoch 24] Train: Count= 71, Accuracy=0.6762, MSE=0.1783, CE=1.1369, LL=0.5557\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1793, CE=1.1384, LL=0.5520\n", "[Epoch 25] Train: Count= 71, Accuracy=0.6762, MSE=0.1758, CE=1.1183, LL=0.5525\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1768, CE=1.1199, LL=0.5487\n", "[Epoch 26] Train: Count= 71, Accuracy=0.6762, MSE=0.1737, CE=1.1031, LL=0.5523\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1748, CE=1.1049, LL=0.5489\n", "[Epoch 27] Train: Count= 71, Accuracy=0.6762, MSE=0.1721, CE=1.0908, LL=0.5550\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1732, CE=1.0931, LL=0.5523\n", "[Epoch 28] Train: Count= 71, Accuracy=0.6762, MSE=0.1708, CE=1.0812, LL=0.5598\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1722, CE=1.0843, LL=0.5584\n", "[Epoch 29] Train: Count= 71, Accuracy=0.6762, MSE=0.1701, CE=1.0742, LL=0.5659\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1717, CE=1.0787, LL=0.5662\n", "[Epoch 30] Train: Count= 71, Accuracy=0.6762, MSE=0.1699, CE=1.0696, LL=0.5727\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1719, CE=1.0758, LL=0.5750\n", "[Epoch 31] Train: Count= 71, Accuracy=0.6762, MSE=0.1700, CE=1.0671, LL=0.5795\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1726, CE=1.0755, LL=0.5841\n", "[Epoch 32] Train: Count= 71, Accuracy=0.6762, MSE=0.1704, CE=1.0658, LL=0.5850\n", " Valid: Count= 29, Accuracy=0.6444, MSE=0.1735, CE=1.0768, LL=0.5921\n", "\n", "------ Returned Results --------------\n", "** train: {'Count': 71, 'Accuracy': 0.6761904761904762, 'MSE': 0.17042619766053743, 'CE': 1.065787120205457, 'LL': 0.585027153325647}\n", "** valid: {'Count': 29, 'Accuracy': 0.6444444444444445, 'MSE': 0.17351439367370464, 'CE': 1.0768200510121895, 'LL': 0.5921240809600145}\n", "\n", "*************************************************\n", "Initial activations shape: [(4, 1), (20, 1), (7, 1), (3, 1)]\n", "*************************************************\n" ] } ], "source": [ "# Load a deeper network (from a saved file). Again, the execution actually\n", "# terminates ealier than 50 pochs because of the 'no improvement over the last\n", "# five epochs' condition.\n", "net3 = network_nb.Network.load_network(\"iris4-20-7-3.dat\")\n", "trn_results, val_results = net3.SGD(iris_train, 50, 8, 1.0, iris_test)\n", "print ('\\n------ Returned Results --------------')\n", "print ('** train: {}'.format(trn_results[-1]))\n", "print ('** valid: {}'.format(val_results[-1]))\n", "print ('\\n*************************************************')\n", "print ('Initial activations shape: {}'.format(net3.init_acts_shape))\n", "print ('*************************************************')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bkVIRuEtSEVx" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 1 }