Classifier training, cross validation, and utilities

Reference

Collection of cross-validation functions

ramutils.classifier.cross_validation.permuted_lolo_cross_validation(classifier, powers, events, n_permutations, **kwargs)[source]

Permuted leave-one-list-out cross validation

Parameters:
  • classifier – sklearn model object, usually logistic regression classifier
  • powers (np.ndarray) – power matrix
  • events (recarray) –
  • n_permutations (int) – number of permutation trials
  • kwargs (dict) – Optional keyword arguments. These are passed to get_sample_weights. See that function for more details.
Returns:

AUCs – List of AUCs from performing leave-one-list-out cross validation n_permutations times where the AUC is based on encoding events only

Return type:

list

ramutils.classifier.cross_validation.perform_lolo_cross_validation(classifier, powers, events, recalls, **kwargs)[source]

Perform a single iteration of leave-one-list-out cross validation

Parameters:
  • classifier (sklearn model object) –
  • powers (mean powers to use as features) –
  • events (set of events for the session) –
  • recalls (vector of recall outcomes) –
  • kwargs (dict) – Optional keyword arguments. These are passed to get_sample_weights. See that function for more details.
Returns:

probs – Predicted probabilities for encoding events across all lists

Return type:

np.array

Notes

Be careful when passing a classifier object to this function since it’s .fit() method will be called. If you use the classifier object after calling this function, the internal state may have changed. To avoid this problem, make a copy of the classifier object and pass the copy to this function.

ramutils.classifier.cross_validation.permuted_loso_cross_validation(classifier, powers, events, n_permutations, **kwargs)[source]

Perform permuted leave one session out cross validation

Parameters:
  • classifier – sklearn model object, usually logistic regression classifier
  • powers (np.ndarray) – power matrix
  • events (recarray) –
  • n_permutations (int) – number of permutation trials
  • kwargs (dict) – Optional keyword arguments. These are passed to get_sample_weights. See that function for more details.
Returns:

AUCs – List of AUCs from performing leave-one-list-out cross validation n_permutations times where the AUCs are based on encoding events only

Return type:

list

ramutils.classifier.cross_validation.perform_loso_cross_validation(classifier, powers, events, recalls, **kwargs)[source]

Perform single iteration of leave-one-session-out cross validation

Parameters:
  • classifier – sklearn model object, usually logistic regression classifier
  • powers (np.ndarray) – power matrix
  • events (np.recarray) –
  • recalls (array_like) – List of recall/not-recalled boolean values for each event
  • kwargs (dict) – Optional keyword arguments. These are passed to get_sample_weights. See that function for more details.
Returns:

probs – Predicted probabilities for encoding events across all sessions

Return type:

np.array

Utility functions used during classifier training

ramutils.classifier.utils.reload_classifier(subject, task, session, mount_point='/', base_path=None)[source]

Loads the actual classifier used by Ramulator for a particular session

Parameters:
  • subject (str) – Subject ID
  • task (str) – ex: FR5, FR6, PAL1, etc
  • session (int) – Session number
  • mount_point (str, default '/') – Mount point for RHINO
  • base_path (str) – Location of where the classifier files can be found. If None, default is to look in the expected location on RHINO
Returns:

classifier_container

Return type:

classiflib.container.ClassifierContainer

ramutils.classifier.utils.save_classifier_weights_plot(weights, frequencies, pairs, file_)[source]

Visualize the classifier weights as a function of frequency and location.

Parameters:
  • weights (np.ndarray (len(pairs)*len(frequencies)) –
  • frequencies (np.ndarray[float]) –
  • pairs (??? Iterable describing the pairs in some way) –
  • file (string or file-like) –
  • should be either a path or a file-like object. (which) –
  • Returns
  • file – The file_ parameter
ramutils.classifier.utils.train_classifier(pow_mat, events, sample_weights, penalty_param, penalty_type, solver)[source]

Train a classifier.

Parameters:
  • pow_mat (np.ndarray) –
  • events (np.recarray) –
  • sample_weights (np.ndarray) –
  • penalty_param (Float) – Penalty parameter to use
  • penalty_type (str) – Type of penalty to use for regularized model (ex: L2)
  • solver (str) – Solver to use when fitting the model (ex: liblinear)
Returns:

classifier – Trained classifier

Return type:

LogisticRegression