{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "IoYdLutn7dUa" }, "source": [ "# Latent Dirichlet Allocation\n", "\n", "Recall the following generatize model for LDA. Suppose we have $K$ topics and $N$ documents.\n", "\n", "For each topic $k =1,\\ldots,K$, draw a topic \n", "\n", "$$\\theta_k \\sim \\text{Dir}(\\phi)$$\n", "\n", "Then, for each document $n = 1,\\ldots, N$, draw topic proportions \n", "\n", "$$\\pi_n \\sim \\text{Dir}(\\alpha)$$\n", "\n", "Finally, for each word $l$ in document $n$, first draw a topic assignment \n", "\n", "$$\n", "z_{n,d} \\mid \\pi_n \\sim \\text{Cat}(\\pi_n)\n", "$$\n", "\n", "and draw a word\n", "\n", "$$\n", "x_{n,d} \\mid z_{n,d} \\sim \\text{Cat}(\\theta_{z_{n,d}})\n", "$$\n", "\n", "As mentioned in the lecture notes, while this formulation is easier to present, it's more efficient to represent the documents as sparse vectors of _word counts_, $\\mathbf{y}_n \\in \\mathbb{N}^V$ where $y_{n,v} = \\sum_{d=1}^D \\mathbb{I}[x_{n,d} = v]$. \n", "\n", "This notebook studies Federalist papers in their entirety. We've provided a $N \\times V$ dataframe of the essays represented as word counts. The rows of the data frame correspond to the 85 individual essays and the columns correspond to the 5320 words in the vocabulary. We have already preprocessed the raw essays to remove very common and very infrequent words.\n", "\n", "Using this data, we will fit a topic model and do some analysis." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "GFN1-wvIue3K" }, "outputs": [], "source": [ "import torch\n", "from torch.distributions import Dirichlet, Multinomial, Categorical\n", "import pandas as pd \n", "\n", "import matplotlib.pyplot as plt\n", "from tqdm.auto import trange" ] }, { "cell_type": "markdown", "metadata": { "id": "DP4qzh_Kue3L" }, "source": [ "## Load the data\n", "\n", "We've already tokenized the text and created a bag-of-words representation of the corpus. We removed words from the vocabulary that occur in more than 95% of the essays or only appear in 1 essay." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2kDPPoqpvQ-x", "outputId": "4a14bcf1-05a1-441b-838b-68c69cb8a241", "tags": [ "hide-output" ] }, "outputs": [], "source": [ "# Download the data\n", "!wget -nc https://www.dropbox.com/s/p6jb2cw5w5626pl/tokenized_fed.csv\n", "!wget -nc https://www.dropbox.com/s/ftedra0jyk1j3hx/authorship.csv" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 487 }, "id": "Ig2BB1-Eue3M", "outputId": "8d8cf501-7c6e-42ce-875d-0e04e337c164" }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " | unequivocal | \n", "experience | \n", "inefficacy | \n", "subsisting | \n", "federal | \n", "called | \n", "deliberate | \n", "new | \n", "constitution | \n", "united | \n", "... | \n", "chancery | \n", "jurisprudence | \n", "reexamination | \n", "writ | \n", "commonlaw | \n", "intent | \n", "refutation | \n", "habeas | \n", "corpus | \n", "clerks | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "5.0 | \n", "7.0 | \n", "1.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
1 | \n", "0.0 | \n", "2.0 | \n", "0.0 | \n", "0.0 | \n", "2.0 | \n", "1.0 | \n", "0.0 | \n", "2.0 | \n", "0.0 | \n", "3.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
2 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "2.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
3 | \n", "0.0 | \n", "2.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "2.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
4 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
80 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "6.0 | \n", "1.0 | \n", "1.0 | \n", "8.0 | \n", "12.0 | \n", "8.0 | \n", "... | \n", "1.0 | \n", "1.0 | \n", "5.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
81 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "12.0 | \n", "0.0 | \n", "0.0 | \n", "2.0 | \n", "4.0 | \n", "5.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
82 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "7.0 | \n", "2.0 | \n", "2.0 | \n", "9.0 | \n", "13.0 | \n", "6.0 | \n", "... | \n", "7.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "5.0 | \n", "2.0 | \n", "2.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "
83 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "2.0 | \n", "1.0 | \n", "0.0 | \n", "9.0 | \n", "26.0 | \n", "12.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "2.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "3.0 | \n", "3.0 | \n", "1.0 | \n", "
84 | \n", "0.0 | \n", "2.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "13.0 | \n", "2.0 | \n", "... | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
85 rows × 5320 columns
\n", "