{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "IoYdLutn7dUa" }, "source": [ "# Latent Dirichlet Allocation\n", "\n", "Recall the following generatize model for LDA. Suppose we have $K$ topics and $N$ documents.\n", "\n", "For each topic $k =1,\\ldots,K$, draw a topic \n", "\n", "$$\\theta_k \\sim \\text{Dir}(\\phi)$$\n", "\n", "Then, for each document $n = 1,\\ldots, N$, draw topic proportions \n", "\n", "$$\\pi_n \\sim \\text{Dir}(\\alpha)$$\n", "\n", "Finally, for each word $l$ in document $n$, first draw a topic assignment \n", "\n", "$$\n", "z_{n,d} \\mid \\pi_n \\sim \\text{Cat}(\\pi_n)\n", "$$\n", "\n", "and draw a word\n", "\n", "$$\n", "x_{n,d} \\mid z_{n,d} \\sim \\text{Cat}(\\theta_{z_{n,d}})\n", "$$\n", "\n", "As mentioned in the lecture notes, while this formulation is easier to present, it's more efficient to represent the documents as sparse vectors of _word counts_, $\\mathbf{y}_n \\in \\mathbb{N}^V$ where $y_{n,v} = \\sum_{d=1}^D \\mathbb{I}[x_{n,d} = v]$. \n", "\n", "This notebook studies Federalist papers in their entirety. We've provided a $N \\times V$ dataframe of the essays represented as word counts. The rows of the data frame correspond to the 85 individual essays and the columns correspond to the 5320 words in the vocabulary. We have already preprocessed the raw essays to remove very common and very infrequent words.\n", "\n", "Using this data, we will fit a topic model and do some analysis." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "GFN1-wvIue3K" }, "outputs": [], "source": [ "import torch\n", "from torch.distributions import Dirichlet, Multinomial, Categorical\n", "import pandas as pd \n", "\n", "import matplotlib.pyplot as plt\n", "from tqdm.auto import trange" ] }, { "cell_type": "markdown", "metadata": { "id": "DP4qzh_Kue3L" }, "source": [ "## Load the data\n", "\n", "We've already tokenized the text and created a bag-of-words representation of the corpus. We removed words from the vocabulary that occur in more than 95% of the essays or only appear in 1 essay." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2kDPPoqpvQ-x", "outputId": "4a14bcf1-05a1-441b-838b-68c69cb8a241", "tags": [ "hide-output" ] }, "outputs": [], "source": [ "# Download the data\n", "!wget -nc https://www.dropbox.com/s/p6jb2cw5w5626pl/tokenized_fed.csv\n", "!wget -nc https://www.dropbox.com/s/ftedra0jyk1j3hx/authorship.csv" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 487 }, "id": "Ig2BB1-Eue3M", "outputId": "8d8cf501-7c6e-42ce-875d-0e04e337c164" }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
unequivocalexperienceinefficacysubsistingfederalcalleddeliberatenewconstitutionunited...chanceryjurisprudencereexaminationwritcommonlawintentrefutationhabeascorpusclerks
01.01.01.01.01.01.01.05.07.01.0...0.00.00.00.00.00.00.00.00.00.0
10.02.00.00.02.01.00.02.00.03.0...0.00.00.00.00.00.00.00.00.00.0
20.01.00.00.02.00.00.01.00.04.0...0.00.00.00.00.00.00.00.00.00.0
30.02.00.00.00.01.00.00.00.02.0...0.00.00.00.00.00.00.00.00.00.0
40.01.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
..................................................................
800.00.00.00.06.01.01.08.012.08.0...1.01.05.01.01.00.00.00.00.00.0
810.00.00.00.012.00.00.02.04.05.0...0.00.00.00.00.01.00.00.00.00.0
820.01.00.00.07.02.02.09.013.06.0...7.01.01.00.05.02.02.01.01.01.0
830.00.00.00.02.01.00.09.026.012.0...0.00.00.02.00.00.00.03.03.01.0
840.02.00.00.00.01.00.04.013.02.0...0.00.00.00.00.00.01.00.00.00.0
\n", "

85 rows × 5320 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ], "text/plain": [ " unequivocal experience inefficacy subsisting federal called \\\n", "0 1.0 1.0 1.0 1.0 1.0 1.0 \n", "1 0.0 2.0 0.0 0.0 2.0 1.0 \n", "2 0.0 1.0 0.0 0.0 2.0 0.0 \n", "3 0.0 2.0 0.0 0.0 0.0 1.0 \n", "4 0.0 1.0 0.0 0.0 0.0 0.0 \n", ".. ... ... ... ... ... ... \n", "80 0.0 0.0 0.0 0.0 6.0 1.0 \n", "81 0.0 0.0 0.0 0.0 12.0 0.0 \n", "82 0.0 1.0 0.0 0.0 7.0 2.0 \n", "83 0.0 0.0 0.0 0.0 2.0 1.0 \n", "84 0.0 2.0 0.0 0.0 0.0 1.0 \n", "\n", " deliberate new constitution united ... chancery jurisprudence \\\n", "0 1.0 5.0 7.0 1.0 ... 0.0 0.0 \n", "1 0.0 2.0 0.0 3.0 ... 0.0 0.0 \n", "2 0.0 1.0 0.0 4.0 ... 0.0 0.0 \n", "3 0.0 0.0 0.0 2.0 ... 0.0 0.0 \n", "4 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n", ".. ... ... ... ... ... ... ... \n", "80 1.0 8.0 12.0 8.0 ... 1.0 1.0 \n", "81 0.0 2.0 4.0 5.0 ... 0.0 0.0 \n", "82 2.0 9.0 13.0 6.0 ... 7.0 1.0 \n", "83 0.0 9.0 26.0 12.0 ... 0.0 0.0 \n", "84 0.0 4.0 13.0 2.0 ... 0.0 0.0 \n", "\n", " reexamination writ commonlaw intent refutation habeas corpus clerks \n", "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", ".. ... ... ... ... ... ... ... ... \n", "80 5.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 \n", "81 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 \n", "82 1.0 0.0 5.0 2.0 2.0 1.0 1.0 1.0 \n", "83 0.0 2.0 0.0 0.0 0.0 3.0 3.0 1.0 \n", "84 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 \n", "\n", "[85 rows x 5320 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the data\n", "df = pd.read_csv('tokenized_fed.csv', index_col = 0)\n", "docs = torch.tensor(df.to_numpy()).type(torch.int)\n", "vocab = df.columns.to_list()\n", "df" ] }, { "cell_type": "markdown", "metadata": { "id": "7Nyg-s9Yx1JN" }, "source": [ "## Write some helper fucntions for Dirichlet distributions\n", "\n", "Specifically, we need the expected log of a Dirichlet random vector and the KL divergence between two Dirichlet random variables." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "UoqrOOfvue3N" }, "outputs": [], "source": [ "def dirichlet_expected_log(dirichlet):\n", " \"\"\"Helper function to compute expected log under Dirichlet distribution.\n", "\n", " Args:\n", " dirichlet: A torch.distributions.Dirichlet object with a batch shape of\n", " (...,) and a event shape of (K,).\n", "\n", " Returns:\n", " (...,K) tensor of expected logs, E[\\log \\pi], under the Dirichlet.\n", " \"\"\"\n", " concentration = dirichlet.concentration\n", " return torch.special.digamma(concentration) - \\\n", " torch.special.digamma(concentration.sum(dim=-1, keepdims=True))\n", "\n", "\n", "def dirichlet_log_normalizer(concentration):\n", " \"\"\"Compute the log normalizing constant of a Dirichlet distribution with\n", " the specificed concentration.\n", "\n", " Args:\n", " concentration: (...,K) tensor of concentration parameters\n", "\n", " Returns:\n", " (...,) batch of log normalizers\n", " \"\"\"\n", " return torch.special.gammaln(concentration).sum(dim=-1) - \\\n", " torch.special.gammaln(concentration.sum(dim=-1))\n", "\n", "def dirichlet_kl(q, p):\n", " \"\"\"Compute the KL divergence between two Dirichlet disdtributions\n", "\n", " Args:\n", " q: A torch.distributions.Dirichlet object\n", " p: A torch.distributions.Dirichlet object over the same domain\n", "\n", " Returns:\n", " A (batch of) KL divergence(s) between q and p.\n", " \"\"\"\n", " kl = -dirichlet_log_normalizer(q.concentration)\n", " kl += dirichlet_log_normalizer(p.concentration)\n", " kl += torch.sum((q.concentration - p.concentration) * \\\n", " dirichlet_expected_log(q), dim=-1)\n", " return kl" ] }, { "cell_type": "markdown", "metadata": { "id": "-bCID4jque3N" }, "source": [ "## Implement Coordinate Ascent Variational Inference (CAVI)\n", "\n", "_Note: The `torch.distributions.Multinomial` object doesn't work well when you have a batch with different numbers of counts. We hijack this object by not giving it a count so that it defaults to 1, which is equivalent to a categorical distribution. Then we multiply by the total counts to get the necessary expectations under the multinomial posterior._" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "x-TIiIlxue3O" }, "outputs": [], "source": [ "def cavi(docs, \n", " num_topics=10, \n", " num_iters=200, \n", " tol=1e-5, \n", " alpha=20.0, \n", " phi=1.0,\n", " seed=305 + ord('c'),\n", " ):\n", " \"\"\"Run coordinate ascent VI for LDA.\n", " \n", " \"\"\"\n", " docs = docs.type(torch.float)\n", " N, V = docs.shape\n", " K = num_topics # short hand\n", " \n", " def cavi_step(q_c, q_pi, q_theta):\n", " \"\"\"One step of CAVI.\n", " \"\"\"\n", " # Update the topic assignment counts\n", " E_logpi = dirichlet_expected_log(q_pi)\n", " E_logtheta = dirichlet_expected_log(q_theta)\n", " q_c = Multinomial(logits=E_logpi[:, None, :] + E_logtheta.T)\n", " \n", " # Compute the mean of q(c) since we'll use it twice below\n", " E_c = docs.unsqueeze(2) * q_c.mean\n", " \n", " # Update the topic proportions\n", " q_pi = Dirichlet(alpha + E_c.sum(axis=1))\n", " \n", " # Update the topic word probabilities\n", " q_theta = Dirichlet(phi + E_c.sum(axis=0).T)\n", " \n", " return q_c, q_pi, q_theta\n", " \n", " def elbo(q_c, q_pi, q_theta):\n", " \"\"\"Compute the evidence lower bound.\n", " \"\"\"\n", " elbo = 0\n", " \n", " # KL to prior\n", " elbo -= dirichlet_kl(q_pi, Dirichlet(alpha * torch.ones(K))).sum()\n", " elbo -= dirichlet_kl(q_theta, Dirichlet(phi * torch.ones(V))).sum()\n", " \n", " # Entropy of q(z) [a little different from multinomial entropy]\n", " E_c = docs.unsqueeze(2) * q_c.mean\n", " elbo -= torch.sum(E_c * torch.log(q_c.probs))\n", " \n", " # Expected log p(z | \\theta)\n", " E_logpi = dirichlet_expected_log(q_pi)\n", " E_logtheta = dirichlet_expected_log(q_theta)\n", " elbo += torch.sum(E_c * E_logpi[:, None, :])\n", " elbo += torch.sum(E_c * E_logtheta.T)\n", " \n", " return elbo / torch.sum(docs)\n", " \n", " # Initialize the topics by randomly clustering the documents\n", " # and using their word counts\n", " torch.manual_seed(seed)\n", " clusters = Categorical(logits=torch.zeros(K)).sample((N,))\n", " q_pi = Dirichlet(alpha * torch.ones((N, K)))\n", " q_theta = Dirichlet(phi + torch.row_stack([docs[clusters == k].sum(axis=0) \n", " for k in range(K)]))\n", " q_c = Multinomial(logits=torch.zeros((N, V, K)))\n", " \n", " # Run CAVI\n", " elbos = [elbo(q_c, q_pi, q_theta)]\n", " for itr in trange(num_iters):\n", " q_c, q_pi, q_theta = cavi_step(q_c, q_pi, q_theta)\n", " elbos.append(elbo(q_c, q_pi, q_theta))\n", " \n", " if elbos[-1] - elbos[-2] < -1e-4:\n", " raise Exception(\"ELBO is going down!\")\n", " elif elbos[-1] - elbos[-2] < tol:\n", " print(\"Converged!\")\n", " break\n", " \n", " return torch.tensor(elbos), (q_c, q_pi, q_theta)\n", " " ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 347, "referenced_widgets": [ "678444446c6b403fa165ecf9abbb7624", "eb674660b30041baa6de6eddc6926cf6", "10bfdab627d24288bffa6b00d1214955", "02f48be17ec246409841be7c4429dff9", "b14a145cf2154b1a9366a34350b0d2c6", "eb963e8602844150b388f15247e9c90d", "50f0438726d74985bb222d18e2844996", "1b30e510e3de423d9427e14306e4a54f", "d54823f5694b40f69c97eb47f948d916", "4e68636d615c443f9be9a18071ed02d8", "6189223a819546c9bcf080bf463f9fe8" ] }, "id": "7iZzXB-9ue3O", "outputId": "34da3019-f44d-45c6-fe68-1554460b3d46" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "678444446c6b403fa165ecf9abbb7624", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/200 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "elbos, (q_c, q_pi, q_theta) = cavi(docs)\n", "\n", "plt.plot(elbos)\n", "plt.xlabel(\"Iteration\")\n", "plt.ylabel(\"ELBO per word\")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 301 }, "id": "VKgIDiPBue3O", "outputId": "f38e196a-fb59-4b0f-f3e3-e188f37601dc" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "topic_usage = q_pi.mean\n", "plt.imshow(topic_usage, aspect=\"auto\", vmin=0, vmax=1)\n", "plt.xticks(torch.arange(topic_usage.shape[1]))\n", "plt.xlabel(\"topic\")\n", "plt.ylabel(\"document\")\n", "plt.colorbar()" ] }, { "cell_type": "code", "execution_count": 72, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "GICzRqWWue3O", "outputId": "589371f5-6eaa-439f-88e9-9a7f23b5f149" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "topic 0 usage : tensor([0.3847, 0.0888, 0.0834, 0.0931, 0.0613, 0.0777, 0.0585, 0.0508, 0.0508,\n", " 0.0508])\n", "\t people : tensor(0.0149)\n", "\t federal : tensor(0.0077)\n", "\t national : tensor(0.0070)\n", "\t union : tensor(0.0062)\n", "\t representatives : tensor(0.0058)\n", "\n", "topic 1 usage : tensor([0.4278, 0.0656, 0.0646, 0.0685, 0.0777, 0.1051, 0.0534, 0.0457, 0.0457,\n", " 0.0457])\n", "\t executive : tensor(0.0141)\n", "\t legislative : tensor(0.0097)\n", "\t senate : tensor(0.0074)\n", "\t body : tensor(0.0074)\n", "\t president : tensor(0.0068)\n", "\n", "topic 2 usage : tensor([0.2959, 0.1060, 0.1093, 0.0754, 0.0586, 0.1825, 0.0463, 0.0420, 0.0420,\n", " 0.0420])\n", "\t cases : tensor(0.0093)\n", "\t courts : tensor(0.0088)\n", "\t jurisdiction : tensor(0.0065)\n", "\t trial : tensor(0.0060)\n", "\t court : tensor(0.0059)\n", "\n", "topic 3 usage : tensor([0.1641, 0.0610, 0.0538, 0.0989, 0.0484, 0.4167, 0.0448, 0.0374, 0.0374,\n", " 0.0374])\n", "\t peace : tensor(0.0058)\n", "\t military : tensor(0.0055)\n", "\t time : tensor(0.0049)\n", "\t militia : tensor(0.0047)\n", "\t war : tensor(0.0045)\n", "\n", "topic 4 usage : tensor([0.1576, 0.0593, 0.0542, 0.0715, 0.0478, 0.4437, 0.0471, 0.0396, 0.0396,\n", " 0.0396])\n", "\t powers : tensor(0.0105)\n", "\t constitution : tensor(0.0075)\n", "\t congress : tensor(0.0068)\n", "\t authority : tensor(0.0065)\n", "\t confederation : tensor(0.0062)\n", "\n", "topic 5 usage : tensor([0.1520, 0.0574, 0.0452, 0.0517, 0.0385, 0.5027, 0.0535, 0.0330, 0.0330,\n", " 0.0330])\n", "\t nations : tensor(0.0073)\n", "\t us : tensor(0.0062)\n", "\t commerce : tensor(0.0044)\n", "\t war : tensor(0.0041)\n", "\t foreign : tensor(0.0034)\n", "\n", "topic 6 usage : tensor([0.1873, 0.0498, 0.0726, 0.0520, 0.0590, 0.4437, 0.0382, 0.0325, 0.0325,\n", " 0.0325])\n", "\t confederacy : tensor(0.0045)\n", "\t members : tensor(0.0044)\n", "\t empire : tensor(0.0034)\n", "\t cities : tensor(0.0027)\n", "\t among : tensor(0.0021)\n", "\n", "topic 7 usage : tensor([0.1334, 0.0510, 0.0439, 0.5379, 0.0393, 0.0635, 0.0378, 0.0310, 0.0310,\n", " 0.0310])\n", "\t kind : tensor(0.0003)\n", "\t reasons : tensor(0.0003)\n", "\t equally : tensor(0.0003)\n", "\t union : tensor(0.0003)\n", "\t connected : tensor(0.0003)\n", "\n", "topic 8 usage : tensor([0.3348, 0.0920, 0.0871, 0.0862, 0.0592, 0.0634, 0.1420, 0.0451, 0.0451,\n", " 0.0451])\n", "\t kind : tensor(0.0003)\n", "\t reasons : tensor(0.0003)\n", "\t equally : tensor(0.0003)\n", "\t union : tensor(0.0003)\n", "\t connected : tensor(0.0003)\n", "\n", "topic 9 usage : tensor([0.7449, 0.0341, 0.0359, 0.0272, 0.0263, 0.0312, 0.0260, 0.0248, 0.0248,\n", " 0.0248])\n", "\t kind : tensor(0.0003)\n", "\t reasons : tensor(0.0003)\n", "\t equally : tensor(0.0003)\n", "\t union : tensor(0.0003)\n", "\t connected : tensor(0.0003)\n", "\n" ] } ], "source": [ "# Analyze the topics\n", "usage = q_pi.mean\n", "topics = q_theta.mean\n", "\n", "# Sort the topics by usage\n", "topic_perm = torch.argsort(usage.sum(0), descending=True)\n", "usage = usage[:, topic_perm]\n", "topics = topics[topic_perm]\n", "\n", "keywords = []\n", "for k, topic in enumerate(topics):\n", " # if torch.allclose(usage[k], usage.min()):\n", " # continue\n", " print(\"topic \", k, \"usage : \", usage[k])\n", " \n", " inds = torch.argsort(topic, descending=True)\n", " keywords.append(vocab[inds[0]])\n", " for i, ind in enumerate(inds[:5]):\n", " print(\"\\t\", vocab[ind], \":\", topic[ind])\n", " print(\"\")" ] }, { "cell_type": "markdown", "metadata": { "id": "a4BTtABuue3P" }, "source": [ "## Exploring topics by author usage\n", "\n", "Using the model, plot the total topic usage for each author" ] }, { "cell_type": "code", "execution_count": 73, "metadata": { "id": "MC1TVd-hue3P" }, "outputs": [], "source": [ "# load authorship and convert \n", "authors_df = pd.read_csv('authorship.csv', index_col = 0)\n", "to_tensor = lambda df: torch.tensor(df.to_numpy())\n", "author_names = [\"HAMILTON\", \"JAY\", \"MADISON\", \"DISPUTED\"]\n", "authors = torch.column_stack([to_tensor(authors_df == name) \n", " for name in author_names])" ] }, { "cell_type": "code", "execution_count": 74, "metadata": { "id": "z2mzQVVD43c5" }, "outputs": [], "source": [ "author_usage = torch.row_stack([usage[authors[:, i]].mean(dim=0) for i in range(4)])" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 347 }, "id": "wG4j2LI05JNW", "outputId": "270bbce3-2432-4561-94c7-f4a01131d83a" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, axs = plt.subplots(1, 4, figsize=(12, 4), sharey=True)\n", "for ax, name, usage in zip(axs, author_names, author_usage):\n", " ax.bar(torch.arange(10), usage)\n", " ax.set_title(name)\n", " ax.set_xlabel(\"topic\")\n", " ax.set_xticks(torch.arange(10))\n", " ax.set_xticklabels(keywords, rotation=90)\n", " # ax.set_ylabel(\"usage\")" ] }, { "cell_type": "markdown", "metadata": { "id": "9cDAAFhl7bto" }, "source": [ "## Conclusion\n", "\n", "This notebook demonstrates a very naive implementation of CAVI for LDA and applies it to the Federalist papers -- a collection of documents authored by Hamilton, Jay, and Madison that were influential in forming the US Constitution. \n", "\n", "There are many ways this implementation could be improved. For example,\n", "- While it does work with word counts, it does not take advantage of the sparsity of the data matrix. It explicitly instantiates parameters for the posterior over $c_{n,v}$ even when $y_{n,v}=0$. We could improve performance by leveraging this sparsity.\n", "\n", "- It operates in \"batch mode,\" which is fine for small datasets like this one, but can become intractable for massive corpora, like all the pages of Wikipedia. For those regimes, it is better to work with _stochastic variational inference_ (Hoffman et al, 2011)." ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "Lecture 9: Latent Dirichlet Allocation and CAVI", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 4 }