Calcium Deconvolution#

In this lab you’ll write your own code for demixing and deconvolving calcium imaging videos. Demixing refers to the problem of identifying potentially overlapping neurons in the video and separating their fluorescence traces. Deconvolving refers to taking those traces and finding the times of spiking activity, which produce exponentially decaying transients in fluorescence. We’ll frame it as a constrained and (partially) non-negative matrix factorization problem, inspired by the CNMF model of Pnevmatikakis et al, 2016, which is implemented in CaImAn (Giovannucci et al, 2019). More details and further references are in the course notes. We’ll use CVXpy to solve the convex optimization problems at the hard of this approach.

References

  • Pnevmatikakis, Eftychios A., Daniel Soudry, Yuanjun Gao, Timothy A. Machado, Josh Merel, David Pfau, Thomas Reardon, et al. 2016. “Simultaneous Denoising, Deconvolution, and Demixing of Calcium Imaging Data.” Neuron 89 (2): 285–99. link

  • Giovannucci, Andrea, Johannes Friedrich, Pat Gunn, Jérémie Kalfon, Brandon L. Brown, Sue Ann Koay, Jiannis Taxidis, et al. 2019. “CaImAn an Open Source Tool for Scalable Calcium Imaging Data Analysis.” eLife. link

Setup#

%%capture
try:
    import jaxtyping
except ImportError:
    !pip install jaxtyping

from typing import Optional, Any, Dict, Union, Tuple
from jaxtyping import Float, Int, Array
from torch import Tensor
import torch
import torch.nn.functional as F
import torch.distributions as dist
import numpy as np

# We'll use a few SciPy functions too
import scipy.sparse
from scipy.signal import butter, sosfilt
from scipy.ndimage import gaussian_filter
from skimage.feature import peak_local_max

# We'll use CVXpy to solve convex optimization problems
import cvxpy as cvx

# Plotting stuff
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from matplotlib.patches import Circle
import seaborn as sns

# Helpers
from tqdm.auto import trange
import warnings
device = "cuda" if torch.cuda.is_available() else "cpu"

Download example data#

This demo data was contributed by Sue Ann Koay and David Tank (Princeton University). It is also used in the CaImAn demo notebook. We used CaImAn and NoRMCorr to correct for motion artifacts.

!wget -nc https://github.com/slinderman/ml4nd/raw/refs/heads/main/data/02_calcium_imaging/lab02_data.pt
# Load the data
data = torch.load("lab02_data.pt")
height, width, num_frames = data.shape

# Set some constants
FPS = 30                        # frames per second in the movie
NEURON_WIDTH = 10               # approximate width (in pixels) of a neuron
GCAMP_TIME_CONST_SEC = 0.300    # reasonable guess for calcium decay time const.

Helper functions for plotting#

Hide code cell content
#@title Helper functions for movies and plotting { display-mode: "form" }
from matplotlib import animation
from IPython.display import HTML
from tempfile import NamedTemporaryFile
import base64

# Set some plotting defaults
sns.set_context("talk")

# initialize a color palette for plotting
palette = sns.xkcd_palette(["windows blue",
                            "red",
                            "medium green",
                            "dusty purple",
                            "orange",
                            "amber",
                            "clay",
                            "pink",
                            "greyish"])

_VIDEO_TAG = """<video controls>
 <source src="data:video/x-m4v;base64,{0}" type="video/mp4">
 Your browser does not support the video tag.
</video>"""

def _anim_to_html(anim: animation.FuncAnimation,
                  fps: int = 20) -> str:
    """
    Convert a Matplotlib animation object to an HTML video snippet
    with an embedded base64-encoded video.

    Parameters:
    -----------
    anim : matplotlib.animation.FuncAnimation
        The Matplotlib animation object to encode.
    fps : int, optional
        Frames per second to use when encoding the animation. Default is 20.

    Returns:
    --------
    str
        An HTML string containing the video tag with the encoded video.
    """
    if not hasattr(anim, '_encoded_video'):
        with NamedTemporaryFile(suffix='.mp4') as f:
            anim.save(f.name, fps=fps, extra_args=['-vcodec', 'libx264'])
            video = open(f.name, "rb").read()
        anim._encoded_video = base64.b64encode(video)

    return _VIDEO_TAG.format(anim._encoded_video.decode('ascii'))

def _display_animation(anim: animation.FuncAnimation,
                       fps: int = 30,
                       start: int = 0,
                       stop: Optional[int] = None) -> HTML:
    """
    Display a Matplotlib animation by converting it to an HTML video snippet
    and closing its figure.

    Parameters:
    -----------
    anim : matplotlib.animation.FuncAnimation
        The animation object to display.
    fps : int, optional
        Frames per second for the displayed animation. Default is 30.
    start : int, optional
        Starting frame index (currently not used #TODO). Default is 0.
    stop : Optional[int], optional
        Ending frame index (currently not used #TODO). Default is None.

    Returns:
    --------
    IPython.display.HTML
        An HTML snippet containing the video for display.
    """
    plt.close(anim._fig)
    return HTML(_anim_to_html(anim, fps=fps))

