Source code for ramutils.pipelines.ramulator_config

from ramutils.constants import EXPERIMENTS
from ramutils.exc import (
    MissingArgumentsError, MultistimNotAllowedException, ValidationError
)
from ramutils.montage import generate_pairs_from_electrode_config
from ramutils.tasks import *
from .hooks import PipelineCallback


@task(cache=False)
def validate_pairs(subject, ec_pairs, trigger_pairs=None):
    """Validate that specified pairs exist in the electrode config.

    Parameters
    ----------
    subject : str
        Subject ID
    ec_pairs : OrderedDict
        Contents of pairs.json as generated from the electrode config file.
        Pairs here are specified as ``<anode label>-<cathode label>``.
    trigger_pairs : List
        List of specified pairs to be used as triggers for PS5. Pairs here are
        specified as ``<anode label>_<cathode label>``.

    Notes
    -----
    Generating the electrode config file will already fail if anodes/cathodes
    are not spelled correctly, so we only actually check trigger pairs for PS5
    here.

    """
    pairs_json = ec_pairs[subject]['pairs']

    if trigger_pairs is not None:
        for pair in trigger_pairs:
            hyphenated_pair = pair.replace('_', '-')
            if hyphenated_pair not in pairs_json:
                raise ValidationError(
                    "trigger pair " + pair +
                    " not found in pairs.json (check for typos!)"
                )


[docs]def make_ramulator_config(subject, experiment, paths, stim_params, sessions=None, exp_params=None, vispath=None, extended_blanking=True, localization=0, montage=0, default_surface_area=0.001, trigger_pairs=None, use_common_reference=False, use_classifier_excluded_leads=False, pipeline_name="ramulator-conf"): """ Generate configuration files for a Ramulator experiment Parameters ---------- subject : str Subject ID experiment : str Experiment to generate configuration file for paths : FilePaths stim_params : List[StimParameters] Stimulation parameters for this experiment. sessions: List[int] Sessions to include when training classifier exp_params : ExperimentParameters Parameters for the experiment. vispath : str Path to save task graph visualization to if given. extended_blanking : bool Whether to enable extended blanking on the ENS (default: True). localization : int Localization number montage : int Montage number default_surface_area : float Default surface area to set all electrodes to in mm^2. Only used if no area file can be found. trigger_pairs : List[str] or None Pairs to use for triggering stim in PS5 experiments. use_common_reference : bool Use a common reference in the electrode configuration instead of bipolar referencing. use_classifier_excluded_leads: bool Use contents of classifier_excluded_leads.txt to exclude channels from classifier training pipeline_name : str Name to use for status updates. Returns ------- The path to the generated configuration zip file. """ if len(stim_params) > 1 and experiment not in EXPERIMENTS['multistim']: raise MultistimNotAllowedException if trigger_pairs is None: if experiment.startswith('PS5'): raise MissingArgumentsError("PS5 requires trigger_pairs") # setting to empty list for validation trigger_pairs = [] anodes = [c.anode_label for c in stim_params] cathodes = [c.cathode_label for c in stim_params] # If the electrode config path is defined, load it instead of creating a new # one. This is useful if we want to make comparisons with old referencing # schemes that are not currently implemented in bptools. if paths.electrode_config_file is None: paths = generate_electrode_config(subject, paths, anodes, cathodes, localization, montage, default_surface_area, use_common_reference).compute() # Note: All of these pairs variables are of type OrderedDict, which is # crucial for preserving the initial order of the electrodes in the # config file ec_pairs = make_task(generate_pairs_from_electrode_config, subject, experiment, None, paths) # Ignore leads identified in classifier_excluded_leads.txt pairs_to_exclude = stim_params if use_classifier_excluded_leads: classifier_excluded_leads = get_classifier_excluded_leads( subject, ec_pairs, rootdir=paths.root).compute() pairs_to_exclude = pairs_to_exclude + classifier_excluded_leads excluded_pairs = reduce_pairs(ec_pairs, pairs_to_exclude, True) used_pair_mask = get_used_pair_mask(ec_pairs, excluded_pairs) final_pairs = generate_pairs_for_classifier(ec_pairs, excluded_pairs) # Ensure specified pairs exist. We have to call .compute here since no # other tasks depend on the output of this task. validate_pairs(subject, ec_pairs, trigger_pairs).compute() # Special case handling of no-classifier tasks no_classifier_experiments = EXPERIMENTS["record_only"] + [ "AmplitudeDetermination", "PS5_FR", "PS5_CatFR", "LocationSearch", ] if experiment in no_classifier_experiments: container = None config_path = generate_ramulator_config(subject=subject, experiment=experiment, container=container, stim_params=stim_params, paths=paths, pairs=ec_pairs, excluded_pairs=excluded_pairs, extended_blanking=extended_blanking, trigger_pairs=trigger_pairs) with PipelineCallback(pipeline_name): return config_path.compute() if "FR" not in experiment and "PAL" not in experiment: raise RuntimeError("Only PAL, FR, and catFR experiments are currently" "implemented") kwargs = exp_params.to_dict() all_task_events = build_training_data( subject, experiment, paths, sessions=sessions, **kwargs) powers, final_task_events = compute_normalized_powers(all_task_events, bipolar_pairs=ec_pairs, **kwargs) reduced_powers = reduce_powers( powers, used_pair_mask, len(kwargs['freqs'])) sample_weights = get_sample_weights(final_task_events, **kwargs) classifier = train_classifier(reduced_powers, final_task_events, sample_weights, kwargs['C'], kwargs['penalty_type'], kwargs['solver']) cross_validation_results = summarize_classifier(classifier, reduced_powers, final_task_events, kwargs['n_perm'], 'Trained Classifier', **kwargs) container = serialize_classifier(classifier, final_pairs, reduced_powers, final_task_events, sample_weights, cross_validation_results, subject) config_path = generate_ramulator_config(subject=subject, experiment=experiment, container=container, stim_params=stim_params, paths=paths, pairs=ec_pairs, excluded_pairs=excluded_pairs, exp_params=exp_params, extended_blanking=extended_blanking) if vispath is not None: config_path.visualize(filename=vispath) with PipelineCallback(pipeline_name): return config_path.compute()