#!/usr/bin/env python
"""
mask.py -- apply binary masks to an SPT movie
"""
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.path import Path
from scipy import interpolate
from scipy.spatial import distance_matrix
from tqdm import tqdm
[docs]class MaskInterpolator(object):
"""
Given a set of 2D masks, interpolate the masks between frames to
generate an approximation to the mask for intermediate frames.
In more detail:
We have a set of 2D masks defined in frames F0, F1, ..., Fn. Masks
are defined as a ringlike, connected set of points. We wish to use
the masks defined in frames F(i) and F(i+1) to estimate the shape
of the mask in any intermediate frame, assuming the mask varies
smoothly/linearly between frame F(i) and F(i+1).
Instantiation of the LinearMaskInterpolator() accomplishes this,
using either linear or spline interpolation. The resulting object
can then be passed any frame index between F0 and Fn and will
generate the corresponding 2D mask.
The resulting object can be passed a frame index, and will generate
the corresponding 2D mask.
init
----
mask_edges : list of 2D ndarray of shape (n_points, 2),
the Y and X coordinates for each mask at
each frame
mask_frames : list of int, the frame indices corresponding
to each mask edge
n_vertices : int, the number of vertices to use per
interpolated mask
interp : str, "linear" or "cubic", the type of
interpolation to use. Note that at least 4
masks are required for cubic spline interpolation.
plot : bool, show a plot of the vertex matching between
interpolated masks during initialization, for QC
methods
-------
__call__ : determine whether each of a set of points lies
inside or outside the mask
interpolate : generate an interpolated mask edge for an arbitrary
frame lying between the minimum and maximum frames
for this object
"""
def __init__(self, mask_edges, mask_frames, n_vertices=101,
interp_kind="linear", plot=True):
assert len(mask_edges) == len(mask_frames)
assert interp_kind in ["linear", "cubic"]
self.mask_edges = mask_edges
self.mask_frames = mask_frames
self.interp_kind = interp_kind
self.n_frames = len(self.mask_frames)
self.n_vertices = n_vertices
self.plot = plot
# If passed only a single frame, then the interpolator
# always returns a simple static 2D shape
self.static = (self.n_frames == 1)
# Otherwise, generate interpolator objects to reconstruct
# the mask for any other frame
if not self.static:
self._generate_mask_matches()
self._generate_interpolators()
def __call__(self, points, frame_indices, progress_bar=False):
"""
Given a set of points, determine whether each point lies inside
or outside the present mask.
args
----
points : 2D ndarray of shape (n_points, 2), the
YX coordinates for each point
frame_indices : 1D ndarray of shape (n_points,), the
frame indices corresponding to each point
progress_bar : bool, show a progress bar
returns
-------
1D ndarray of shape (n_points,), dtype bool
"""
assert points.shape[0] == frame_indices.shape[0]
# Format as ndarray
if isinstance(points, pd.DataFrame):
points = np.asarray(points)
if isinstance(frame_indices, pd.Series):
frame_indices = np.asarray(frame_indices)
unique_frames = np.unique(frame_indices)
if progress_bar:
unique_frames = tqdm(unique_frames)
assignments = np.empty(points.shape[0])
for frame_index in unique_frames:
mask = Path(self.interpolate(frame_index), closed=True)
in_frame = frame_indices == frame_index
assignments[in_frame] = mask.contains_points(points[in_frame,:])
return assignments
[docs] def interpolate(self, frame):
"""
Interpolate the mask edges for a given frame index.
args
----
frame : int, the frame index
returns
-------
2D ndarray of shape (self.n_vertices, 2), the YX
coordinates for the points along the edge of
the mask
"""
if self.static:
return self.mask_edges[0]
else:
# Frame lies outside of interpolable range
if not self._within_interpolation_range(frame):
return self.mask_edges[:,-1,:]
# Otherwise interpolate the mask
result = np.empty((self.n_vertices, 2), dtype=np.float64)
result[:,0] = [self.y_interpolators[i](frame) for i in range(self.n_vertices)]
result[:,1] = [self.x_interpolators[i](frame) for i in range(self.n_vertices)]
return result
def _within_interpolation_range(self, frame):
"""
Return True if the frame lies inside the interpolation
range for this MaskInterpolator instance.
"""
return (frame >= self.mask_frames[0]) and (frame <= self.mask_frames[-1])
def _generate_mask_matches(self):
"""
Given the set of points that defines the edge of each mask
used to instantiate this object, upsample the masks to the
same number of points and match each point with the points
in other masks for subsequent interpolation.
Resets the self.mask_edges attribute to the result.
"""
# The total number of sets of mask edges to match
n_masks = len(self.mask_edges)
# The final set of matched points, appropriate for
# interpolation
result = np.zeros((self.n_vertices, n_masks, 2), dtype=np.float64)
# Upsample each of the masks to the same number of points
for j in range(n_masks):
self.mask_edges[j] = upsample_2d_path(self.mask_edges[j],
kind="cubic", n_vertices=self.n_vertices)
result[:,0,:] = self.mask_edges[0]
# For each sequential combination of masks, match the
# interpolated mask points
for j in range(1, n_masks):
result[:,j,:] = match_vertices(
result[:,j-1,:],
self.mask_edges[j],
method="global",
plot=self.plot
)
self.mask_edges = result
def _generate_interpolators(self):
"""
For each point that defines the edge of this mask, generate Y
and X interpolators that enable a mask to be reconstructed
for a given image frame.
"""
self.y_interpolators = []
self.x_interpolators = []
for vertex in range(self.n_vertices):
# Generate the interpolator object for the y-index
I = interpolate.interp1d(self.mask_frames, self.mask_edges[vertex, :, 0],
kind=self.interp_kind)
self.y_interpolators.append(I)
# Generate the interpolator object for the y-index
I = interpolate.interp1d(self.mask_frames, self.mask_edges[vertex, :, 1],
kind=self.interp_kind)
self.x_interpolators.append(I)
[docs]def upsample_2d_path(points, kind="cubic", n_vertices=101):
"""
Upsample a 2D path by interpolation.
args
----
points : 2D ndarray of shape (n_points, 2), the
Y and X coordinates of each point in the
path, organized sequentially
kind : str, the kind of spline interpolation
n_vertices : int, the number of points to use in the
upsampled path
returns
-------
2D ndarray of shape (n_vertices, 2), the upsampled
path
"""
P = np.concatenate((points, np.array([points[0,:]])), axis=0)
t = np.arange(P.shape[0])
fy = interpolate.interp1d(t, P[:,0], kind=kind)
fx = interpolate.interp1d(t, P[:,1], kind=kind)
result = np.empty((n_vertices, 2), dtype=np.float64)
new_t = np.linspace(0, t.max(), n_vertices)
result[:,0] = fy(new_t)
result[:,1] = fx(new_t)
return result
[docs]def shoelace(points):
"""
Shoelace algorithm for computing the oriented area of a 2D
polygon. This area is positive when the points that define
the polygon are arranged counterclockwise, and negative
otherwise.
args
----
points : 2D ndarray, shape (n_points, 2), the
vertices of the polygon
returns
-------
float, the oriented volume of the polygon defined by
*points*
"""
return ((points[1:,0] - points[:-1,0]) * \
(points[1:,1] + points[:-1,1])).sum()
[docs]def circshift(points, shift):
"""
Circularly shift a set of points.
args
----
points : ndarray of shape (n_points, D), the
D-dimensional coordinates of each point
shift : int, the index of the new starting point
returns
-------
2D ndarray of shape (n_points, D), the same points
but circularly shifted
example
-------
points_before = np.array([
[1, 2],
[3, 4],
[5, 6]
])
points_after = circshift(points_before, 1)
points_after -> np.array([
[3, 4],
[5, 6],
[1, 2]
])
"""
out = np.empty(points.shape, dtype=points.dtype)
n = out.shape[0]
shift = shift % n
if len(points.shape) == 1:
out[:n-shift] = points[shift:]
out[n-shift:] = points[:shift]
elif len(points.shape) == 2:
out[:n-shift,:] = points[shift:,:]
out[n-shift:,:] = points[:shift,:]
return out
[docs]def match_vertices(vertices_0, vertices_1, method="closest", plot=False):
"""
Given two polygons with the same number of vertices, match
each vertex in the first polygon with the "closest" vertex
in the second polygon.
"closest" is in quotation marks here because, before matching,
we align the two polygons by their mean position, so that the
same match is returned regardless of whole-polygon shifts.
args
----
vertices_0 : 2D ndarray, shape (n_points, 2), the
YX coordinates for the vertices of the
first polygon
vertices_1 : 2D ndarray, shape (n_points, 2), the
YX coordinates for the vertices of the
second polygon
method : str, the method to use to match vertices.
"closest": use the closest point between
the two masks as the anchor point.
"global": use the permutation that minimizes
the total distance between the two sets
of vertices.
plot : bool, show the result
returns
-------
2D ndarray, shape (n_points, 2), the vertices of the
second polygon circularly permuted to line them up
with the matching vertex in the first polygon
"""
assert vertices_0.shape == vertices_1.shape
assert method in ["closest", "global"], "method must be either 'closest' or 'global'"
# The final assignments
indices_1 = np.arange(vertices_1.shape[0])
# Deal only with the positions of each vertex relative to the
# respective mean
P0 = vertices_0 - vertices_0.mean(axis=0)
P1 = vertices_1 - vertices_1.mean(axis=0)
# Make sure the points both proceed in the same direction (CW or CCW)
if shoelace(P0) * shoelace(P1) < 0:
P1 = P1[::-1,:]
indices_1 = indices_1[::-1]
# Align masks by simply looking for two vertices that are closest
if method == "closest":
distances = distance_matrix(P0, P1)
m = np.argmin(distances.ravel())
y, x = m // P0.shape[0], m % P0.shape[0]
shift = (x - y) % P0.shape[0]
# Align masks by minimizing the sum of the distances between all
# vertices for all possible matches
elif method == "global":
curr = 0
ss = np.inf
for x in range(P1.shape[0]):
shift_P1 = circshift(P1, x)
tot_dist = ((P0 - shift_P1)**2).sum()
if tot_dist < ss:
curr = x
ss = tot_dist
shift = curr
# Align the masks
indices_1 = circshift(indices_1, shift)
vertices_1 = vertices_1[indices_1, :]
# Show the resulting set of vertex matches, if desired
if plot:
fig, ax = plt.subplots(figsize=(3, 3))
ax.scatter(vertices_0[:,0], vertices_0[:,1], cmap="viridis",
c=np.arange(vertices_0.shape[0]))
ax.scatter(vertices_1[:,0], vertices_1[:,1], cmap="viridis",
c=np.arange(vertices_1.shape[0]))
for j in range(vertices_0.shape[0]):
ax.plot([vertices_0[j,0], vertices_1[j,0]], [vertices_0[j,1], vertices_1[j,1]],
color="k", linestyle='-')
ax.set_aspect('equal')
ax.set_title("Mask alignment")
plt.show()
return vertices_1