def play(movie: Int[Tensor, "height width num_frames"],
         fps: int = FPS,
         speedup: int = 1,
         fig_height: int = 6):
    """
    Create an animation from a movie tensor and return
    an HTML snippet for embedding.

    Parameters:
    -----------
    movie : torch.Tensor
        A 3D tensor with shape (height, width, num_frames) for the movie frames.
    fps : int, optional
        Frames per second of the movie. Default is the global FPS constant.
    speedup : int, optional
        Factor to speed up the animation. Default is 1 (real-time).
    fig_height : int, optional
        Height of the figure (in inches) for the plot. Default is 6.

    Returns:
    --------
    IPython.display.HTML
        An HTML video snippet displaying the animation.

    Notes:
    ------
    This function uses Matplotlib's FuncAnimation to create the animation
    and embeds it as an HTML video.
    """
    # First set up the figure, the axis, and the plot element we want to animate
    Py, Px, T = movie.shape
    fig, ax = plt.subplots(1, 1, figsize=(fig_height * Px/Py, fig_height))
    im = plt.imshow(movie[..., 0], interpolation='None', cmap=plt.cm.gray)
    tx = plt.text(0.75, 0.05, 't={:.3f}s'.format(0),
                  color='white',
                  fontdict=dict(size=12),
                  horizontalalignment='left',
                  verticalalignment='center',
                  transform=ax.transAxes)
    plt.axis('off')

    def animate(i: int):
        """
        Update function for FuncAnimation.

        Parameters:
        -----------
        i : int
            The frame index.

        Returns:
        --------
        tuple
            A tuple containing the modified image;
            used for efficient animation updates.
        """
        im.set_data(movie[..., i * speedup])
        tx.set_text("t={:.3f}s".format(i * speedup / fps))
        return im,

    # call the animator.  blit=True means only re-draw the parts that changed.
    anim = animation.FuncAnimation(fig, animate,
                                   frames=T // speedup,
                                   interval=1,
                                   blit=True)
    plt.close(anim._fig)

    # return an HTML video snippet
    print("Preparing animation. This may take a minute...")
    return HTML(_anim_to_html(anim, fps=30))

def plot_problem_1d(local_correlations: Float[torch.Tensor, "height width"],
                    filtered_correlations: Float[torch.Tensor, "height width"],
                    peaks: Int[torch.Tensor, "num_peaks 2"]) -> None:
    """
    Plot local correlations, filtered correlations,
    and candidate neurons side-by-side.

    Parameters
    ----------
    local_correlations : torch.Tensor
        2D tensor representing local correlation values,
        with shape (height, width).
    filtered_correlations : torch.Tensor
        2D tensor representing filtered correlation values,
        with shape (height, width).
    peaks : torch.Tensor
        Tensor of shape (num_peaks, 2) where each row is [y, x] coordinates
        for peak locations.

    Returns
    -------
    None

    Notes
    -----
    The function displays three panels and overlays circles and labels
    on the candidate neuron panel.
    """
    def _plot_panel(ax: plt.Axes,
                    im: Float[torch.Tensor, "height width"],
                    title: str) -> None:
        """
        Plot an individual panel with an image and a colorbar.

        Parameters
        ----------
        ax : matplotlib.axes.Axes
            The axis on which to plot.
        im : torch.Tensor
            The image data to display, with shape (height, width).
        title : str
            Title for the panel.
        """
        h = ax.imshow(im, cmap="Greys_r")
        ax.set_title(title)
        ax.set_xlim(0, width)
        ax.set_ylim(height, 0)
        ax.set_axis_off()

        # add a colorbar of the same height
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad="2%")
        plt.colorbar(h, cax=cax)

    fig, axs = plt.subplots(1, 3, figsize=(15, 6))
    _plot_panel(axs[0], local_correlations, "local correlations")
    _plot_panel(axs[1], filtered_correlations, "filtered correlations")
    _plot_panel(axs[2], local_correlations, "candidate neurons")

    # Draw circles around the peaks
    for n, yx in enumerate(peaks):
        y, x = yx
        axs[2].add_patch(Circle((x, y),
                                radius=NEURON_WIDTH/2,
                                facecolor='none',
                                edgecolor='red',
                                linewidth=1))

        axs[2].text(x, y, "{}".format(n),
                    horizontalalignment="center",
                    verticalalignment="center",
                    fontdict=dict(size=10, weight="bold"),
                    color='r')

