Source code for ramutils.powers

""" Helper functions for computing powers from a set of EEG signals """

import numpy as np
import pandas as pd
from ptsa.data.filters import (
    MonopolarToBipolarMapper,
    MorletWaveletFilter
)
from ptsa.data.readers import EEGReader
from scipy.stats import zscore, ttest_ind
from statsmodels.sandbox.stats.multicomp import multipletests
import io
try:
    from typing import List
except ImportError:
    pass

from ramutils.log import get_logger
from ramutils.utils import timer
from ramutils.events import get_recall_events_mask, extract_sessions, \
    partition_events, concatenate_events_for_single_experiment, \
    get_partition_masks
from ramutils.montage import generate_pairs_for_ptsa, extract_monopolar_from_bipolar


logger = get_logger()


def compute_single_session_powers(session, all_events, start_time, end_time,
                                  buffer_time, freqs, log_powers,
                                  filt_order, width, normalize, bipolar_pairs):
    """Compute powers for a single session """
    # PTSA will sometimes modify events when reading the eeg, so we ultimately
    # need to return the updated events. In case no events are removed, return
    # the original set of events
    eeg, updated_events = load_single_session_eeg(session,
                                                  all_events,
                                                  start_time,
                                                  end_time,
                                                  bipolar_pairs)

    eeg = eeg.add_mirror_buffer(buffer_time)

    # Butterworth filter to remove line noise
    eeg = eeg.filtered(freq_range=[58., 62.],
                       filt_type='stop',
                       order=filt_order)
    with timer("Total wavelet decomposition time: %f s"):
        eeg.data = np.ascontiguousarray(eeg.data)
        wavelet_filter = MorletWaveletFilter(timeseries=eeg,
                                                freqs=freqs,
                                                output='power',
                                                width=width,
                                                cpus=25)  # FIXME: why 25?
        # At this point, pow mat has dimensions: frequency, bipolar_pairs,
        # events, time
        sess_pow_mat = wavelet_filter.filter()

    sess_pow_mat = sess_pow_mat.remove_buffer(buffer_time).data + np.finfo(
        np.float).eps/2.

    if log_powers:
        np.log10(sess_pow_mat, sess_pow_mat)

    # Re-ordering dimensions to be events, frequencies, electrodes with the
    # mean calculated over the time dimension
    updated_session_events = updated_events[updated_events.session == session]
    sess_pow_mat = np.nanmean(sess_pow_mat.transpose(2, 1, 0, 3), -1)
    sess_pow_mat = sess_pow_mat.reshape((len(updated_session_events), -1))

    if normalize:
        sess_pow_mat = zscore(sess_pow_mat, axis=0, ddof=1)

    return sess_pow_mat, updated_events


def load_single_session_eeg(session, all_events, start_time, end_time, bipolar_pairs):
    updated_events = all_events
    session_events = all_events[all_events.session == session]
    logger.info("Loading EEG data for session %d", session)
    eeg_reader = EEGReader(events=session_events,
                           start_time=start_time,
                           end_time=end_time,
                           )
    try:
        eeg = eeg_reader.read()
    # recording was done in bipolar mode, and the channels are different than
    # what we expect
    except IndexError:
        eeg_reader.channels = np.array([])
        eeg = eeg_reader.read()
    if eeg_reader.removed_bad_data():
        logger.warning('PTSA EEG reader elected to remove some bad events')
        # TODO: Use the event utility functions here
        updated_events = np.rec.array(np.concatenate(
            (all_events[all_events.session != session],
             np.rec.array(eeg['events'].data))))
        event_fields = updated_events.dtype.names
        order = tuple(f for f in ['session', 'list',
                                  'mstime'] if f in event_fields)
        ev_order = np.argsort(updated_events, order=order)
        updated_events = updated_events[ev_order]
        updated_events = np.rec.array(updated_events)

    # Use bipolar pairs if they exist and recording is not already bipolar
    if eeg.channels.dtype.names is None:
        monopolar_channels = extract_monopolar_from_bipolar(bipolar_pairs)
        eeg_reader = EEGReader(events=session_events,
                               start_time=start_time,
                               end_time=end_time,
                               channels=monopolar_channels)
        eeg = eeg_reader.read()
        # Check for removal of bad data again and update events
        if eeg_reader.removed_bad_data():
            logger.warning('PTSA EEG reader elected to remove some bad events')
            # TODO: Use the event utility functions here
            updated_events = np.rec.array(np.concatenate(
                (all_events[all_events.session != session],
                 np.rec.array(eeg['events'].data))))
            event_fields = updated_events.dtype.names
            order = tuple(
                f for f in ['session', 'list', 'mstime'] if f in event_fields)
            ev_order = np.argsort(updated_events, order=order)
            updated_events = updated_events[ev_order]
            updated_events = np.rec.array(updated_events)

        eeg = MonopolarToBipolarMapper(timeseries=eeg,
                                       bipolar_pairs=bipolar_pairs).filter()
    return eeg, updated_events


