Calcium Deconvolution#
In this lab you’ll write your own code for demixing and deconvolving calcium imaging videos. Demixing refers to the problem of identifying potentially overlapping neurons in the video and separating their fluorescence traces. Deconvolving refers to taking those traces and finding the times of spiking activity, which produce exponentially decaying transients in fluorescence. We’ll frame it as a constrained and (partially) non-negative matrix factorization problem, inspired by the CNMF model of Pnevmatikakis et al, 2016, which is implemented in CaImAn (Giovannucci et al, 2019). More details and further references are in the course notes. We’ll use CVXpy to solve the convex optimization problems at the hard of this approach.
References
Pnevmatikakis, Eftychios A., Daniel Soudry, Yuanjun Gao, Timothy A. Machado, Josh Merel, David Pfau, Thomas Reardon, et al. 2016. “Simultaneous Denoising, Deconvolution, and Demixing of Calcium Imaging Data.” Neuron 89 (2): 285–99. link
Giovannucci, Andrea, Johannes Friedrich, Pat Gunn, Jérémie Kalfon, Brandon L. Brown, Sue Ann Koay, Jiannis Taxidis, et al. 2019. “CaImAn an Open Source Tool for Scalable Calcium Imaging Data Analysis.” eLife. link
Setup#
%%capture
try:
import jaxtyping
except ImportError:
!pip install jaxtyping
from typing import Optional, Any, Dict, Union, Tuple
from jaxtyping import Float, Int, Array
from torch import Tensor
import torch
import torch.nn.functional as F
import torch.distributions as dist
import numpy as np
# We'll use a few SciPy functions too
import scipy.sparse
from scipy.signal import butter, sosfilt
from scipy.ndimage import gaussian_filter
from skimage.feature import peak_local_max
# We'll use CVXpy to solve convex optimization problems
import cvxpy as cvx
# Plotting stuff
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from matplotlib.patches import Circle
import seaborn as sns
# Helpers
from tqdm.auto import trange
import warnings
device = "cuda" if torch.cuda.is_available() else "cpu"
Download example data#
This demo data was contributed by Sue Ann Koay and David Tank (Princeton University). It is also used in the CaImAn demo notebook. We used CaImAn and NoRMCorr to correct for motion artifacts.
!wget -nc https://github.com/slinderman/ml4nd/raw/refs/heads/main/data/02_calcium_imaging/lab02_data.pt
# Load the data
data = torch.load("lab02_data.pt")
height, width, num_frames = data.shape
# Set some constants
FPS = 30 # frames per second in the movie
NEURON_WIDTH = 10 # approximate width (in pixels) of a neuron
GCAMP_TIME_CONST_SEC = 0.300 # reasonable guess for calcium decay time const.
Helper functions for plotting#
Show code cell content
#@title Helper functions for movies and plotting { display-mode: "form" }
from matplotlib import animation
from IPython.display import HTML
from tempfile import NamedTemporaryFile
import base64
# Set some plotting defaults
sns.set_context("talk")
# initialize a color palette for plotting
palette = sns.xkcd_palette(["windows blue",
"red",
"medium green",
"dusty purple",
"orange",
"amber",
"clay",
"pink",
"greyish"])
_VIDEO_TAG = """<video controls>
<source src="data:video/x-m4v;base64,{0}" type="video/mp4">
Your browser does not support the video tag.
</video>"""
def _anim_to_html(anim: animation.FuncAnimation,
fps: int = 20) -> str:
"""
Convert a Matplotlib animation object to an HTML video snippet
with an embedded base64-encoded video.
Parameters:
-----------
anim : matplotlib.animation.FuncAnimation
The Matplotlib animation object to encode.
fps : int, optional
Frames per second to use when encoding the animation. Default is 20.
Returns:
--------
str
An HTML string containing the video tag with the encoded video.
"""
if not hasattr(anim, '_encoded_video'):
with NamedTemporaryFile(suffix='.mp4') as f:
anim.save(f.name, fps=fps, extra_args=['-vcodec', 'libx264'])
video = open(f.name, "rb").read()
anim._encoded_video = base64.b64encode(video)
return _VIDEO_TAG.format(anim._encoded_video.decode('ascii'))
def _display_animation(anim: animation.FuncAnimation,
fps: int = 30,
start: int = 0,
stop: Optional[int] = None) -> HTML:
"""
Display a Matplotlib animation by converting it to an HTML video snippet
and closing its figure.
Parameters:
-----------
anim : matplotlib.animation.FuncAnimation
The animation object to display.
fps : int, optional
Frames per second for the displayed animation. Default is 30.
start : int, optional
Starting frame index (currently not used #TODO). Default is 0.
stop : Optional[int], optional
Ending frame index (currently not used #TODO). Default is None.
Returns:
--------
IPython.display.HTML
An HTML snippet containing the video for display.
"""
plt.close(anim._fig)
return HTML(_anim_to_html(anim, fps=fps))
def play(movie: Int[Tensor, "height width num_frames"],
fps: int = FPS,
speedup: int = 1,
fig_height: int = 6):
"""
Create an animation from a movie tensor and return
an HTML snippet for embedding.
Parameters:
-----------
movie : torch.Tensor
A 3D tensor with shape (height, width, num_frames) for the movie frames.
fps : int, optional
Frames per second of the movie. Default is the global FPS constant.
speedup : int, optional
Factor to speed up the animation. Default is 1 (real-time).
fig_height : int, optional
Height of the figure (in inches) for the plot. Default is 6.
Returns:
--------
IPython.display.HTML
An HTML video snippet displaying the animation.
Notes:
------
This function uses Matplotlib's FuncAnimation to create the animation
and embeds it as an HTML video.
"""
# First set up the figure, the axis, and the plot element we want to animate
Py, Px, T = movie.shape
fig, ax = plt.subplots(1, 1, figsize=(fig_height * Px/Py, fig_height))
im = plt.imshow(movie[..., 0], interpolation='None', cmap=plt.cm.gray)
tx = plt.text(0.75, 0.05, 't={:.3f}s'.format(0),
color='white',
fontdict=dict(size=12),
horizontalalignment='left',
verticalalignment='center',
transform=ax.transAxes)
plt.axis('off')
def animate(i: int):
"""
Update function for FuncAnimation.
Parameters:
-----------
i : int
The frame index.
Returns:
--------
tuple
A tuple containing the modified image;
used for efficient animation updates.
"""
im.set_data(movie[..., i * speedup])
tx.set_text("t={:.3f}s".format(i * speedup / fps))
return im,
# call the animator. blit=True means only re-draw the parts that changed.
anim = animation.FuncAnimation(fig, animate,
frames=T // speedup,
interval=1,
blit=True)
plt.close(anim._fig)
# return an HTML video snippet
print("Preparing animation. This may take a minute...")
return HTML(_anim_to_html(anim, fps=30))
def plot_problem_1d(local_correlations: Float[torch.Tensor, "height width"],
filtered_correlations: Float[torch.Tensor, "height width"],
peaks: Int[torch.Tensor, "num_peaks 2"]) -> None:
"""
Plot local correlations, filtered correlations,
and candidate neurons side-by-side.
Parameters
----------
local_correlations : torch.Tensor
2D tensor representing local correlation values,
with shape (height, width).
filtered_correlations : torch.Tensor
2D tensor representing filtered correlation values,
with shape (height, width).
peaks : torch.Tensor
Tensor of shape (num_peaks, 2) where each row is [y, x] coordinates
for peak locations.
Returns
-------
None
Notes
-----
The function displays three panels and overlays circles and labels
on the candidate neuron panel.
"""
def _plot_panel(ax: plt.Axes,
im: Float[torch.Tensor, "height width"],
title: str) -> None:
"""
Plot an individual panel with an image and a colorbar.
Parameters
----------
ax : matplotlib.axes.Axes
The axis on which to plot.
im : torch.Tensor
The image data to display, with shape (height, width).
title : str
Title for the panel.
"""
h = ax.imshow(im, cmap="Greys_r")
ax.set_title(title)
ax.set_xlim(0, width)
ax.set_ylim(height, 0)
ax.set_axis_off()
# add a colorbar of the same height
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad="2%")
plt.colorbar(h, cax=cax)
fig, axs = plt.subplots(1, 3, figsize=(15, 6))
_plot_panel(axs[0], local_correlations, "local correlations")
_plot_panel(axs[1], filtered_correlations, "filtered correlations")
_plot_panel(axs[2], local_correlations, "candidate neurons")
# Draw circles around the peaks
for n, yx in enumerate(peaks):
y, x = yx
axs[2].add_patch(Circle((x, y),
radius=NEURON_WIDTH/2,
facecolor='none',
edgecolor='red',
linewidth=1))
axs[2].text(x, y, "{}".format(n),
horizontalalignment="center",
verticalalignment="center",
fontdict=dict(size=10, weight="bold"),
color='r')
def plot_problem_2(traces: Float[torch.Tensor, "num_neurons num_frames"],
denoised_traces: Float[torch.Tensor, "num_neurons num_frames"],
amplitudes: Float[torch.Tensor, "num_neurons num_frames"]) -> None:
"""
Plot raw and denoised fluorescence traces for neurons
along with amplitude markers.
Parameters
----------
traces : torch.Tensor
2D tensor of raw fluorescence traces with shape (num_neurons, num_frames).
denoised_traces : torch.Tensor
2D tensor of denoised fluorescence traces, same shape as 'traces'.
amplitudes : torch.Tensor
2D tensor of estimated amplitudes, same shape as 'traces'.
Returns
-------
None
Notes
-----
Normalizes the traces by their 99.5th percentile,
offsets the traces for clarity, and highlights amplitude markers.
"""
num_neurons, num_frames = traces.shape
# Plot the traces and our denoised estimates
scale = torch.quantile(traces, .995, dim=1, keepdims=True)
offset = -torch.arange(num_neurons)
# Plot points at the time frames where the (normalized) amplitudes are > 0.05
sparse_amplitudes = amplitudes / scale
# sparse_amplitudes = torch.isclose(sparse_amplitudes, 0, atol=0.05)
sparse_amplitudes[sparse_amplitudes < 0.05] = torch.nan
sparse_amplitudes[sparse_amplitudes > 0.05] = 0.0
plt.figure(figsize=(12, 8))
plt.plot((traces / scale).T + offset , color=palette[0], lw=1, alpha=0.5)
plt.plot((denoised_traces / scale).T + offset, color=palette[0], lw=2)
plt.plot((sparse_amplitudes).T + offset, color=palette[1], marker='o', markersize=2)
plt.xlabel("time (frames)")
plt.xlim(0, num_frames)
plt.ylabel("neuron")
plt.yticks(-torch.arange(0, num_neurons, step=5),
labels=torch.arange(0, num_neurons, step=5).numpy())
plt.ylim(-num_neurons, 2)
plt.title("raw and denoised fluorescence traces")
def plot_problem_3(flat_data: Any,
params: Dict[str, torch.Tensor],
hypers: Any,
plot_bkgd: bool = True,
indices: Optional[Int[torch.Tensor, "num_indices"]] = None) -> None:
"""
Plot neuron footprints and their corresponding fluorescence traces.
Parameters
----------
flat_data : Any
Data used for plotting (currently not used within the function).
params : dict[str, torch.Tensor]
Dictionary containing model parameters with the following keys:
- 'footprints': Neuron footprints, expected to be reshaped to (-1, height, width).
- 'bkgd_footprint': Background footprint reshaped to (height, width).
- 'traces': Fluorescence traces for each neuron.
- 'bkgd_trace': Background trace.
hypers : Any
Hyperparameters related to the model (not directly used in plotting).
plot_bkgd : bool, optional
Whether to include background plots. Default is True.
indices : Optional[torch.Tensor], optional
Indices of neurons to plot; if None, all neurons are plotted.
Returns
-------
None
Notes
-----
Generates a separate plot for each neuron (and background, if requested) showing both the spatial footprint
and its corresponding fluorescence trace.
"""
U = params["footprints"].reshape(-1, height, width)
u0 = params["bkgd_footprint"].reshape(height, width)
C = params["traces"]
c0 = params["bkgd_trace"]
N, T = C.shape
if indices is None:
indices = torch.arange(N)
def _plot_factor(footprint: Float[torch.Tensor, "height width"],
trace: Float[torch.Tensor, "num_frames"],
title: str) -> None:
"""
Plot a neuron's footprint and its corresponding fluorescence trace.
Parameters
----------
footprint : torch.Tensor
2D tensor representing the spatial footprint of a neuron,
with shape (height, width).
trace : torch.Tensor
1D tensor containing the fluorescence trace over time,
with shape (num_frames,).
title : str
Title for the subplot.
"""
fig, ax1 = plt.subplots(1, 1, figsize=(12, 6))
vlim = abs(footprint).max()
h = ax1.imshow(footprint, vmin=-vlim, vmax=vlim, cmap="RdBu")
ax1.set_title(title)
ax1.set_axis_off()
# add a colorbar of the same height
divider = make_axes_locatable(ax1)
cax = divider.append_axes("right", size="5%", pad="2%")
plt.colorbar(h, cax=cax)
ax2 = divider.append_axes("right", size="150%", pad="75%")
ts = torch.arange(T) / FPS
ax2.plot(ts, trace, color=palette[0], lw=2)
ax2.set_xlabel("time (sec)")
ax2.set_ylabel("fluorescence trace")
if plot_bkgd:
_plot_factor(u0, c0, "background")
for k in indices:
_plot_factor(U[k], C[k], "neuron {}".format(k))
Movie of the data#
It takes a minute to render the animation…
# Play the motion corrected movie.
play(data, speedup=5)
Preparing animation. This may take a minute...