Source code for ramutils.classifier.utils

""" Utility functions used during classifier training """

import os
import numpy as np

from glob import glob
from sklearn.linear_model.logistic import LogisticRegression

from classiflib.container import ClassifierContainer
from ramutils.exc import UnableToReloadClassifierException

[docs]def reload_classifier(subject, task, session, mount_point='/', base_path=None): """Loads the actual classifier used by Ramulator for a particular session Parameters ---------- subject: str Subject ID task: str ex: FR5, FR6, PAL1, etc session: int Session number mount_point: str, default '/' Mount point for RHINO base_path: str Location of where the classifier files can be found. If None, default is to look in the expected location on RHINO Returns ------- classifier_container: classiflib.container.ClassifierContainer """ if base_path is None: base_path = os.path.join(mount_point, 'data', 'eeg', subject, 'behavioral', task, 'session_{}'.format(str(session)), 'host_pc') config_paths = glob(os.path.join(base_path,'*', 'config_files')) if len(config_paths) < 1: # expected host_pc folder does not exist return None if len(config_paths) > 1: # Return the original classifier retrained_classifier_paths = glob(os.path.join( base_path, '*', 'config_files', 'retrained_classifier')) if len(retrained_classifier_paths): config_path = sorted(retrained_classifier_paths)[-1] else: config_path = sorted(config_paths)[-1] classifier_path = glob(os.path.join(config_path, '*classifier*.zip')) # No container was found, likely because it is the old .pkl version if len(classifier_path) == 0: return None classifier_path = classifier_path[0] orig_classifier_container = ClassifierContainer.load(classifier_path) return orig_classifier_container # FIXME: this needs a data quality check to confirm that all classifiers in # a session are the same! # We take the final timestamped directory because in principle retrained # classifiers can be different depending on artifact detection. In # reality, stim sessions should never be restarted (apart from issues # getting things started in the first place). config_path = config_paths[-1] if 'retrained_classifier' in os.listdir(config_path): classifier_path = glob(os.path.join(config_path, 'retrained_classifier', '*classifier*.zip'))[0] else: classifier_path = glob(os.path.join(config_path, '*classifier*.zip')) if len(classifier_path) == 0: return None classifier_path = classifier_path[0] classifier_container = ClassifierContainer.load(classifier_path) return classifier_container
[docs]def train_classifier(pow_mat, events, sample_weights, penalty_param, penalty_type, solver): """Train a classifier. Parameters ---------- pow_mat : np.ndarray events : np.recarray sample_weights : np.ndarray penalty_param: Float Penalty parameter to use penalty_type: str Type of penalty to use for regularized model (ex: L2) solver: str Solver to use when fitting the model (ex: liblinear) Returns ------- classifier : LogisticRegression Trained classifier """ recalls = events.recalled classifier = LogisticRegression(C=penalty_param, penalty=penalty_type, solver=solver, class_weight='balanced'), recalls, sample_weights) return classifier
# TODO: group and average classifier weights by brain region
[docs]def save_classifier_weights_plot(weights, frequencies, pairs, file_): """ Visualize the classifier weights as a function of frequency and location. Parameters ----------- weights : np.ndarray (len(pairs)*len(frequencies) frequencies : np.ndarray[float] pairs : ??? Iterable describing the pairs in some way file_ : The destination of the classifier weight plot, which should be either a path or a file-like object. Returns: file_ : string or file-like The file_ parameter """ # pyplot is imported here rather than at the top of the module to avoid any # problems setting the matplotlib backend elsewhere. from matplotlib import pyplot as plt weights_by_channel = weights.reshape((len(frequencies), len(pairs))) plt.imshow(weights_by_channel, aspect='auto', origin='lower', cmap='RdBu') cmin, cmax = weights.min(), weights.max() clim = max(abs(cmin), abs(cmax)) plt.clim(-clim, clim) plt.colorbar() locs, old_labels = plt.yticks() new_labels = ['%d' % (np.rint(f).astype(int)) for f in frequencies] plt.yticks(locs[1:-1], new_labels) plt.savefig(file_, format="png", dpi=300, bbox_inches="tight", pad_inches=0.1) plt.close() return file_