Source code for ramutils.tasks.classifier

from __future__ import division

import numpy as np
from classiflib import ClassifierContainer
from sklearn.metrics import roc_auc_score

from ramutils.classifier.cross_validation import permuted_loso_cross_validation, \
    permuted_lolo_cross_validation, perform_cross_validation
from ramutils.classifier.utils import reload_classifier
from ramutils.classifier.utils import train_classifier as train_classifier_core
from ramutils.classifier.weighting import \
    get_sample_weights as get_sample_weights_core
from ramutils.events import extract_sessions, get_nonstim_events_mask, \
    get_encoding_mask, extract_event_metadata
from ramutils.log import get_logger
from ramutils.montage import compare_recorded_with_all_pairs
from ramutils.powers import reduce_powers
from ramutils.reports.summary import ClassifierSummary
from ramutils.tasks import task


logger = get_logger()

__all__ = [
    'get_sample_weights',
    'train_classifier',
    'summarize_classifier',
    'serialize_classifier',
    'post_hoc_classifier_evaluation',
    'reload_used_classifiers'
]


@task()
def get_sample_weights(events, **kwargs):
    sample_weights = get_sample_weights_core(events,
                                             **kwargs)
    return sample_weights


@task()
def train_classifier(pow_mat, events, sample_weights, penalty_param,
                     penalty_type, solver):
    classifier = train_classifier_core(pow_mat, events, sample_weights,
                                       penalty_param, penalty_type, solver)
    return classifier