def compute_powers(events, start_time, end_time, buffer_time, freqs,
                   log_powers, filt_order=4, width=5,
                   normalize=True, bipolar_pairs=None):
    """
        Compute powers (or log powers) using a Morlet wavelet filter and
        Butterworth Filter to get rid of line noise


    Parameters
    ----------
    events: np.recarray
        Events to consider when computing powers
    start_time: float
        Start of the period in the EEG to consider for each event
    end_time: float
        End of the period to consider
    buffer_time: float
        Buffer time
    freqs: array_like
        List of frequencies to use when applying Wavelet Filter
    log_powers: bool
        Whether to take the logarithm of the powers
    filt_order: Int
        Filter order to use in Butterworth filter
    width: Int
        Wavelet width to use in Wavelet Filter
    normalize: bool
        Whether power matrix should be zscored using mean and std. dev by
        electrode (row)
    bipolar_pairs: OrderedDoct
        OrderedDict of bipolar pairs to use if converting a monopolar EEG
        recording to bipolar recording

    Returns
    -------
    np.ndarray
        Calculated powers of shape n_events X (freqs * n_channels) where
        n_channels is determined when loading the EEG
    np.recarray
        Set of events after 'bad' events were removed while loading the EEG.
        Currently, removal of these events is a side effect of loading the
        EEG, so the 'cleaned' events must be caught. In an ideal world,
        this side effect would not exist and bad events would be remove prior
        to computing powers.
    -------

    """
    if (bipolar_pairs is not None) and \
            (not isinstance(bipolar_pairs, np.recarray)):
        bipolar_pairs = generate_pairs_for_ptsa(bipolar_pairs)

    sessions = np.unique(events.session)
    pow_mat = None

    with timer("Total time for computing powers: %f"):
        updated_events = events.copy()
        for sess in sessions:
            powers, updated_events = compute_single_session_powers(sess,
                                                                   updated_events,
                                                                   start_time,
                                                                   end_time,
                                                                   buffer_time,
                                                                   freqs,
                                                                   log_powers,
                                                                   filt_order,
                                                                   width,
                                                                   normalize,
                                                                   bipolar_pairs)
            pow_mat = powers if pow_mat is None else np.concatenate((pow_mat,
                                                                     powers))

    return pow_mat, updated_events


