{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "fb1065df-3051-47df-97d4-b8ce2d9b34b7",
      "metadata": {
        "id": "fb1065df-3051-47df-97d4-b8ce2d9b34b7"
      },
      "source": [
        "<<div style=\"color:blue\">\n",
        "    \n",
        "# **General Instructions for the ML Coding Problems**\n",
        "\n",
        "Please follow these instructions carefully to ensure a smooth evaluation process.\n",
        "\n",
        "## **1. Notebook Submission**\n",
        "- You **must** make a copy of this notebook and append your **full name** to the filename before submitting (e.g., `[OriginalNotebookName]_[YourName].ipynb`).\n",
        "- Share  your notebook copy with inaio@acmindia.org [This is for your own safety so that you do not accidentally lose any changes while editing the notebook]\n",
        "- After solving the questions, ensure you mention the correct URL of your  modified notebook in the test form\n",
        "- Also answer questions on external resources used and link to LLM chats used for each problem in the main test form\n",
        "\n",
        "## **2. Attempting the Questions**\n",
        "- Carefully **read each problem statement** before attempting.\n",
        "- **Attempt all parts** of each question.\n",
        "- Each question is organized into the following parts\n",
        "   - **DATA**, **TASK**, **HELPER CODE [Optional]** and **ANSWER**\n",
        "- **Follow the function signatures** provided. Do not modify them.\n",
        "- You only need to edit the cells in the **ANSWER** sections\n",
        "- If required, you may also add other modules under **IMPORTS** and **INSTALLATION INSTRUCTIONS**\n",
        "- Do not edit the other cells, especially those marked with **DO NOT MODIFY** which are meant for evaluation\n",
        "- You may add new cells to the notebook with extra code as desired\n",
        "  \n",
        "\n",
        "## **3. Scoring Criteria**\n",
        "Your score will be based on the following factors with distribution varying across each problem.\n",
        "- **Soundness & Creativity** of your approach.  \n",
        "  - Include a clear description and rationale of your solution methodology in the notebook (in markdown cells)\n",
        "  - Solutions that showcase your understanding of data and ML will garner more points\n",
        "- **Code Implementation & Readability**\n",
        "  - Ensure your implementation is correct and works\n",
        "  - Incomplete non-working code will be awarded  partial marks based on problem-wise rubric\n",
        "  - In case you have a solution but are unsure about some aspect, you can define a function that solves that aspect and present the rest of the solution\n",
        "  - Use comments to explain important parts of your code.\n",
        "- **Performance of Your Model**:\n",
        "  - Each task will be assessed based on specified performance metrics both on shared datasets and secret datasets\n",
        "  - Different performance ranges will receive different scores.\n",
        "  - Secret datasets used for last section will be shared along with the final results\n",
        "\n",
        "**Points associated with cells are marked at the beginning of the cell**\n",
        "    \n",
        "## **4. Dataset Usage**\n",
        "- **Only use the datasets provided** in this test.\n",
        "- Do **not** use the provided test data set for training.\n",
        "- Do **not** use external datasets for training or testing.\n",
        "- If the submitted performance metrics cannot be reproduced with your code and original datasets, then you will lose all the points associated with model performance.\n",
        "\n",
        "\n",
        "</div>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "65cb0461-0656-4d25-a67b-3385821ac8d9",
      "metadata": {
        "id": "65cb0461-0656-4d25-a67b-3385821ac8d9"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "markdown",
      "id": "016204c1-8ed0-4105-8a3d-ea3f7757b05b",
      "metadata": {
        "id": "016204c1-8ed0-4105-8a3d-ea3f7757b05b"
      },
      "source": [
        "<div style=\"color:blue\">\n",
        "\n",
        "## Problem 6:  Drone Patrol: From Vivid Landscapes to Monochrome Missions [14 pts]\n",
        "You are an AI military expert assigned to modernize army surveillance by setting up an unmanned AI drone patrol system. The drones are equipped with cameras and microphones to detect unusual activity. During patrol, whenever an unusual sound is detected, the drone captures an image of the source location. The challenge is to classify the source location to determine if it is on the **coast** or **desert** and inform the appropriate responding authority.\n",
        "\n",
        "**Training Data:**\n",
        "- We have a large corpus of previous images that can be used for training\n",
        "- During the day, the drone captures **color images**, and these images are all labeled by human experts.\n",
        "- At night, the drone captures only **grayscale images**, but these remain **unlabeled** since it takes more effort.\n",
        "\n",
        "\n",
        "During  patrol, the new images to be classified can be  **colored or grayscale** based on the whether it is daytime or night.\n",
        "\n",
        "Your task is to develop an ML model that can learn from the training data (labeled color images and  the unlabeled grayscale images) to perform accurate classification during  patrol.\n",
        "\n",
        "This problem has 3 questions (2 to be attempted, 3rd one private INAIO evaluation)\n",
        "-  **Q1: Build an Image LocationType Classifier** [9 pts]\n",
        "-  **Q2: Test Your Classifier - Public Set** [2 pts]\n",
        "-  **Q3: Test Your Classifier - Secret Set** [3 pts] [NOT FOR STUDENTS TO ATTEMPT]\n",
        "\n",
        "\n",
        "\n",
        "</div>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "da1df294-8c85-4cf4-be92-5a70a2060489",
      "metadata": {
        "id": "da1df294-8c85-4cf4-be92-5a70a2060489"
      },
      "source": [
        "<div style=\"color:blue\">\n",
        "    \n",
        "### INSTALLATION  \n",
        "\n",
        "</div>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b9529e3c-ac68-4ae1-b9e4-f5afa5bbcc3e",
      "metadata": {
        "id": "b9529e3c-ac68-4ae1-b9e4-f5afa5bbcc3e",
        "outputId": "4b0bd854-d9e9-4fd7-b865-856b8f3e6fbe",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: uv in /usr/local/lib/python3.11/dist-packages (0.6.3)\n",
            "\u001b[2mUsing Python 3.11.11 environment at: /usr\u001b[0m\n",
            "\u001b[2mAudited \u001b[1m7 packages\u001b[0m \u001b[2min 101ms\u001b[0m\u001b[0m\n"
          ]
        }
      ],
      "source": [
        "!pip install uv\n",
        "!uv pip install tensorflow pandas numpy opencv-python-headless scikit-learn matplotlib seaborn\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "151a3114-b7ef-4d68-a94a-9febabd724f9",
      "metadata": {
        "id": "151a3114-b7ef-4d68-a94a-9febabd724f9"
      },
      "source": [
        "<div style=\"color:blue\">\n",
        "    \n",
        "### IMPORTS\n",
        "</div>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b2593443-65dd-4769-a76a-8628d5cffdea",
      "metadata": {
        "id": "b2593443-65dd-4769-a76a-8628d5cffdea"
      },
      "outputs": [],
      "source": [
        "# EDIT:\n",
        "# You may add any other free python packages along with comments\n",
        "\n",
        "# Data Types\n",
        "from typing import Any\n",
        "\n",
        "# System, File I/O\n",
        "import os\n",
        "import cv2\n",
        "\n",
        "# Data handling\n",
        "import pandas as pd  # Data manipulation and analysis\n",
        "import numpy as np  # Numerical computations and array handling\n",
        "\n",
        "\n",
        "# Machine Learning models and process\n",
        "import tensorflow as tf\n",
        "from tensorflow.keras.preprocessing.image import load_img, img_to_array\n",
        "from sklearn.model_selection import train_test_split  # Split data into training and testing sets\n",
        "from sklearn.preprocessing import LabelEncoder\n",
        "\n",
        "\n",
        "# Model evaluation\n",
        "from sklearn.metrics import accuracy_score, classification_report  # Performance metrics\n",
        "from sklearn.metrics import roc_curve, auc  # ROC curve and AUC score\n",
        "\n",
        "\n",
        "# Visualization\n",
        "import matplotlib.pyplot as plt  # Plotting graphs\n",
        "import seaborn as sns  # Enhanced data visualization\n",
        "\n",
        "# Image processing\n",
        "from skimage.color import rgb2gray\n",
        "from skimage.io import imread\n",
        "from skimage.metrics import structural_similarity as ssim"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "**COPY DATA**"
      ],
      "metadata": {
        "id": "RolLAcDXKvpW"
      },
      "id": "RolLAcDXKvpW"
    },
    {
      "cell_type": "code",
      "source": [
        "# Copy data\n",
        "!mkdir /content/data\n",
        "!wget https://raw.githubusercontent.com/inaiogit/stage2test/main/test/color_images_test_public.csv\n",
        "!wget https://raw.githubusercontent.com/inaiogit/stage2test/main/test/color_images_train.csv\n",
        "!wget https://raw.githubusercontent.com/inaiogit/stage2test/main/test/grayscale_images_test_public.csv\n",
        "!wget https://raw.githubusercontent.com/inaiogit/stage2test/main/test/grayscale_images_train.csv\n",
        "!wget https://raw.githubusercontent.com/inaiogit/stage2test/main/test/processed_images.zip\n",
        "!mv color_images_test_public.csv color_images_train.csv grayscale_images_test_public.csv grayscale_images_train.csv data/\n",
        "!unzip processed_images.zip\n",
        "!rm processed_images.zip"
      ],
      "metadata": {
        "id": "XXvOB0oq_s5O",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "42a533db-0ab1-4ba8-8bc5-0c46943df92a"
      },
      "id": "XXvOB0oq_s5O",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "mkdir: cannot create directory ‘/content/data’: File exists\n",
            "--2025-03-02 06:27:57--  https://raw.githubusercontent.com/inaiogit/stage2test/main/test/color_images_test_public.csv\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 1684 (1.6K) [text/plain]\n",
            "Saving to: ‘color_images_test_public.csv’\n",
            "\n",
            "color_images_test_p 100%[===================>]   1.64K  --.-KB/s    in 0.003s  \n",
            "\n",
            "2025-03-02 06:27:58 (639 KB/s) - ‘color_images_test_public.csv’ saved [1684/1684]\n",
            "\n",
            "--2025-03-02 06:27:58--  https://raw.githubusercontent.com/inaiogit/stage2test/main/test/color_images_train.csv\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 2794 (2.7K) [text/plain]\n",
            "Saving to: ‘color_images_train.csv’\n",
            "\n",
            "color_images_train. 100%[===================>]   2.73K  --.-KB/s    in 0s      \n",
            "\n",
            "2025-03-02 06:27:58 (70.2 MB/s) - ‘color_images_train.csv’ saved [2794/2794]\n",
            "\n",
            "--2025-03-02 06:27:58--  https://raw.githubusercontent.com/inaiogit/stage2test/main/test/grayscale_images_test_public.csv\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 4254 (4.2K) [text/plain]\n",
            "Saving to: ‘grayscale_images_test_public.csv’\n",
            "\n",
            "grayscale_images_te 100%[===================>]   4.15K  --.-KB/s    in 0s      \n",
            "\n",
            "2025-03-02 06:27:59 (75.2 MB/s) - ‘grayscale_images_test_public.csv’ saved [4254/4254]\n",
            "\n",
            "--2025-03-02 06:27:59--  https://raw.githubusercontent.com/inaiogit/stage2test/main/test/grayscale_images_train.csv\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 8113 (7.9K) [text/plain]\n",
            "Saving to: ‘grayscale_images_train.csv’\n",
            "\n",
            "grayscale_images_tr 100%[===================>]   7.92K  --.-KB/s    in 0.01s   \n",
            "\n",
            "2025-03-02 06:28:00 (817 KB/s) - ‘grayscale_images_train.csv’ saved [8113/8113]\n",
            "\n",
            "--2025-03-02 06:28:00--  https://raw.githubusercontent.com/inaiogit/stage2test/main/test/processed_images.zip\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 2189957 (2.1M) [application/zip]\n",
            "Saving to: ‘processed_images.zip’\n",
            "\n",
            "processed_images.zi 100%[===================>]   2.09M  3.53MB/s    in 0.6s    \n",
            "\n",
            "2025-03-02 06:28:01 (3.53 MB/s) - ‘processed_images.zip’ saved [2189957/2189957]\n",
            "\n",
            "Archive:  processed_images.zip\n",
            "replace data/processed_images/small_image_f3fe7449.jpeg? [y]es, [n]o, [A]ll, [N]one, [r]ename: "
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9169d222-16c9-4aa6-baa8-daacc2092350",
      "metadata": {
        "id": "9169d222-16c9-4aa6-baa8-daacc2092350"
      },
      "source": [
        "<div style=\"color:blue\">\n",
        "    \n",
        "## **Q1: Build an Image Location Type Classifier** [9 pts]\n",
        "\n",
        "</div>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "542e062d-7d45-43dd-af2c-e46afcd773ab",
      "metadata": {
        "id": "542e062d-7d45-43dd-af2c-e46afcd773ab"
      },
      "source": [
        "<div style=\"color:blue\">\n",
        "\n",
        "### DATA\n",
        "\n",
        "You are provided with two datasets:\n",
        "\n",
        "- color_images_train: one image per row (image_path, label:\"Coast\" vs. \"Desert\")\n",
        "- grayscale_images_train: one image per row (image_path)\n",
        "- images/ : folder with image files\n",
        "  \n",
        "</div>\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "00efe6b6-4565-4020-a963-d864adbc9fc4",
      "metadata": {
        "id": "00efe6b6-4565-4020-a963-d864adbc9fc4"
      },
      "outputs": [],
      "source": [
        "# Training datasets\n",
        "color_train_path = \"data/color_images_train.csv\"  # Labeled colored images - two columns image_path, label\n",
        "grayscale_train_path = \"data/grayscale_images_train.csv\"  # Unlabeled grayscale images - single column image_path\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c65c0081-5125-4773-ab03-54c232c8ffdf",
      "metadata": {
        "id": "c65c0081-5125-4773-ab03-54c232c8ffdf"
      },
      "source": [
        "<div style=\"color:blue\">\n",
        "    \n",
        "### TASK\n",
        "\n",
        "Create two functions **learn_location_type_classifier** and  **classify_image** as per the signatures defined below:\n",
        "If you scroll down, you will see cells with the skeletal code that you need to flesh out.\n",
        "\n",
        "---\n",
        "#### **Function 1:`learn_location_type_classifier`**\n",
        "def learn_location_type_classifier(color_train_path: str, grayscale_train_path: str) -> Any:\n",
        "\n",
        "    \"\"\"\n",
        "    Train a classifier to distinguish between Coast and Desert images.\n",
        "    \n",
        "    Parameters:\n",
        "    - color_train_path (str): csv file containing file path and label ofcolor images.\n",
        "    - grayscale_train_path (str): csv file containing file paths of unlabeled grayscale images.\n",
        "    \n",
        "    Returns:\n",
        "    - model (Any): Trained classification model that includes preprocessing.\n",
        "    \"\"\"\n",
        "    pass\n",
        "\n",
        "\n",
        "#### **Function 2: `classify_image`**\n",
        "def classify_image(image_path: str, model: Any) -> str:\n",
        "\n",
        "    \"\"\"\n",
        "    Classify an image as either \"Coast\" or \"Desert\".\n",
        "    \n",
        "    Parameters:\n",
        "    - image_path (str): Path to the image file.\n",
        "    - model (Any): Trained classification model that includes preprocessing.\n",
        "    \n",
        "    Returns:\n",
        "    - label (str): Predicted label, either \"Coast\" or \"Desert\".\n",
        "    \n",
        "    - float: A probability score where:\n",
        "        - values close to 1 indicates high confidence the image is \"Coast\".\n",
        "        - values close to 0  indicates high confidence the image is \"Desert\".\n",
        "    \"\"\"\n",
        "    pass\n",
        "\n",
        "**Hints:**\n",
        "- Load and explore the data\n",
        "- Identify potentially discriminating features\n",
        "- Come with a strategy for building a classifier\n",
        "- Train models, test, and iterate\n",
        "\n",
        "</div>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4b2f9023-f431-440c-91d7-40bdb42362a7",
      "metadata": {
        "id": "4b2f9023-f431-440c-91d7-40bdb42362a7"
      },
      "source": [
        "<div style=\"color:blue\">\n",
        "    \n",
        "### HELPER CODE\n",
        "\n",
        "</div>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e2e0cf27-7edd-45ba-a809-d4ad84df8873",
      "metadata": {
        "id": "e2e0cf27-7edd-45ba-a809-d4ad84df8873"
      },
      "outputs": [],
      "source": [
        "# HELPER CODE\n",
        "# You may choose to use or modify any of the below code in your solution, but it is NOT mandatory\n",
        "def convert_to_grayscale(image_path: str, output_dir: str = None) -> str:\n",
        "    \"\"\"\n",
        "    Converts an image to grayscale and saves it with 'grayscale_' prefixed to the filename.\n",
        "\n",
        "    Parameters:\n",
        "    - image_path (str): Path to the original image file.\n",
        "    - output_dir (str, optional): Directory where the grayscale image will be saved.\n",
        "                                  If None, saves in the same directory as the input image.\n",
        "\n",
        "    Returns:\n",
        "    - str: Path to the saved grayscale image.\n",
        "    \"\"\"\n",
        "    # Read the image\n",
        "    image = cv2.imread(image_path)\n",
        "    if image is None:\n",
        "        raise ValueError(f\"Could not read image: {image_path}\")\n",
        "\n",
        "    # Convert to grayscale\n",
        "    grayscale_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n",
        "\n",
        "    # Construct output file name with prefix\n",
        "    base_name, ext = os.path.splitext(os.path.basename(image_path))\n",
        "    grayscale_filename = f\"grayscale_{base_name}{ext}\"\n",
        "\n",
        "    # Determine output directory\n",
        "    if output_dir is None:\n",
        "        output_dir = os.path.dirname(image_path)\n",
        "    os.makedirs(output_dir, exist_ok=True)\n",
        "\n",
        "    # Save grayscale image\n",
        "    output_path = os.path.join(output_dir, grayscale_filename)\n",
        "    cv2.imwrite(output_path, grayscale_image)\n",
        "\n",
        "    print(f\"Grayscale image saved to: {output_path}\")\n",
        "    return output_path\n",
        "\n",
        "\n",
        "\n",
        "def train_image_classifier(df: pd.DataFrame, epochs: int = 10, test_size: float = 0.2) -> Any:\n",
        "    IMG_SIZE = (128, 128)\n",
        "    BATCH_SIZE = 32\n",
        "\n",
        "    def load_images_labels(df):\n",
        "        images = []\n",
        "        labels = []\n",
        "        for _, row in df.iterrows():\n",
        "            img_path = row[\"image_path\"]\n",
        "            label = row[\"label\"]\n",
        "            if os.path.exists(img_path):\n",
        "                image = load_img(img_path, target_size=IMG_SIZE)\n",
        "                image = img_to_array(image) / 255.0\n",
        "                images.append(image)\n",
        "                labels.append(label)\n",
        "            else:\n",
        "                print(f\"Warning: {img_path} does not exist.\")\n",
        "        return np.array(images), np.array(labels)\n",
        "\n",
        "    images, labels = load_images_labels(df)\n",
        "\n",
        "    # Check if any images were loaded\n",
        "    if len(images) == 0:\n",
        "        raise ValueError(\"No images loaded. Please verify your CSV file and image paths.\")\n",
        "\n",
        "    label_encoder = LabelEncoder()\n",
        "    labels = label_encoder.fit_transform(labels)\n",
        "\n",
        "    X_train, X_val, y_train, y_val = train_test_split(\n",
        "        images, labels, test_size=test_size, random_state=42\n",
        "    )\n",
        "\n",
        "    model = tf.keras.models.Sequential([\n",
        "        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3)),\n",
        "        tf.keras.layers.MaxPooling2D(2, 2),\n",
        "        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),\n",
        "        tf.keras.layers.MaxPooling2D(2, 2),\n",
        "        tf.keras.layers.Flatten(),\n",
        "        tf.keras.layers.Dense(128, activation='relu'),\n",
        "        tf.keras.layers.Dropout(0.5),\n",
        "        tf.keras.layers.Dense(1, activation='sigmoid')\n",
        "    ])\n",
        "\n",
        "    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n",
        "    model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=epochs, batch_size=BATCH_SIZE)\n",
        "\n",
        "    test_loss, test_acc = model.evaluate(X_val, y_val)\n",
        "    print(f\"Validation Accuracy: {test_acc:.4f}\")\n",
        "\n",
        "    return model, label_encoder\n",
        "\n",
        "\n",
        "def classify_image(image_path: str, model: tf.keras.Model) -> float:\n",
        "    \"\"\"\n",
        "    Classifies a new image and returns a probability score (0 to 1).\n",
        "\n",
        "    Parameters:\n",
        "    - image_path (str): Path to the image to be classified.\n",
        "    - model (tf.keras.Model): Trained CNN model.\n",
        "\n",
        "    Returns:\n",
        "    - float: Probability score (0 to 1), where closer to 1 means one class and closer to 0 means the other.\n",
        "    \"\"\"\n",
        "    IMG_SIZE = (128, 128)\n",
        "\n",
        "    if not os.path.exists(image_path):\n",
        "        raise ValueError(f\"Image path does not exist: {image_path}\")\n",
        "\n",
        "    image = load_img(image_path, target_size=IMG_SIZE)\n",
        "    image = img_to_array(image) / 255.0\n",
        "    image = np.expand_dims(image, axis=0)\n",
        "\n",
        "    score = model.predict(image)[0][0]\n",
        "    return score\n",
        "\n",
        "\n",
        "def compute_image_similarity(image_path1: str, image_path2: str) -> dict:\n",
        "    \"\"\"\n",
        "    Computes similarity between two images using SSIM and MSE.\n",
        "\n",
        "    Parameters:\n",
        "    - image_path1 (str): Path to the first image.\n",
        "    - image_path2 (str): Path to the second image.\n",
        "\n",
        "    Returns:\n",
        "    - dict: Dictionary containing SSIM and MSE scores.\n",
        "    \"\"\"\n",
        "    # Load the images\n",
        "    image1 = cv2.imread(image_path1, cv2.IMREAD_GRAYSCALE)\n",
        "    image2 = cv2.imread(image_path2, cv2.IMREAD_GRAYSCALE)\n",
        "\n",
        "    if image1 is None or image2 is None:\n",
        "        raise ValueError(\"One or both image paths are invalid.\")\n",
        "\n",
        "    # Resize images to the same size\n",
        "    image1 = cv2.resize(image1, (128, 128))\n",
        "    image2 = cv2.resize(image2, (128, 128))\n",
        "\n",
        "    # Compute SSIM (Structural Similarity Index)\n",
        "    ssim_score = ssim(image1, image2)\n",
        "\n",
        "    # Compute MSE (Mean Squared Error)\n",
        "    mse_score = np.mean((image1 - image2) ** 2)\n",
        "\n",
        "    return {\"SSIM\": ssim_score, \"MSE\": mse_score}"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "458f567c-e991-495d-aed5-24cbec9a9bab",
      "metadata": {
        "id": "458f567c-e991-495d-aed5-24cbec9a9bab"
      },
      "source": [
        "<div style=\"color:red\">\n",
        "    \n",
        "### ANSWER\n",
        "\n",
        "</div>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "492e37a4-5ee8-4f66-a14a-32168ec40399",
      "metadata": {
        "id": "492e37a4-5ee8-4f66-a14a-32168ec40399"
      },
      "source": [
        "#### **EDIT: [4 pts]**\n",
        "#### You can jot initial notes here and flesh this out in more detail after the implementation.\n",
        "\n",
        "### **Describe Your Solution Approach**\n",
        "\n",
        "#### **• Data Exploration Approach** [0.5 pt]\n",
        "  - Review the provided CSV files to verify that image paths for both labeled (color) and unlabeled (grayscale) images are correct.\n",
        "  - Visualize a few samples from each set to understand image quality and class distinctions.\n",
        "\n",
        "#### **• Instance Representation/Preprocessing** [0.5 pt]\n",
        "  - Resize all images to a consistent size (128×128) and normalize pixel values to the range [0, 1].\n",
        "  - Maintain color images as is, while converting grayscale images as needed using helper functions.\n",
        "\n",
        "#### **• Modeling Strategy** [2 pt]\n",
        "  - Train an initial CNN classifier using the 50 labeled color images.\n",
        "  - Apply a self-training strategy: use the initial model to pseudo-label the 150 unlabeled grayscale images when predictions are highly confident.\n",
        "  - Combine the pseudo-labeled images with the original labeled dataset and retrain the CNN model to improve overall performance.\n",
        "\n",
        "#### **• Specific Modeling Choices** [1 pt]\n",
        "  - Use a simple CNN architecture with two convolutional layers followed by max-pooling, a flattening layer, dense layers, and dropout for regularization.\n",
        "  - Employ a sigmoid activation in the final layer and binary cross-entropy as the loss function.\n",
        "  - Set high confidence thresholds (e.g., ≥0.9 for \"Coast\" and ≤0.1 for \"Desert\") to ensure reliable pseudo-labeling.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "924b0f7e-e532-4d53-bb26-457c14685e49",
      "metadata": {
        "id": "924b0f7e-e532-4d53-bb26-457c14685e49",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "outputId": "3077138c-8652-4d4f-e890-9aaffb34365e"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "                                        image_path  label\n",
              "0  data/processed_images/small_image_f3f35850.jpeg  Coast\n",
              "1  data/processed_images/small_image_69362bc2.jpeg  Coast\n",
              "2  data/processed_images/small_image_a6916309.jpeg  Coast\n",
              "3  data/processed_images/small_image_9ea6f817.jpeg  Coast\n",
              "4  data/processed_images/small_image_6466e46c.jpeg  Coast"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-28fa8c9f-80be-4012-86f6-3cb6dd69fdb6\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>image_path</th>\n",
              "      <th>label</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>data/processed_images/small_image_f3f35850.jpeg</td>\n",
              "      <td>Coast</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>data/processed_images/small_image_69362bc2.jpeg</td>\n",
              "      <td>Coast</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>data/processed_images/small_image_a6916309.jpeg</td>\n",
              "      <td>Coast</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>data/processed_images/small_image_9ea6f817.jpeg</td>\n",
              "      <td>Coast</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>data/processed_images/small_image_6466e46c.jpeg</td>\n",
              "      <td>Coast</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-28fa8c9f-80be-4012-86f6-3cb6dd69fdb6')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-28fa8c9f-80be-4012-86f6-3cb6dd69fdb6 button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-28fa8c9f-80be-4012-86f6-3cb6dd69fdb6');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "<div id=\"df-23f6e94f-4c49-4051-8ea9-2127f61f7536\">\n",
              "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-23f6e94f-4c49-4051-8ea9-2127f61f7536')\"\n",
              "            title=\"Suggest charts\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "  </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "  <script>\n",
              "    async function quickchart(key) {\n",
              "      const quickchartButtonEl =\n",
              "        document.querySelector('#' + key + ' button');\n",
              "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "      try {\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      } catch (error) {\n",
              "        console.error('Error during call to suggestCharts:', error);\n",
              "      }\n",
              "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "    }\n",
              "    (() => {\n",
              "      let quickchartButtonEl =\n",
              "        document.querySelector('#df-23f6e94f-4c49-4051-8ea9-2127f61f7536 button');\n",
              "      quickchartButtonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "    })();\n",
              "  </script>\n",
              "</div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "variable_name": "color_data",
              "summary": "{\n  \"name\": \"color_data\",\n  \"rows\": 50,\n  \"fields\": [\n    {\n      \"column\": \"image_path\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 50,\n        \"samples\": [\n          \"data/processed_images/small_image_2711b89e.jpeg\",\n          \"data/processed_images/small_image_98f77472.jpeg\",\n          \"data/processed_images/small_image_66a98f82.jpeg\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"label\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 2,\n        \"samples\": [\n          \"Desert\",\n          \"Coast\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {},
          "execution_count": 19
        }
      ],
      "source": [
        "# EDIT: [0.5 pt]\n",
        "# Add your data exploration code here\n",
        "color_data = pd.read_csv(color_train_path)\n",
        "color_data.head()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "color_data.label.unique()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5clKagpUf1Fr",
        "outputId": "fa8be928-6702-4148-d26f-b96b2b89d13b"
      },
      "id": "5clKagpUf1Fr",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array(['Coast', 'Desert'], dtype=object)"
            ]
          },
          "metadata": {},
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "color_data.info()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BhfJJFjwgGoi",
        "outputId": "029b7708-f05b-450d-e6dd-9045f0703961"
      },
      "id": "BhfJJFjwgGoi",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "<class 'pandas.core.frame.DataFrame'>\n",
            "RangeIndex: 50 entries, 0 to 49\n",
            "Data columns (total 2 columns):\n",
            " #   Column      Non-Null Count  Dtype \n",
            "---  ------      --------------  ----- \n",
            " 0   image_path  50 non-null     object\n",
            " 1   label       50 non-null     object\n",
            "dtypes: object(2)\n",
            "memory usage: 932.0+ bytes\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "grayscale_data = pd.read_csv(grayscale_train_path)\n",
        "grayscale_data.head()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "wzhqc5KSf3cx",
        "outputId": "cdbd5dad-dead-449f-c4de-ac617cd60ec7"
      },
      "id": "wzhqc5KSf3cx",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "                                          image_path\n",
              "0  data/processed_images/small_gray_image_3b65eec...\n",
              "1  data/processed_images/small_gray_image_0231789...\n",
              "2  data/processed_images/small_gray_image_044d8f4...\n",
              "3  data/processed_images/small_gray_image_6371e1b...\n",
              "4  data/processed_images/small_gray_image_2168812..."
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-ff64e79b-3c46-4594-ae87-5a2c946b72b3\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>image_path</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>data/processed_images/small_gray_image_3b65eec...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>data/processed_images/small_gray_image_0231789...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>data/processed_images/small_gray_image_044d8f4...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>data/processed_images/small_gray_image_6371e1b...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>data/processed_images/small_gray_image_2168812...</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-ff64e79b-3c46-4594-ae87-5a2c946b72b3')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-ff64e79b-3c46-4594-ae87-5a2c946b72b3 button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-ff64e79b-3c46-4594-ae87-5a2c946b72b3');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "<div id=\"df-33f76ca8-deff-4af2-b08e-ffae297cf36a\">\n",
              "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-33f76ca8-deff-4af2-b08e-ffae297cf36a')\"\n",
              "            title=\"Suggest charts\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "  </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "  <script>\n",
              "    async function quickchart(key) {\n",
              "      const quickchartButtonEl =\n",
              "        document.querySelector('#' + key + ' button');\n",
              "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "      try {\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      } catch (error) {\n",
              "        console.error('Error during call to suggestCharts:', error);\n",
              "      }\n",
              "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "    }\n",
              "    (() => {\n",
              "      let quickchartButtonEl =\n",
              "        document.querySelector('#df-33f76ca8-deff-4af2-b08e-ffae297cf36a button');\n",
              "      quickchartButtonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "    })();\n",
              "  </script>\n",
              "</div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "variable_name": "grayscale_data",
              "summary": "{\n  \"name\": \"grayscale_data\",\n  \"rows\": 150,\n  \"fields\": [\n    {\n      \"column\": \"image_path\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 150,\n        \"samples\": [\n          \"data/processed_images/small_gray_image_3d5c8b17.jpeg\",\n          \"data/processed_images/small_gray_image_c207a555.jpeg\",\n          \"data/processed_images/small_gray_image_490b1233.jpeg\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {},
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "grayscale_data.info()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iUFV41vkgI98",
        "outputId": "1fa0fcd4-8793-400e-b627-98301a09fc86"
      },
      "id": "iUFV41vkgI98",
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "<class 'pandas.core.frame.DataFrame'>\n",
            "RangeIndex: 150 entries, 0 to 149\n",
            "Data columns (total 1 columns):\n",
            " #   Column      Non-Null Count  Dtype \n",
            "---  ------      --------------  ----- \n",
            " 0   image_path  150 non-null    object\n",
            "dtypes: object(1)\n",
            "memory usage: 1.3+ KB\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c5154341-d86c-457d-a291-eeb0f0cdeca7",
      "metadata": {
        "id": "c5154341-d86c-457d-a291-eeb0f0cdeca7"
      },
      "outputs": [],
      "source": [
        "# EDIT: [1 pt]\n",
        "# Add any instance representation/preprocessing code here\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "339aaeb6-e2d6-40cb-9790-6845357b2ab8",
      "metadata": {
        "id": "339aaeb6-e2d6-40cb-9790-6845357b2ab8"
      },
      "outputs": [],
      "source": [
        "# EDIT: [O pts]\n",
        "# Add any additional code that you need for your classifier implementation\n",
        "# Code in this cell will be evaluated with points assigned to the classifier training and prediction implementation cells\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b4c6671f-655b-4d12-8d76-29fba5ff992f",
      "metadata": {
        "id": "b4c6671f-655b-4d12-8d76-29fba5ff992f"
      },
      "outputs": [],
      "source": [
        "# EDIT: [2.5 pts]\n",
        "# Implement the classifier training\n",
        "# You can choose to  edit the partial implementation but keep the signature same\n",
        "\n",
        "\n",
        "def learn_location_type_classifier(color_train_path: str, grayscale_train_path: str) -> dict:\n",
        "    \"\"\"\n",
        "    Train a classifier to distinguish between Coast and Desert images.\n",
        "\n",
        "    We first train on the 50 labeled color images. Then we use the trained model to\n",
        "    pseudo–label the 150 unlabeled grayscale images when the prediction confidence is high.\n",
        "    We then retrain the model on the combined dataset.\n",
        "\n",
        "    Parameters:\n",
        "        - color_train_path (str): CSV file containing file path and label of color images.\n",
        "          (CSV is assumed to have columns: \"image_path\" and \"label\")\n",
        "        - grayscale_train_path (str): CSV file containing file paths of unlabeled grayscale images.\n",
        "          (CSV is assumed to have a column: \"image_path\")\n",
        "\n",
        "    Returns:\n",
        "        - dict: Contains the final trained model and the label encoder, e.g.,\n",
        "                {\"model\": cnn_model, \"label_encoder\": label_encoder}\n",
        "    \"\"\"\n",
        "    # Load CSV files\n",
        "    df_color = pd.read_csv(color_train_path)\n",
        "    df_gray = pd.read_csv(grayscale_train_path)\n",
        "\n",
        "    # Train an initial classifier on the labeled (color) images.\n",
        "    print(\"Training initial model on labeled color images...\")\n",
        "    cnn_model, label_encoder = train_image_classifier(df_color, epochs=10)\n",
        "\n",
        "    # Self-training: pseudo-label the grayscale images\n",
        "    pseudo_data = []\n",
        "    high_conf_threshold = 0.9  # high confidence for \"Coast\"\n",
        "    low_conf_threshold = 0.1   # high confidence for \"Desert\"\n",
        "\n",
        "    print(\"Pseudo-labeling unlabeled grayscale images...\")\n",
        "    for idx, row in df_gray.iterrows():\n",
        "        image_path = row[\"image_path\"]\n",
        "        try:\n",
        "            prob = _predict_probability(image_path, cnn_model)\n",
        "        except Exception as e:\n",
        "            print(f\"Skipping {image_path} due to error: {e}\")\n",
        "            continue\n",
        "\n",
        "        # If the model is very confident, assign a pseudo-label\n",
        "        if prob >= high_conf_threshold:\n",
        "            pseudo_label = \"Coast\"\n",
        "            pseudo_data.append({\"image_path\": image_path, \"label\": pseudo_label})\n",
        "        elif prob <= low_conf_threshold:\n",
        "            pseudo_label = \"Desert\"\n",
        "            pseudo_data.append({\"image_path\": image_path, \"label\": pseudo_label})\n",
        "        # Otherwise, do not add the image to the training set.\n",
        "\n",
        "    if pseudo_data:\n",
        "        df_pseudo = pd.DataFrame(pseudo_data)\n",
        "        df_combined = pd.concat([df_color, df_pseudo], ignore_index=True)\n",
        "        print(f\"Retraining with {len(df_pseudo)} pseudo-labeled grayscale images (total training samples: {len(df_combined)})...\")\n",
        "        cnn_model, label_encoder = train_image_classifier(df_combined, epochs=10)\n",
        "    else:\n",
        "        print(\"No pseudo-labeled grayscale images met the confidence threshold.\")\n",
        "\n",
        "    # Return a dictionary with the model and label encoder\n",
        "    return {\"model\": cnn_model, \"label_encoder\": label_encoder}\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ab367e1c-b9ab-4f8e-92e6-d5d0d60f7bd6",
      "metadata": {
        "id": "ab367e1c-b9ab-4f8e-92e6-d5d0d60f7bd6"
      },
      "outputs": [],
      "source": [
        "# EDIT: [1 pt]\n",
        "# Implement the classifier prediction function\n",
        "# You can choose to  edit the partial implementation but keep the signature same\n",
        "\n",
        "def classify_image(image_path: str, classifier: dict) -> (str, float):\n",
        "    \"\"\"\n",
        "    Classify an image as either \"Coast\" or \"Desert\".\n",
        "\n",
        "    This function loads the image, preprocesses it, and uses the trained CNN model to predict\n",
        "    a probability score. If the probability is >= 0.5, the image is labeled as \"Coast\";\n",
        "    otherwise it is labeled as \"Desert\".\n",
        "\n",
        "    Parameters:\n",
        "        - image_path (str): Path to the image file.\n",
        "        - classifier (dict): A dictionary containing the trained model and label encoder,\n",
        "                             as returned by learn_location_type_classifier.\n",
        "\n",
        "    Returns:\n",
        "        - label (str): Predicted label, either \"Coast\" or \"Desert\".\n",
        "        - prob (float): Probability score where values close to 1 indicate \"Coast\" and values\n",
        "                        close to 0 indicate \"Desert\".\n",
        "    \"\"\"\n",
        "    IMG_SIZE = (128, 128)\n",
        "    cnn_model = classifier[\"model\"]\n",
        "    # Here we assume that load_img loads the image in RGB mode even if it is grayscale.\n",
        "    image = load_img(image_path, target_size=IMG_SIZE)\n",
        "    image = img_to_array(image) / 255.0\n",
        "    image = np.expand_dims(image, axis=0)\n",
        "\n",
        "    prob = cnn_model.predict(image)[0][0]\n",
        "    label = \"Coast\" if prob >= 0.5 else \"Desert\"\n",
        "    return label, prob\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "48467cc6-38ed-4507-bd17-33ff7d7a800d",
      "metadata": {
        "id": "48467cc6-38ed-4507-bd17-33ff7d7a800d"
      },
      "source": [
        "<div style=\"color:blue\">\n",
        "    \n",
        "## **Q2: Test Your Classifier - Public Set** [2 pts]\n",
        "\n",
        "</div>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "b45f6169-64d8-465c-beda-1c1580e557ad",
      "metadata": {
        "id": "b45f6169-64d8-465c-beda-1c1580e557ad"
      },
      "source": [
        "<div style=\"color:blue\">\n",
        "\n",
        "### DATA\n",
        "\n",
        "We now have datasets for patrol time classification\n",
        "\n",
        "**Use this as your test set. Do not use this for training or validation!**\n",
        "\n",
        "- color_test_public_path: Labeled color images  (one image per row, two columns - image path, label)\n",
        "- grayscale_test_public_path:  Labeled grayscale images (one image per row, two columns - image path, label)\n",
        "\n",
        "</div\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b03b418c-c878-439a-b4b7-bdbc443314a7",
      "metadata": {
        "id": "b03b418c-c878-439a-b4b7-bdbc443314a7"
      },
      "outputs": [],
      "source": [
        "# Public Test datasets\n",
        "\n",
        "color_test_public_path =\"data/color_images_test_public.csv\" # Labeled color images  (one image per row, two columns - image path, label)\n",
        "grayscale_test_public_path=\"data/grayscale_images_test_public.csv\" # Labeled grayscale images [one image per row, two columns - image path, label\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a19f3773-c3ef-4508-b48d-6fc2653ead17",
      "metadata": {
        "id": "a19f3773-c3ef-4508-b48d-6fc2653ead17"
      },
      "source": [
        "<div style=\"color:blue\">\n",
        "    \n",
        "### TASK\n",
        "\n",
        "Execute the code below as is with your implementation of **learn_location_type_classifier** and **classify_image** to test your classifier\n",
        "\n",
        "- Evaluate your model on the two public test splits (color_test_public and grayscale_test_public)\n",
        "- Generate the ROC (Receiver Operating Characteristic) curve to assess your classifier's performance.\n",
        "- Compute the AUC (Area Under Curve) value as a performance metric.\n",
        "  \n",
        "</div>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3f3a89bf-c211-4c69-9fb7-a041a7331ce4",
      "metadata": {
        "id": "3f3a89bf-c211-4c69-9fb7-a041a7331ce4"
      },
      "source": [
        "<div style=\"color:blue\">\n",
        "    \n",
        "### HELPER CODE\n",
        "\n",
        "</div>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2b861f33-3a7d-480c-8478-7c16b5294cc3",
      "metadata": {
        "id": "2b861f33-3a7d-480c-8478-7c16b5294cc3"
      },
      "outputs": [],
      "source": [
        "# DO NOT MODIFY\n",
        "\n",
        "# Use these functions directly since these are meant for evaluation\n",
        "\n",
        "def plot_roc_curve(y_true, y_scores):\n",
        "    fpr, tpr, _ = roc_curve(y_true, y_scores)\n",
        "    roc_auc = auc(fpr, tpr)\n",
        "\n",
        "    plt.figure(figsize=(8, 6))\n",
        "    plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')\n",
        "    plt.plot([0, 1], [0, 1], color='gray', linestyle='--')  # Random classifier line\n",
        "    plt.xlabel('False Positive Rate')\n",
        "    plt.ylabel('True Positive Rate')\n",
        "    plt.title('ROC Curve')\n",
        "    plt.legend(loc='lower right')\n",
        "    plt.show()\n",
        "\n",
        "    print(f\"AUC Value: {roc_auc:.4f}\")\n",
        "    return roc_auc\n",
        "\n",
        "def evaluate_classifier(dataset_path: str, model) -> float:\n",
        "    \"\"\"\n",
        "    Loads a labeled dataset, applies the classify_poem function with the given model,\n",
        "    and computes the ROC curve & AUC.\n",
        "\n",
        "    Parameters:\n",
        "    - dataset_path (str): Path to the CSV dataset containing 'poem' and 'label' columns.\n",
        "    - model (Any): A trained model that will be used by classify_poem.\n",
        "\n",
        "    Outputs:\n",
        "    - Plots the ROC curve\n",
        "    - Prints the AUC value\n",
        "\n",
        "    Returns:\n",
        "    - float: The AUC (Area Under the Curve) value.\n",
        "    \"\"\"\n",
        "\n",
        "    # Load the dataset\n",
        "    df = pd.read_csv(dataset_path)\n",
        "\n",
        "    # Ensure required columns exist\n",
        "    if \"image_path\" not in df.columns or \"label\" not in df.columns:\n",
        "        raise ValueError(\"Dataset must contain 'image_path' and 'label' columns\")\n",
        "\n",
        "    # Apply the classifier to get prediction scores\n",
        "    df[\"predicted_score\"] = df[\"image_path\"].apply(lambda x: classify_image(x,model))\n",
        "\n",
        "    # Extract true labels and predicted scores\n",
        "    y_true = df[\"label\"].values  # Ground truth (1 = A, 0 = B)\n",
        "    y_scores = df[\"predicted_score\"].values  # Model output scores\n",
        "\n",
        "    # Compute ROC curve and AUC\n",
        "    plot_roc_curve(y_true, y_scores)\n",
        "\n",
        "    return roc_auc"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a114f9f5-f1b3-4a1a-ba55-7ce0c573ef08",
      "metadata": {
        "id": "a114f9f5-f1b3-4a1a-ba55-7ce0c573ef08",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "3f4975dd-687d-42a5-b678-9dcbd69a1cb4"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Training initial model on labeled color images...\n",
            "Warning: data/processed_images/small_image_f3f35850.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_69362bc2.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_a6916309.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_9ea6f817.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_6466e46c.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_ca301338.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_5cd181f6.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_159c0670.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_127d7c86.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_8a63c8f9.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_95ff900d.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_eb1545b3.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_d9fead28.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_2711b89e.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_d1d3f434.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_f2258757.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_732b256f.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_a3538469.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_58c0eda5.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_6fd14dae.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_3a7207cd.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_2e128f9a.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_30cc1073.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_ee754b34.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_1d627739.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_d4257046.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_824362ba.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_d8eedb03.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_d3b5a15f.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_1eadc828.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_66a98f82.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_35e32b41.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_0870e732.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_543223ab.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_1860b980.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_0b7d75a6.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_7e5a62cf.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_b4deb89f.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_6ec45ecf.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_98f77472.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_d1995b4c.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_386f2ce3.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_32523c7f.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_97d171f1.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_da44d0cf.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_9d4f8ef9.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_6851b0f5.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_74ce9e22.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_98523890.jpeg does not exist.\n",
            "Warning: data/processed_images/small_image_79b9a504.jpeg does not exist.\n"
          ]
        },
        {
          "output_type": "error",
          "ename": "ValueError",
          "evalue": "No images loaded. Please verify your CSV file and image paths.",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-23-60fcb2d1e3e0>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# DO NOT MODIFY\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;31m# Run this code and observe the ROC Curve  and AUC [pts depend on performance range]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlearn_location_type_classifier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcolor_train_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mgrayscale_train_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0mauc_color_test_public\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mevaluate_classifier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcolor_test_public_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mauc_grayscale_test_public\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mevaluate_classifier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrayscale_test_public_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-12-38bc3b35f33e>\u001b[0m in \u001b[0;36mlearn_location_type_classifier\u001b[0;34m(color_train_path, grayscale_train_path)\u001b[0m\n\u001b[1;32m     28\u001b[0m     \u001b[0;31m# Train an initial classifier on the labeled (color) images.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Training initial model on labeled color images...\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m     \u001b[0mcnn_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_encoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_image_classifier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdf_color\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     32\u001b[0m     \u001b[0;31m# Self-training: pseudo-label the grayscale images\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-18-aaee21e72404>\u001b[0m in \u001b[0;36mtrain_image_classifier\u001b[0;34m(df, epochs, test_size)\u001b[0m\n\u001b[1;32m     62\u001b[0m     \u001b[0;31m# Check if any images were loaded\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     63\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 64\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"No images loaded. Please verify your CSV file and image paths.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     65\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     66\u001b[0m     \u001b[0mlabel_encoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLabelEncoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mValueError\u001b[0m: No images loaded. Please verify your CSV file and image paths."
          ]
        }
      ],
      "source": [
        "# DO NOT MODIFY\n",
        "# Run this code and observe the ROC Curve  and AUC [pts depend on performance range]\n",
        "model = learn_location_type_classifier(color_train_path,grayscale_train_path)\n",
        "auc_color_test_public=evaluate_classifier(color_test_public_path, model)\n",
        "auc_grayscale_test_public=evaluate_classifier(grayscale_test_public_path, model)\n",
        "\n",
        "print(auc_color_test_public)\n",
        "print(auc_grayscale_test_public)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "d2093bb7-9e04-4f7e-9e0f-6bd3cd162bab",
      "metadata": {
        "id": "d2093bb7-9e04-4f7e-9e0f-6bd3cd162bab"
      },
      "source": [
        "<div style=\"color:red\">\n",
        "    \n",
        "### ANSWER\n",
        "\n",
        "</div>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "40a1cb92-1a56-4277-b6ad-51dda9dcbd9d",
      "metadata": {
        "id": "40a1cb92-1a56-4277-b6ad-51dda9dcbd9d"
      },
      "source": [
        "#### **EDIT: [2 pts]**\n",
        "#### You can jot down the test datasets AUC below\n",
        "\n",
        "### **Color Image Public Test Set AUC** [varying points for different range]\n",
        "  -   \n",
        "  -   \n",
        "\n",
        "### **Grayscale Image Public Test Set AUC** [varying points for different range]\n",
        "  -   \n",
        "  -   \n",
        "\n",
        "### **Any Additional Observations**\n",
        "  -   \n",
        "  -   \n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "15e0f606-9220-4647-9165-2077c2c9edc5",
      "metadata": {
        "id": "15e0f606-9220-4647-9165-2077c2c9edc5"
      },
      "source": [
        "<div style=\"color:red\">\n",
        "    \n",
        "## YOU CAN STOP THE TEST HERE -- BELOW EVALUATION TO BE PERFORMED BY INAIO\n",
        "\n",
        "</div>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a66dfd61-0eaa-405e-9ead-8fb4d0f5cbb2",
      "metadata": {
        "id": "a66dfd61-0eaa-405e-9ead-8fb4d0f5cbb2"
      },
      "source": [
        "<div style=\"color:blue\">\n",
        "    \n",
        "## **Q3: Test Your Classifier - Secret Set** [3 pts]\n",
        "\n",
        "- Same metrics as Public Dataset\n",
        "</div>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7be6072f-ea21-48f3-b905-2eeb2d36fdfb",
      "metadata": {
        "id": "7be6072f-ea21-48f3-b905-2eeb2d36fdfb"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.18"
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}