{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "P_JIQE4w9xbB" }, "source": [ "# HW1: Logistic Regression\n", "\n", "This class is about models and algorithms for discrete data. This homework will have all 3 ingredients:\n", "* **Data**: the results from all college football games in the 2023 season\n", "* **Model**: The *Bradely-Terry* model for predicting the winners of football game. The Bradley-Terry model is just logistic regression.\n", "* **Algorithm**: We will implement two ways of fitting logistic regression: gradient descent and Newton's method" ] }, { "cell_type": "markdown", "metadata": { "id": "oi2v2m5yCJE9" }, "source": [ "## The Bradley-Terry Model\n", "\n", "In the Bradley-Terry Model, we give team $k$ a team-effect $\\beta_k$. Basically, higher $\\beta_k$ (relatively speaking), means that team $k$ is a better team.\n", "The Bradley-Terry model formalizes this intution by modeling the log odds of team $k$ beating team $k'$ by the difference in their team effects, $\\beta_k - \\beta_{k'}$.\n", "\n", "Let $i = 1,\\ldots, n$ index games, and let $h(i) \\in \\{1,\\ldots,K\\}$ and $a(i) \\in \\{1,\\ldots,K\\}$ denote the indices of the home and away teams, respectively.\n", "Let $Y_i \\in \\{0,1\\}$ denote whether the home team won.\n", "Under the Bradley-Terry model,\n", "\\begin{equation*}\n", " Y_i \\sim \\mathrm{Bern}\\big(\\sigma(\\beta_{h(i)} - \\beta_{a(i)}) \\big),\n", "\\end{equation*}\n", "where $\\sigma(\\cdot)$ is the sigmoid function. We can view this model as a logistic regression model with covariates $x_i \\in \\mathbb{R}^K$ where,\n", "\\begin{align*}\n", "x_{i,k} &=\n", "\\begin{cases}\n", "+1 &\\text{if } h(i) = k \\\\\n", "-1 &\\text{if } a(i) = k \\\\\n", "0 &\\text{o.w.},\n", "\\end{cases}\n", "\\end{align*}\n", "and parameters $\\beta \\in \\mathbb{R}^K$." ] }, { "cell_type": "markdown", "metadata": { "id": "toIIF0ej-a7I" }, "source": [ "## Data\n", "\n", "We use the results of college football games in the fall 2023 season, which are available from the course github page and loaded for you below.\n", "\n", "The data comes as a list of the outcomes of individual games. You'll need to wrangle the data to get it into a format that you can feed into the Bradley-Terry model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qvTw_232nr-v" }, "outputs": [], "source": [ "import torch\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 638 }, "id": "WIYCdEBqnvJG", "outputId": "00e407b9-75af-46de-be25-bec38f06f02d" }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " | Id | \n", "Season | \n", "Week | \n", "Season Type | \n", "Start Date | \n", "Start Time Tbd | \n", "Completed | \n", "Neutral Site | \n", "Conference Game | \n", "Attendance | \n", "... | \n", "Away Conference | \n", "Away Division | \n", "Away Points | \n", "Away Line Scores | \n", "Away Post Win Prob | \n", "Away Pregame Elo | \n", "Away Postgame Elo | \n", "Excitement Index | \n", "Highlights | \n", "Notes | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "401550883 | \n", "2023 | \n", "1 | \n", "regular | \n", "2023-08-26T17:00:00.000Z | \n", "False | \n", "True | \n", "False | \n", "False | \n", "NaN | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
1 | \n", "401525434 | \n", "2023 | \n", "1 | \n", "regular | \n", "2023-08-26T18:30:00.000Z | \n", "False | \n", "True | \n", "True | \n", "False | \n", "49000.0 | \n", "... | \n", "American Athletic | \n", "fbs | \n", "3.0 | \n", "NaN | \n", "0.001042 | \n", "1471.0 | \n", "1385.0 | \n", "1.346908 | \n", "NaN | \n", "NaN | \n", "
2 | \n", "401540199 | \n", "2023 | \n", "1 | \n", "regular | \n", "2023-08-26T19:30:00.000Z | \n", "False | \n", "True | \n", "True | \n", "False | \n", "NaN | \n", "... | \n", "UAC | \n", "fcs | \n", "7.0 | \n", "NaN | \n", "0.025849 | \n", "NaN | \n", "NaN | \n", "6.896909 | \n", "NaN | \n", "NaN | \n", "
3 | \n", "401520145 | \n", "2023 | \n", "1 | \n", "regular | \n", "2023-08-26T21:30:00.000Z | \n", "False | \n", "True | \n", "False | \n", "True | \n", "17982.0 | \n", "... | \n", "Conference USA | \n", "fbs | \n", "14.0 | \n", "NaN | \n", "0.591999 | \n", "1369.0 | \n", "1370.0 | \n", "6.821333 | \n", "NaN | \n", "NaN | \n", "
4 | \n", "401525450 | \n", "2023 | \n", "1 | \n", "regular | \n", "2023-08-26T23:00:00.000Z | \n", "False | \n", "True | \n", "False | \n", "False | \n", "15356.0 | \n", "... | \n", "FBS Independents | \n", "fbs | \n", "41.0 | \n", "NaN | \n", "0.760751 | \n", "1074.0 | \n", "1122.0 | \n", "5.311493 | \n", "NaN | \n", "NaN | \n", "
5 | \n", "401532392 | \n", "2023 | \n", "1 | \n", "regular | \n", "2023-08-26T23:00:00.000Z | \n", "False | \n", "True | \n", "False | \n", "False | \n", "23867.0 | \n", "... | \n", "Mid-American | \n", "fbs | \n", "13.0 | \n", "NaN | \n", "0.045531 | \n", "1482.0 | \n", "1473.0 | \n", "6.547378 | \n", "NaN | \n", "NaN | \n", "
6 | \n", "401540628 | \n", "2023 | \n", "1 | \n", "regular | \n", "2023-08-26T23:00:00.000Z | \n", "False | \n", "True | \n", "False | \n", "False | \n", "NaN | \n", "... | \n", "Patriot | \n", "fcs | \n", "13.0 | \n", "NaN | \n", "0.077483 | \n", "NaN | \n", "NaN | \n", "5.608758 | \n", "NaN | \n", "NaN | \n", "
7 | \n", "401520147 | \n", "2023 | \n", "1 | \n", "regular | \n", "2023-08-26T23:30:00.000Z | \n", "False | \n", "True | \n", "False | \n", "False | \n", "21407.0 | \n", "... | \n", "Mountain West | \n", "fbs | \n", "28.0 | \n", "NaN | \n", "0.819154 | \n", "1246.0 | \n", "1241.0 | \n", "5.282033 | \n", "NaN | \n", "NaN | \n", "
8 | \n", "401539999 | \n", "2023 | \n", "1 | \n", "regular | \n", "2023-08-26T23:30:00.000Z | \n", "False | \n", "True | \n", "True | \n", "False | \n", "NaN | \n", "... | \n", "MEAC | \n", "fcs | \n", "7.0 | \n", "NaN | \n", "0.001097 | \n", "NaN | \n", "NaN | \n", "3.122344 | \n", "NaN | \n", "NaN | \n", "
9 | \n", "401523986 | \n", "2023 | \n", "1 | \n", "regular | \n", "2023-08-27T00:00:00.000Z | \n", "False | \n", "True | \n", "False | \n", "False | \n", "63411.0 | \n", "... | \n", "Mountain West | \n", "fbs | \n", "28.0 | \n", "NaN | \n", "0.001769 | \n", "1462.0 | \n", "1412.0 | \n", "1.698730 | \n", "NaN | \n", "NaN | \n", "
10 rows × 33 columns
\n", "