[docs]@task(cache=False) def serialize_classifier(classifier, pairs, features, events, sample_weights, classifier_summary, subject): """ Serialize classifier into a container object Parameters ---------- classifier: sklearn Estimator Model used during training pairs: array_like bipolar pairs used for training features: np.ndarray Normalized power matrix used as features to the classifier events: np.recarray Set of events used for training sample_weights: array_like Weights used for each of the event classifier_summary: ClassifierSummary Object used for calculating and storing cross-validation-related metrics subject: str Subject identifier Returns ------- ClassififerContainer Object representing all meta-data associated with training a classifier """ container = ClassifierContainer( classifier=classifier, pairs=pairs, features=features, events=events, sample_weight=sample_weights, classifier_info={ 'auc': classifier_summary.auc, 'subject': subject } ) return container
[docs]@task() def summarize_classifier(classifier, pow_mat, events, n_permutations, tag='classifier', **kwargs): """Perform LOSO or LOLO cross validation on a classifier. Parameters ---------- classifier : sklearn model object pow_mat : np.ndarray events : np.recarray n_permutations: int tag: str Tag to assign the resulting classifier summary (default: ``'classifier'``) kwargs: dict Extra keyword arguments that are passed to get_sample_weights. See that function for more details Returns ------- classifier_summary : ClassifierSummary Results of cross validation as a summary object """ recalls = events.recalled encoding_event_mask = get_encoding_mask(events) encoding_recalls = recalls[encoding_event_mask] # Run leave-one-session-out cross validation when we have > 1 session, # otherwise leave-one-list-out subject, experiment, sessions = extract_event_metadata(events) permuted_auc_values, probs = perform_cross_validation(classifier, events, n_permutations, pow_mat, recalls, sessions, **kwargs) classifier_summary = ClassifierSummary() classifier_summary.populate(subject, experiment, sessions, encoding_recalls, probs, permuted_auc_values, frequencies=kwargs.get('freqs'), pairs=kwargs.get('pairs'), tag=tag, features=pow_mat, coefficients=classifier.coef_) logger.info("Permutation test p-value = %f", classifier_summary.pvalue) recall_prob = classifier.predict_proba(pow_mat)[:, 1] insample_auc = roc_auc_score(recalls, recall_prob) logger.info("in-sample AUC = %f", insample_auc) return classifier_summary
[docs]@task() def reload_used_classifiers(subject, experiment, events, root): """ Reload the actual classifiers used in each session of an experiment Parameters ---------- subject: str Subject identifier experiment: str Name of the experiment sessions: list List of sessions to try reloading a classifier root: str Base path of where to find RHINO files Returns ------- list List of ClassifierContainer objects of length n_sessions Notes ----- If a classifier is not found or is unable to be reloaded (legacy storage format, or other issues), then the list of ClassifierContainer objects will have None as the entry for that session. """ used_classifiers = [] sessions = extract_sessions(events) for session in sessions: classifier = reload_classifier(subject, experiment, session, root) used_classifiers.append(classifier) return used_classifiers
[docs]@task() def post_hoc_classifier_evaluation(events, powers, all_pairs, classifiers, n_permutations, retrained_classifier, use_retrained=False, post_stim_events=None, post_stim_powers=None, **kwargs): """ Evaluate a trained classifier Parameters ---------- events: np.recarray Task events associated with the stim sessesion to be evaluated powers: np.ndarray Normalized mean powers all_pairs: OrderedDict All pairs based on recorded electrodes combine from config file classifiers: List List of classifiers corresponding to each session n_permutations: int Number of permutations to use for cross validation retrained_classifier: classiflib.container.ClassifierContainer classifier container object based on a retrained classifier use_retrained: bool (default False) Indicates if the retrained classifier should be used over the actual classifier for the purpose of evaluation post_stim_events: np.recarray or None Post-stimulation events associated with the stim sessesion to be evaluated. Can be done in the case of FR2 where post stim events post_stim_powers: np.ndarray or None Normalized mean powers for post_stim period events Returns ------- dict A dictionary of summary objects that are needed in subsequent parts of the processing pipeline. The dictionary will be in the following format:: { 'cross_session_summary': MultiSessionClassifierSummary, 'classifier_summaries': List of ClassifierSummary objects, 'encoding_classifier_summaries': List of ClassifierSummary objects built using all encoding events, 'post_stim_predicted_probs': Classifier output during post stim period } Notes ----- Different channels could be excluded based on results of artifact detection and stim parameters. Extract the used pairs from the serialized classifier that was used/retrained in order to correctly assess the classifier. The default behavior is to use the retrained classifier for any sessions where the actual classifier was not found or was unable to be loaded. Legacy-formatted classifiers are not supported for re-loading. In cases where a stim session was restarted, the default behavior is to use the original classifier (i.e. the classifier before artifact detection) rather than trying to guess which classifier to load. """ sessions = extract_sessions(events) if len(sessions) != len(classifiers): raise RuntimeError('The number of sessions for evaluation must match ' 'the number of classifiers') if (any([classifier is None for classifier in classifiers]) and retrained_classifier is None): raise RuntimeError('A retrained classifier must be passed if any ' 'sessions have missing classifiers') recalls = events.recalled if post_stim_events is not None: post_stim_recalls = post_stim_events.recalled # Masks for encoding events encoding_mask = get_encoding_mask(events) # This takes care of sub-setting events to encoding non-stim events non_stim_mask = get_nonstim_events_mask(events) non_stim_recalls = recalls[non_stim_mask] classifier_summaries = [] encoding_classifier_summaries = [] predicted_probs = [] post_stim_predicted_probs = [] for i, session in enumerate(sessions): classifier_summary = ClassifierSummary() reloaded = True # Be sure to work with a copy of the classifier object because it will # be re-fit as part of the lolo cross validation and if you pass # a reference, the AUCs will be wacky if (classifiers[i] is None) or (use_retrained): classifier_container = retrained_classifier reloaded = False logger.info( "Using the retrained classifier for session {}".format(session)) else: classifier_container = classifiers[i] logger.info( "Using actual classifier for session {}".format(session)) classifier = classifier_container.classifier recorded_pairs = classifier_container.pairs used_mask = compare_recorded_with_all_pairs(all_pairs, recorded_pairs) session_mask = (events.session == session) session_events = events[(session_mask & non_stim_mask)] session_recalls = recalls[session_mask & non_stim_mask] session_powers = powers[(session_mask & non_stim_mask)] reduced_session_powers = reduce_powers(session_powers, used_mask, len(kwargs['freqs'])) # Manually pass in the weighting scheme here, otherwise the cross # validation procedures will try to determine it for you permuted_auc_values = permuted_lolo_cross_validation(classifier, reduced_session_powers, session_events, n_permutations, scheme='EQUAL', **kwargs) session_probs = classifier.predict_proba(reduced_session_powers)[:, 1] predicted_probs.append(session_probs) # Calculate classifier outputs during the post stim period. This is # used downstream in the reports to see if stimulation affected the # biomarker if post_stim_events is not None: post_stim_session_mask = (post_stim_events.session == session) post_stim_session_powers = post_stim_powers[post_stim_session_mask] post_stim_reduced_session_powers = reduce_powers( post_stim_session_powers, used_mask, len(kwargs['freqs'])) post_stim_probs = classifier.predict_proba( post_stim_reduced_session_powers)[:, 1] post_stim_predicted_probs.append(post_stim_probs) subject, experiment, sessions = extract_event_metadata(session_events) # This is the primary classifier used for evaluation. It is based on # assessing classifier output for non-stim encoding events classifier_summary.populate(subject, experiment, sessions, session_recalls, session_probs, permuted_auc_values, frequencies=classifier_container.frequencies, pairs=kwargs['pairs'][used_mask], tag='session_' + str(session), reloaded=reloaded, features=reduced_session_powers, coefficients=classifier.coef_) classifier_summaries.append(classifier_summary) logger.info('AUC for session {}: {}'.format(session, classifier_summary.auc)) # Get a classifier summary for all encoding events. This classifier # is needed in order to match all encoding events to stim information # in a later step session_encoding_powers = powers[(session_mask & encoding_mask)] reduced_session_encoding_powers = reduce_powers(session_encoding_powers, used_mask, len(kwargs['freqs'])) session_encoding_probs = classifier.predict_proba( reduced_session_encoding_powers)[:, 1] session_encoding_recalls = recalls[session_mask & encoding_mask] encoding_classifier_summary = ClassifierSummary() encoding_classifier_summary.populate(subject, experiment, sessions, session_encoding_recalls, session_encoding_probs, permuted_auc_values=None, frequencies=classifier_container.frequencies, pairs=kwargs['pairs'], tag='encoding_evaluation', features=reduced_session_encoding_powers, coefficients=classifier.coef_) encoding_classifier_summaries.append(encoding_classifier_summary) # Combine session-specific predicted probabilities into 1D array all_predicted_probs = np.array(predicted_probs).flatten() if len(sessions) > 1: permuted_auc_values = permuted_loso_cross_validation( retrained_classifier.classifier, powers, events, n_permutations, scheme='EQUAL', **kwargs) subject, experiment, sessions = extract_event_metadata(events) cross_session_summary = ClassifierSummary() classifier_ = retrained_classifier.classifier if retrained_classifier else classifier cross_session_summary.populate(subject, experiment, sessions, non_stim_recalls, all_predicted_probs, permuted_auc_values, coefficients=classifier_.coef_, frequencies=classifier_container.frequencies, pairs=kwargs['pairs'], tag='Combined Sessions', reloaded=False, features=classifier_container.features) # Leave commented out until we have a way to do multi-stim-session # evaluation, otherwise this classifier is just redundant. # classifier_summaries.append(cross_session_summary) logger.info("Combined AUC: {}".format(cross_session_summary.auc)) result_dict = { 'cross_session_summary': cross_session_summary, 'classifier_summaries': classifier_summaries, 'encoding_classifier_summaries': encoding_classifier_summaries, 'post_stim_predicted_probs': post_stim_predicted_probs } return result_dict