def plot_problem_2(traces: Float[torch.Tensor, "num_neurons num_frames"],
                     denoised_traces: Float[torch.Tensor, "num_neurons num_frames"],
                     amplitudes: Float[torch.Tensor, "num_neurons num_frames"]) -> None:
    """
    Plot raw and denoised fluorescence traces for neurons
    along with amplitude markers.

    Parameters
    ----------
    traces : torch.Tensor
        2D tensor of raw fluorescence traces with shape (num_neurons, num_frames).
    denoised_traces : torch.Tensor
        2D tensor of denoised fluorescence traces, same shape as 'traces'.
    amplitudes : torch.Tensor
        2D tensor of estimated amplitudes, same shape as 'traces'.

    Returns
    -------
    None

    Notes
    -----
    Normalizes the traces by their 99.5th percentile,
    offsets the traces for clarity, and highlights amplitude markers.
    """
    num_neurons, num_frames = traces.shape

    # Plot the traces and our denoised estimates
    scale = torch.quantile(traces, .995, dim=1, keepdims=True)
    offset = -torch.arange(num_neurons)

    # Plot points at the time frames where the (normalized) amplitudes are > 0.05
    sparse_amplitudes = amplitudes / scale
    # sparse_amplitudes = torch.isclose(sparse_amplitudes, 0, atol=0.05)
    sparse_amplitudes[sparse_amplitudes < 0.05] = torch.nan
    sparse_amplitudes[sparse_amplitudes > 0.05] = 0.0

    plt.figure(figsize=(12, 8))
    plt.plot((traces / scale).T + offset , color=palette[0], lw=1, alpha=0.5)
    plt.plot((denoised_traces / scale).T + offset, color=palette[0], lw=2)
    plt.plot((sparse_amplitudes).T + offset, color=palette[1], marker='o', markersize=2)
    plt.xlabel("time (frames)")
    plt.xlim(0, num_frames)
    plt.ylabel("neuron")
    plt.yticks(-torch.arange(0, num_neurons, step=5),
               labels=torch.arange(0, num_neurons, step=5).numpy())
    plt.ylim(-num_neurons, 2)
    plt.title("raw and denoised fluorescence traces")


def plot_problem_3(flat_data: Any,
                     params: Dict[str, torch.Tensor],
                     hypers: Any,
                     plot_bkgd: bool = True,
                     indices: Optional[Int[torch.Tensor, "num_indices"]] = None) -> None:
    """
    Plot neuron footprints and their corresponding fluorescence traces.

    Parameters
    ----------
    flat_data : Any
        Data used for plotting (currently not used within the function).
    params : dict[str, torch.Tensor]
        Dictionary containing model parameters with the following keys:
          - 'footprints': Neuron footprints, expected to be reshaped to (-1, height, width).
          - 'bkgd_footprint': Background footprint reshaped to (height, width).
          - 'traces': Fluorescence traces for each neuron.
          - 'bkgd_trace': Background trace.
    hypers : Any
        Hyperparameters related to the model (not directly used in plotting).
    plot_bkgd : bool, optional
        Whether to include background plots. Default is True.
    indices : Optional[torch.Tensor], optional
        Indices of neurons to plot; if None, all neurons are plotted.

    Returns
    -------
    None

    Notes
    -----
    Generates a separate plot for each neuron (and background, if requested) showing both the spatial footprint
    and its corresponding fluorescence trace.
    """
    U = params["footprints"].reshape(-1, height, width)
    u0 = params["bkgd_footprint"].reshape(height, width)
    C = params["traces"]
    c0 = params["bkgd_trace"]
    N, T = C.shape

    if indices is None:
        indices = torch.arange(N)

    def _plot_factor(footprint: Float[torch.Tensor, "height width"],
                     trace: Float[torch.Tensor, "num_frames"],
                     title: str) -> None:
        """
        Plot a neuron's footprint and its corresponding fluorescence trace.

        Parameters
        ----------
        footprint : torch.Tensor
            2D tensor representing the spatial footprint of a neuron,
            with shape (height, width).
        trace : torch.Tensor
            1D tensor containing the fluorescence trace over time,
            with shape (num_frames,).
        title : str
            Title for the subplot.
        """
        fig, ax1 = plt.subplots(1, 1, figsize=(12, 6))
        vlim = abs(footprint).max()
        h = ax1.imshow(footprint, vmin=-vlim, vmax=vlim, cmap="RdBu")
        ax1.set_title(title)
        ax1.set_axis_off()

        # add a colorbar of the same height
        divider = make_axes_locatable(ax1)
        cax = divider.append_axes("right", size="5%", pad="2%")
        plt.colorbar(h, cax=cax)

        ax2 = divider.append_axes("right", size="150%", pad="75%")
        ts = torch.arange(T) / FPS
        ax2.plot(ts, trace, color=palette[0], lw=2)
        ax2.set_xlabel("time (sec)")
        ax2.set_ylabel("fluorescence trace")

    if plot_bkgd:
        _plot_factor(u0, c0, "background")

    for k in indices:
        _plot_factor(U[k], C[k], "neuron {}".format(k))

Movie of the data#

It takes a minute to render the animation…

# Play the motion corrected movie.
play(data, speedup=5)
Preparing animation. This may take a minute...