"""
Multi-subject inverse problem.
This module fits multi-subject inverse problem models on real or simulated
M-EEG data.
"""
import numpy as np
import mne
from mutar import DirtyModel, IndLasso, ReMTW, MTW, GroupLasso, IndRewLasso
from . import utils
from .solvers import _gl_wrapper
from .utils import _compute_ground_metric
def _coefs_to_stcs(coefs, group_info, tmin, tstep):
stcs = []
vertices_lh = group_info["vertno_lh"]
vertices_rh = group_info["vertno_rh"]
subjects = group_info["subjects"]
n_subjects = coefs.shape[-1]
for ii in range(n_subjects):
v = [vertices_lh[ii], vertices_rh[ii]]
subject = subjects[ii]
stc = utils._make_stc(coefs[:, :, ii].T, v, tmin=tmin,
tstep=tstep, subject=subject)
stcs.append(stc)
return stcs
def _method_to_solver(method):
if method == "lasso":
return IndLasso
elif method == "relasso":
return IndRewLasso
elif method == "mtw":
return MTW
elif method == "remtw":
return ReMTW
elif method == "dirty":
return DirtyModel
elif method == "multitasklasso":
return GroupLasso
else:
raise ValueError("Method %s not recognized." % method)
def _method_to_str(method):
if method == "lasso":
return "mutar.IndLasso"
if method == "relasso":
return "mutar.IndRewLasso"
elif method == "mtw":
return "mutar.MTW"
elif method == "remtw":
return "mutar.ReMTW"
elif method == "dirty":
return "mutar.DirtyModel"
elif method == "multitasklasso":
return "mutar.GroupLasso"
else:
raise ValueError("Method %s not recognized." % method)
def _check_evokeds(evokeds):
times = evokeds[0].times
for ii, evoked in enumerate(evokeds[1:]):
current_times = evoked.times
if times.shape != current_times.shape:
raise ValueError("Subject number %d has a times array with a "
"different length. Please provide evokeds data "
"with the same shape." % (ii + 1))
if not (times == current_times).all():
raise ValueError("Subject number %d has a times array with "
"different time coordinates. Please provide "
"evokeds data with the same times "
"array." % (ii + 1))
times = current_times.copy()
def _get_common_sel(evokeds, noise_covs, fwds):
selections = []
for ev, cov, fwd in zip(evokeds, noise_covs, fwds):
all_channels = fwd["sol"]["row_names"]
ch_names = utils._get_channels(fwd, noise_cov=cov, evoked=ev)
sel = utils._ch_names_to_sel(all_channels, ch_names)
selections.append(sel)
sel = set(selections[0]).intersection(*selections[1:])
sel = list(sel)
return sel
def _whiten_data(fwds, evokeds, noise_covs, depth):
"""Whiten the evokeds."""
n_subjects = len(fwds)
sel = _get_common_sel(evokeds, noise_covs, fwds)
meeg_w = []
gains_w = []
for ii in range(n_subjects):
ev = evokeds[ii]
all_channels = list(np.array(fwds[ii]["sol"]["row_names"])[sel])
W, _ = mne.cov.compute_whitener(noise_covs[ii], ev.info, all_channels,
pca=False, verbose=False)
W = W[sel, :][:, sel]
gain = fwds[ii]["sol_group"]["data"][sel]
gains_w.append((ev.nave) ** 0.5 * W.dot(gain))
meeg_w.append((ev.nave) ** 0.5 * W.dot(ev.data[sel]))
meeg_data = np.stack(meeg_w, axis=0)
gains_w = np.stack(gains_w, axis=0)
weights = np.linalg.norm(gains_w, axis=1) ** depth
gains_w = gains_w / weights[:, None, :]
return gains_w, meeg_data, weights
def _check_solver(method, spatiotemporal):
if method not in ["multitasklasso", "dirty", "mtw", "remtw", "lasso",
"relasso"]:
raise ValueError("%s is not a valid method. `method` must be one "
"of 'multitasklasso', 'dirty', 'mtw', 'remtw',"
" 'lasso', 'relasso'" % method)
if method != "multitasklasso" and spatiotemporal:
raise ValueError("%s is not feasible as a time dependent method."
"Use Group Lasso for an L2 over the time axis or"
" set `spatiotemporal` to `True`." % method)
def _check_solver_params(fwds, method, solver_kwargs, gains_scaled, meeg,
spatiotemporal):
n_subjects, n_channels, n_times = meeg.shape
n_features = gains_scaled.shape[-1]
# alpha is necessary for all models
if "alpha" not in solver_kwargs.keys():
solver_kwargs["alpha"] = 0.2
# beta is necessary for dirty and ot models
if method not in ["lasso", "multitasklasso", "relasso"]:
if "beta" not in solver_kwargs.keys():
solver_kwargs["beta"] = 0.2
# ground metric and ot hyperparameters for ot models
if method in ["mtw", "remtw"]:
if "concomitant" not in solver_kwargs.keys():
solver_kwargs["concomitant"] = True
if "M" not in solver_kwargs.keys():
print("Computing OT ground metric ...")
src_ref = fwds[0]["sol_group"]["src_ref"]
_group_info = fwds[0]["sol_group"]["group_info"]
M = _compute_ground_metric(src_ref, _group_info)
solver_kwargs["M"] = M
else:
M = solver_kwargs["M"]
if len(M) != n_features or len(M.T) != n_features:
raise ValueError("The ground metric M must be an array"
"(%s, %s); got (%s, %s)"
% (n_features, n_features, *M.shape))
if M.min() < 0.:
raise ValueError("The ground metric M must be non-negative"
"got M.min() = %s"
% M.min())
M /= np.median(M)
solver_kwargs["M"] = M
if "gamma" not in solver_kwargs.keys():
gamma = solver_kwargs["M"].max()
solver_kwargs["gamma"] = gamma
if "epsilon" not in solver_kwargs.keys():
epsilon = 100. / n_features
solver_kwargs["epsilon"] = epsilon
xty = np.array([g.T.dot(m) for g, m in zip(gains_scaled, meeg)])
# rescale l12 norm penalty
if method in ["multitasklasso", "dirty"]:
if not spatiotemporal:
alphamax = np.linalg.norm(xty, axis=0).max() / n_channels
solver_kwargs["alpha"] *= alphamax
# rescale l1 norm penalty
if method in ["dirty", "mtw", "remtw", "lasso", "relasso"]:
betamax = abs(xty).max() / n_channels
if method in ["lasso", "relasso"]:
alpha_ = betamax * np.ones(n_subjects)
solver_kwargs["alpha"] *= alpha_
else:
solver_kwargs["beta"] *= betamax
return solver_kwargs
def _apply_solver(gains_scaled, meeg, method, spatiotemporal, verbose,
**solver_kwargs):
"""Apply time independent solver."""
n_subjects, n_channels, n_times = meeg.shape
if spatiotemporal:
meeg = np.swapaxes(meeg, 1, 2).reshape(-1, n_channels)
gains_scaled = np.tile(gains_scaled, (n_times, 1, 1))
gty = np.array([g.T.dot(m) for g, m in zip(gains_scaled, meeg)])
alphamax = np.linalg.norm(gty, axis=0).max(axis=0)
solver_kwargs["alpha"] *= alphamax / n_channels
coefs, residuals, loss, dg = _gl_wrapper(gains_scaled, meeg,
**solver_kwargs)
coefs = coefs.reshape(-1, n_subjects, n_times).T
coefs = np.swapaxes(coefs, 1, 2)
log = dict(dualgap=dg, loss=loss, residuals=residuals)
else:
solver = _method_to_solver(method)
n_features = gains_scaled.shape[-1]
n_subjects, n_channels, n_times = meeg.shape
coefs = np.empty((n_times, n_features, n_subjects))
for t in range(n_times):
if verbose:
print("Solving for time point {} / {}".format(t + 1, n_times))
estim = solver(fit_intercept=False, normalize=False,
**solver_kwargs)
estim.fit(gains_scaled, meeg[:, :, t])
assert estim.coef_.shape == (n_features, n_subjects)
coefs[t] = estim.coef_
log = dict()
return coefs, log
def _check_fwds(fwds):
"""Check whether fwds were prepared."""
for fwd in fwds:
if "sol_group" not in fwd.keys():
raise ValueError("`groupmne.prepare_fwds` must be called before "
"to compute a group inverse.")
[docs]def compute_group_inverse(fwds, evokeds, noise_covs, method="multitasklasso",
depth=0.8, spatiotemporal=False, verbose=True,
**solver_kwargs):
"""Compute inverse solution for a group of subjects.
Parameters
----------
fwds: list of `mne.Forward`.
Forward soluton of each subject.
evokeds: list of `mne.Evokeds`
Evoked object of each subject.
noise_covs: list of `mne.Covariance`
Noise covariance of each subject.
method: str
Model used for the joint prior. Must be one of ('lasso', 'relasso',
'multitasklasso', 'dirty', 'mtw', 'remtw').
depth: float.
How to weight (or normalize) the forward using a depth prior.
If float (default 0.8), it acts as the depth weighting exponent (exp)
to use, which must be between 0 and 1. None is equivalent to 0,
meaning no depth weighting is performed.
spatiotemporal: boolean.
If True, apply a spatiotemporal prior on the source estimates.
Only for method = `multitasklasso`.
solvers_kwargs: additional keyword arguments passed to the solver.
Returns
-------
stcs: list of `mne.SourceEstimates`.
Source estimates.
"""
if len(evokeds) != len(fwds):
raise ValueError("The number of evokeds is not equal to the number "
"of forwards.")
_check_solver(method, spatiotemporal)
_check_evokeds(evokeds)
_check_fwds(fwds)
gains, meeg, weights = _whiten_data(fwds, evokeds, noise_covs, depth)
# Check hyperparameters for all models and rescale them to 0-1
solver_kwargs = _check_solver_params(fwds, method, solver_kwargs,
gains, meeg,
spatiotemporal)
stc_data, log = _apply_solver(gains, meeg, method,
spatiotemporal, verbose=verbose,
**solver_kwargs)
# re-scale coefs and change units to nAm
stc_data = np.array(stc_data) * 1e9 / weights.T[None, :, :]
tmin = evokeds[0].times[0]
if len(evokeds[0].times) > 1:
tstep = evokeds[0].times[1] - tmin
else:
tstep = 0.01
stcs = _coefs_to_stcs(stc_data, fwds[0]["sol_group"]["group_info"],
tmin=tmin, tstep=tstep)
return stcs