Source code for palmari.image_tools.correct_drift

import logging
from scipy.interpolate import interp1d
import numpy as np
from tqdm import tqdm
import pandas as pd
from dask import delayed
import dask.array as da
from skimage.registration import phase_cross_correlation


[docs]def get_optimal_shift(pos1, pos2, L, step): # objective : find dx, dy to move pos2 on pos1 best = (None, None) bins_x = np.arange(pos1.min() - L, pos1.max() + L, step) bins_y = np.arange(pos2.min() - L, pos2.max() + L, step) hist1, _, _ = np.histogram2d( pos1[:, 0], pos1[:, 1], bins=(bins_x, bins_y), normed=True ) hist2, _, _ = np.histogram2d( pos2[:, 0], pos2[:, 1], bins=(bins_x, bins_y), normed=True ) # subpixel precision shift, error, diffphase = phase_cross_correlation(hist1, hist2) return shift * step """ best_correlation = 0.0 best = np.zeros(2) for dx in np.arange(-L, L, step): for dy in np.arange(-L, L, step): pos2_d = pos2 - np.array([dx, dy]) hist2, _, _ = np.histogram2d( pos2_d[:, 0], pos2_d[:, 1], bins=(bins_x, bins_y), normed=True ) correlation = np.mean(hist2 * hist1) if correlation > best_correlation: best_correlation = correlation best = np.array([dx, dy]) return best """
[docs]def correct_drift( pos, L=0.2, step_size=0.03, prog_bar_position=None, min_n_locs_per_bin: int = 10000, max_n_bins: int = 20, ): pos["n_detection"] = np.arange(pos.shape[0]) if "shift_x" not in pos.columns: pos[["shift_x", "shift_y"]] = np.zeros((pos.shape[0], 2)) bin_size = max(min_n_locs_per_bin, int(pos.shape[0] / max_n_bins)) bins = np.arange(0, pos.n_detection.max() + 1, bin_size) # print("Cut in %d bins" % len(bins)) if bin_size > pos.shape[0] / 2: logging.info("Too few detections to correct drift") interp = None else: pos["n_bin"] = np.digitize(pos.n_detection, bins) sorted_bins = sorted(np.unique(pos.n_bin)) n_steps = np.array( [pos.loc[pos.n_bin == b, "n_detection"].min() for b in sorted_bins] ) pos_1 = [ pos.loc[pos.n_bin == b, ["x", "y"]].copy().values for b in sorted_bins[:-1] ] pos_2 = [ pos.loc[pos.n_bin == b, ["x", "y"]].copy().values for b in sorted_bins[1:] ] shifts_del = [] for p1, p2 in tqdm( zip(pos_1, pos_2), leave=False, unit="bins", total=len(pos_1), position=0 if prog_bar_position is None else prog_bar_position, disable=False, # Disable parce que compliqué en pooling ): shifts_del.append( da.from_delayed( delayed(get_optimal_shift)(p1, p2, L=L, step=step_size), dtype=float, shape=(2,), ) ) shifts_del = da.stack(shifts_del, axis=0) shifts_del = da.concatenate([da.zeros((1, 2)), shifts_del], axis=0) shifts_del = da.cumsum(shifts_del, axis=0) shifts = shifts_del.compute() interp = interp1d( n_steps, shifts, axis=0, bounds_error=False, kind="quadratic" if pos.n_bin.nunique() > 2 else "linear", fill_value=(shifts[0], shifts[-1]), ) interp_values = interp(pos.n_detection.values) pos[["x", "y"]] -= interp_values pos[["shift_x", "shift_y"]] += interp_values frames = np.arange( start=pos.n_detection.min(), stop=pos.n_detection.max(), step=1000 ) interp_df = pd.DataFrame(index=np.arange(len(frames))) interp_df["frame"] = frames interp_df["dx"] = 0.0 interp_df["dy"] = 0.0 if interp is not None: interp_df[["dx", "dy"]] = interp(interp_df["frame"]) return pos, interp_df