[docs]def compute_normalized_powers(events, **kwargs): """ Compute powers by session, encoding/retrieval, and FR vs. PAL Notes ----- There are different start times, end time, and buffer times for each subset type, so those are passed in as kwargs and looked up prior to calling the more general compute_powers function """ event_partitions = partition_events(events) cleaned_event_partitions = [] power_partitions = {} if 'bipolar_pairs' not in kwargs.keys(): kwargs['bipolar_pairs'] = None for subset_name, event_subset in event_partitions.items(): if len(event_subset) == 0: continue if subset_name == 'fr_encoding': start_time = kwargs['encoding_start_time'] end_time = kwargs['encoding_end_time'] buffer_time = kwargs['encoding_buf'] elif subset_name == 'fr_retrieval': start_time = kwargs['retrieval_start_time'] end_time = kwargs['retrieval_end_time'] buffer_time = kwargs['retrieval_buf'] elif subset_name == 'pal_encoding': start_time = kwargs['pal_start_time'] end_time = kwargs['pal_end_time'] buffer_time = kwargs['pal_buf_time'] elif subset_name == 'pal_retrieval': start_time = kwargs['pal_retrieval_start_time'] end_time = kwargs['pal_retrieval_end_time'] buffer_time = kwargs['pal_retrieval_buf'] elif subset_name == 'post_stim': start_time = kwargs['post_stim_start_time'] end_time = kwargs['post_stim_end_time'] buffer_time = kwargs['post_stim_buf'] else: raise RuntimeError("Unexpected event subset was encountered") powers, cleaned_events = compute_powers(event_subset, start_time, end_time, buffer_time, kwargs['freqs'], kwargs['log_powers'], filt_order=kwargs['filt_order'], normalize=kwargs[ 'normalize_powers'], width=kwargs['width'], bipolar_pairs=kwargs[ 'bipolar_pairs']) cleaned_event_partitions.append(cleaned_events) power_partitions[subset_name] = powers cleaned_events = concatenate_events_for_single_experiment( cleaned_event_partitions) partition_masks = get_partition_masks(cleaned_events) # Ensure that the rows of the power matrix match the order of the events. # This works by creating masks for each of the event types from the # sorted events structure n_features = powers.shape[1] normalized_powers = np.empty((len(cleaned_events), n_features)) for subset_name, power_subset in power_partitions.items(): partition_event_mask = partition_masks[subset_name] normalized_powers[partition_event_mask, :] = power_subset return normalized_powers, cleaned_events
[docs]def reduce_powers(powers, channel_mask, n_frequencies, frequency_mask=None): """ Create a subset of the full power matrix by excluding certain electrodes Parameters ---------- powers: np.ndarray Original power matrix channel_mask: array_like Boolean array of size n_channels n_frequencies: int Number of frequencies used in calculating the power matrix. This is needed to be able to properly reshape the array frequency_mask: array_like Boolean array of size n_frequencies Returns ------- np.ndarray Subsetted power matrix """ if frequency_mask is not None and (len(frequency_mask) != n_frequencies): raise RuntimeError("Size of frequency mask must match number of " "frequencies") # Reshape into 3-dimensional array (n_events, n_electrodes, n_frequencies) reduced_powers = powers.reshape((len(powers), -1, n_frequencies)) if frequency_mask is not None: reduced_powers = reduced_powers[:, channel_mask, frequency_mask] else: reduced_powers = reduced_powers[:, channel_mask, :] # Reshape back to 2D representation so it can be used as a feature matrix reduced_powers = reduced_powers.reshape((len(reduced_powers), -1)) return reduced_powers
[docs]def get_trigger_frequency_mask(trigger_frequency, frequencies): """ Returns a boolean mask identifying a single frequency in a list of frequencies """ return [True if int(freq) == trigger_frequency else False for freq in frequencies]
def normalize_powers_by_session(pow_mat, events): """ z-score powers within session. Utility function used by legacy reports Parameters ---------- pow_mat: np.ndarray Power matrix, i.e. the data matrix for the classifier (features) events: pd.DataFrame Behavioral events data Returns ------- pow_mat: np.ndarray Normalized power matrix (features) Notes ----- This function can be removed once the legacy reporting pipeline is fully replaced since those are the only places where it is currently used """ sessions = np.unique(events.session) for sess in sessions: sess_event_mask = (events.session == sess) pow_mat[sess_event_mask] = zscore(pow_mat[sess_event_mask], axis=0, ddof=1) return pow_mat def reshape_powers_to_3d(powers, n_frequencies): """ Make power matrix a 3D structure: n_events x n_electrodes x n_frequencies """ reshaped_powers = powers.reshape((len(powers), -1, n_frequencies)) return reshaped_powers def reshape_powers_to_2d(powers): """ Make power matrix a 2D structure n_events x (n_electrodes x n_frequencies) """ reshaped_powers = powers.reshape((len(powers), -1)) return reshaped_powers def save_power_plot(powers, session, full_path): """ Plots the feature matrix to a file path or file-like object :param powers: :param full_path: :return: """ from matplotlib import pyplot as plt plt.imshow(reshape_powers_to_2d(powers), cmap='RdBu_r', aspect='auto',) cmin, cmax = powers.min(), powers.max() clim = max(abs(cmin), abs(cmax)) plt.clim(-clim, clim) cbar = plt.colorbar() cbar.ax.set_xlabel('Z-Score') cbar.ax.xaxis.set_label_position('top') plt.ylabel('Event Number') plt.xlabel('Feature Number') plt.title('Session %s' % session) plt.savefig(full_path, format="png", dpi=300, bbox_inches="tight", ) plt.close() return full_path def load_eeg(all_events, start_time, end_time, bipolar_pairs): if (bipolar_pairs is not None) and \ (not isinstance(bipolar_pairs, np.recarray)): bipolar_pairs = generate_pairs_for_ptsa(bipolar_pairs) full_eeg = [] for session in np.unique(all_events.session): eeg, _ = load_single_session_eeg(session, all_events, start_time, end_time, bipolar_pairs) time = eeg.time.values full_eeg.append(eeg) full_eeg = np.concatenate([e.data for e in full_eeg], axis=1) return full_eeg def save_eeg_by_channel_plot(bipolar_pairs, full_eeg, used_pair_mask=None, time=None, full_path=None): from matplotlib import pyplot as plt if full_path is None: full_path = io.BytesIO() if time is None: time = np.arange(full_eeg.shape[-1]) ylen = int(np.sqrt(full_eeg.shape[0])) xlen = int(len(bipolar_pairs) / ylen) plt.figure(figsize=(xlen*2, ylen*2)) for i in range(0, len(bipolar_pairs)): plt.subplot(xlen, ylen, i + 1) txtcolor='black' if used_pair_mask is not None and not used_pair_mask[i]: ax = plt.gca() txtcolor='magenta' for spine in ax.spines.values(): spine.set_linewidth(6*spine.get_linewidth()) spine.set_edgecolor('magenta') plt.plot(time, full_eeg[i].squeeze().T, color='grey', alpha=0.15) plt.xlabel('%s' % (bipolar_pairs[i]), color=txtcolor) plt.tight_layout() plt.savefig(full_path, format='png', dpi=200, bbox_inches='tight') plt.close() return full_path def calculate_delta_hfa_table(pairs_metadata_table, normalized_powers, events, frequencies, hfa_cutoff=65, trigger_freq=110): """ Calculate tstats and pvalues from a ttest comparing HFA activity of recalled versus non-recalled items """ powers_3d = reshape_powers_to_3d(normalized_powers, len(frequencies)) hfa_mask = [True if freq > hfa_cutoff else False for freq in frequencies] hfa_powers = powers_3d[:, :, hfa_mask] # Average powers across frequencies. New shape is n_events x n_electrodes hfa_powers = np.nanmean(hfa_powers, axis=-1) recall_mask = get_recall_events_mask(events) recalled_pow_mat = hfa_powers[recall_mask, :] non_recalled_pow_mat = hfa_powers[~recall_mask, :] tstats, pvals = ttest_ind(recalled_pow_mat, non_recalled_pow_mat, axis=0) sig_mask, pvals, _, _ = multipletests(pvals, method='fdr_bh') pairs_metadata_table['hfa_t_stat'] = tstats pairs_metadata_table['hfa_p_value'] = pvals # Repeat for 110hz. Actual frequency is a decimal, so convert to int when # checking for equality trigger_freq_mask = [True if int(freq) == trigger_freq else False for freq in frequencies] single_freq_powers = powers_3d[:, :, trigger_freq_mask] single_freq_powers = np.nanmean(single_freq_powers, axis=-1) recalled_single_freq_powers = single_freq_powers[recall_mask, :] non_recalled_single_freq_powers = single_freq_powers[~recall_mask, :] tstats, pvals = ttest_ind(recalled_single_freq_powers, non_recalled_single_freq_powers, axis=0) sig_mask, pvals, _, _ = multipletests(pvals, method='fdr_bh') pairs_metadata_table['110_t_stat'] = tstats pairs_metadata_table['110_p_value'] = pvals # Pairs that do not have a label do not need to have the stats displayed pairs_metadata_table = pairs_metadata_table.dropna(subset=['label']) return pairs_metadata_table