alexandra_ai_eval.model_adjustment

Adjusting a model's configuration, to make it suitable for a task.

  1"""Adjusting a model's configuration, to make it suitable for a task."""
  2
  3from copy import deepcopy
  4
  5import torch
  6import torch.nn as nn
  7from torch.nn.parameter import Parameter
  8from transformers.modeling_utils import PreTrainedModel
  9
 10from .config import ModelConfig, TaskConfig
 11from .enums import Framework
 12from .exceptions import InvalidEvaluation
 13
 14
 15def adjust_model_to_task(
 16    model: nn.Module,
 17    model_config: ModelConfig,
 18    task_config: TaskConfig,
 19) -> None:
 20    """Adjust the model to the task.
 21
 22    This ensures that the label IDs in the model are consistent with the label IDs in
 23    the dataset.
 24
 25    If the model is a Hugging Face model and there are labels in the dataset which the
 26    model has not been trained on, then the model's classification layer is extended to
 27    include these labels.
 28
 29    Args:
 30        model:
 31            The model to adjust the label ids of.
 32        model_config:
 33            The model configuration.
 34        task_config:
 35            The task configuration.
 36
 37    Raises:
 38        InvalidEvaluation:
 39            If there is a gap in the indexing dictionary of the model.
 40    """
 41    # Define the model's label conversion
 42    model_id2label: dict | list | None
 43
 44    # If the model does not have label conversions, then use the defaults
 45    if model_config.id2label is None:
 46        model_id2label = task_config.id2label
 47
 48    # If the model *does* have conversions, then ensure that it can deal with all the
 49    # labels in the default conversions. This ensures that we can smoothly deal with
 50    # labels that the model have not been trained on (it will just always get those
 51    # labels wrong)
 52    else:
 53        model_id2label = deepcopy(model_config.id2label)
 54
 55        # Collect the dataset labels and model labels in the `model_id2label`
 56        # conversion list
 57        for label in task_config.id2label:
 58            syns = [
 59                syn
 60                for lst in task_config.label_synonyms
 61                for syn in lst
 62                if label.upper() in lst
 63            ]
 64            if all([syn not in model_id2label for syn in syns]):
 65                model_id2label.append(label)
 66
 67        # Ensure that the model_id2label does not contain duplicates modulo synonyms
 68        for idx, label in enumerate(model_id2label):
 69            try:
 70                canonical_syn = [
 71                    syn_lst
 72                    for syn_lst in task_config.label_synonyms
 73                    if label.upper() in syn_lst
 74                ][0][0]
 75                model_id2label[idx] = canonical_syn
 76
 77            # IndexError appears when the label does not appear within the
 78            # label_synonyms (i.e. that we added it in the previous step). In this
 79            # case, we just skip the label.
 80            except IndexError:
 81                continue
 82
 83        # Get the synonyms of all the labels, new ones included
 84        new_synonyms = list(task_config.label_synonyms)
 85        flat_dataset_synonyms = [
 86            syn for lst in task_config.label_synonyms for syn in lst
 87        ]
 88        new_synonyms += [
 89            [label.upper()]
 90            for label in model_id2label
 91            if label.upper() not in flat_dataset_synonyms
 92        ]
 93
 94        # Add all the synonyms of the labels into the label2id conversion dictionary
 95        model_label2id = {
 96            label.upper(): id
 97            for id, lbl in enumerate(model_id2label)
 98            for label_syns in new_synonyms
 99            for label in label_syns
100            if lbl.upper() in label_syns
101        }
102
103        # Get the old model id2label conversion
104        old_model_id2label = model_config.id2label
105
106        # Alter the model's classification layer to match the dataset if the model is
107        # missing labels. This only works if the model is a Hugging Face PyTorch model
108        if (
109            len(model_id2label) > len(old_model_id2label)
110            and model_config.framework == Framework.PYTORCH
111            and isinstance(model, PreTrainedModel)
112        ):
113            alter_classification_layer(
114                model=model,
115                model_id2label=model_id2label,
116                old_model_id2label=old_model_id2label,
117                flat_dataset_synonyms=flat_dataset_synonyms,
118                dataset_num_labels=task_config.num_labels,
119            )
120
121        # Update the label conversion in the model config
122        model_config.id2label = model_id2label
123        model_config.label2id = model_label2id
124
125        # If the model is a Hugging Face model then update the label conversions that
126        # the model thinks it has, as well as the number of labels it thinks that it
127        # has. This helps prevent errors when the model is used for evaluation.
128        if isinstance(model, PreTrainedModel):
129            model.config.num_labels = len(model_id2label)
130            model.num_labels = len(model_id2label)
131            model.config.id2label = model_id2label
132            model.config.label2id = model_label2id
133
134
135def alter_classification_layer(
136    model: PreTrainedModel,
137    model_id2label: list,
138    old_model_id2label: list,
139    flat_dataset_synonyms: list,
140    dataset_num_labels: int,
141) -> None:
142    """Alter the classification layer of the model to match the dataset.
143
144    This changes the classification layer in the finetuned model to be consistent with
145    all the labels in the dataset. If the model was previously finetuned on a dataset
146    which left out a label, say, then that label will be inserted in the model
147    architecture here, but without the model ever predicting it. This will allow the
148    model to be benchmarked on such datasets, however.
149
150    Note that this only works on classification tasks and only for transformer models.
151    This code needs to be rewritten when we add other types of tasks and model types.
152
153    Args:
154        model:
155            The model to alter the classification layer of.
156        model_id2label:
157            The model's label conversion.
158        old_model_id2label:
159            The model's old label conversion.
160        flat_dataset_synonyms:
161            The synonyms of the dataset labels.
162        dataset_num_labels:
163            The number of labels in the dataset.
164
165    Raises:
166        InvalidEvaluation:
167            If the model has not been trained on any of the labels, or synonyms
168            thereof, of if it is not a classification model.
169    """
170    # Count the number of new labels to add to the model
171    num_new_labels = len(model_id2label) - len(old_model_id2label)
172    if num_new_labels == 0:
173        return
174
175    # If *all* the new labels are new and aren't even synonyms of the model's labels,
176    # then raise an exception
177    if num_new_labels == dataset_num_labels:
178        if len(set(flat_dataset_synonyms).intersection(old_model_id2label)) == 0:
179            raise InvalidEvaluation(
180                "The model has not been trained on any of the labels in the dataset, "
181                "or synonyms thereof."
182            )
183
184    # Load the weights from the model's current classification layer. This handles both
185    # the token classification case and the sequence classification case.
186    # NOTE: This might need additional cases (or a general solution) when we start
187    # dealing with other tasks.
188    try:
189        clf_weight = model.classifier.weight.data
190        use_out_proj = False
191    except AttributeError:
192        try:
193            clf_weight = model.classifier.out_proj.weight.data
194            use_out_proj = True
195        except AttributeError:
196            raise InvalidEvaluation("Model does not seem to be a classification model.")
197
198    # Create the new weights, which have zeros at all the new entries
199    zeros = torch.zeros(num_new_labels, model.config.hidden_size)
200    new_clf_weight = torch.cat((clf_weight, zeros), dim=0)
201    new_clf_weight = Parameter(new_clf_weight)
202
203    # Create the new classification layer
204    new_clf = nn.Linear(model.config.hidden_size, len(model_id2label))
205
206    # Assign the new weights to the new classification layer, and replace the old
207    # classification layer with this one
208    new_clf.weight = new_clf_weight
209    if use_out_proj:
210        model.classifier.out_proj = new_clf
211    else:
212        model.classifier = new_clf
def adjust_model_to_task( model: torch.nn.modules.module.Module, model_config: alexandra_ai_eval.config.ModelConfig, task_config: alexandra_ai_eval.config.TaskConfig) -> None:
 16def adjust_model_to_task(
 17    model: nn.Module,
 18    model_config: ModelConfig,
 19    task_config: TaskConfig,
 20) -> None:
 21    """Adjust the model to the task.
 22
 23    This ensures that the label IDs in the model are consistent with the label IDs in
 24    the dataset.
 25
 26    If the model is a Hugging Face model and there are labels in the dataset which the
 27    model has not been trained on, then the model's classification layer is extended to
 28    include these labels.
 29
 30    Args:
 31        model:
 32            The model to adjust the label ids of.
 33        model_config:
 34            The model configuration.
 35        task_config:
 36            The task configuration.
 37
 38    Raises:
 39        InvalidEvaluation:
 40            If there is a gap in the indexing dictionary of the model.
 41    """
 42    # Define the model's label conversion
 43    model_id2label: dict | list | None
 44
 45    # If the model does not have label conversions, then use the defaults
 46    if model_config.id2label is None:
 47        model_id2label = task_config.id2label
 48
 49    # If the model *does* have conversions, then ensure that it can deal with all the
 50    # labels in the default conversions. This ensures that we can smoothly deal with
 51    # labels that the model have not been trained on (it will just always get those
 52    # labels wrong)
 53    else:
 54        model_id2label = deepcopy(model_config.id2label)
 55
 56        # Collect the dataset labels and model labels in the `model_id2label`
 57        # conversion list
 58        for label in task_config.id2label:
 59            syns = [
 60                syn
 61                for lst in task_config.label_synonyms
 62                for syn in lst
 63                if label.upper() in lst
 64            ]
 65            if all([syn not in model_id2label for syn in syns]):
 66                model_id2label.append(label)
 67
 68        # Ensure that the model_id2label does not contain duplicates modulo synonyms
 69        for idx, label in enumerate(model_id2label):
 70            try:
 71                canonical_syn = [
 72                    syn_lst
 73                    for syn_lst in task_config.label_synonyms
 74                    if label.upper() in syn_lst
 75                ][0][0]
 76                model_id2label[idx] = canonical_syn
 77
 78            # IndexError appears when the label does not appear within the
 79            # label_synonyms (i.e. that we added it in the previous step). In this
 80            # case, we just skip the label.
 81            except IndexError:
 82                continue
 83
 84        # Get the synonyms of all the labels, new ones included
 85        new_synonyms = list(task_config.label_synonyms)
 86        flat_dataset_synonyms = [
 87            syn for lst in task_config.label_synonyms for syn in lst
 88        ]
 89        new_synonyms += [
 90            [label.upper()]
 91            for label in model_id2label
 92            if label.upper() not in flat_dataset_synonyms
 93        ]
 94
 95        # Add all the synonyms of the labels into the label2id conversion dictionary
 96        model_label2id = {
 97            label.upper(): id
 98            for id, lbl in enumerate(model_id2label)
 99            for label_syns in new_synonyms
100            for label in label_syns
101            if lbl.upper() in label_syns
102        }
103
104        # Get the old model id2label conversion
105        old_model_id2label = model_config.id2label
106
107        # Alter the model's classification layer to match the dataset if the model is
108        # missing labels. This only works if the model is a Hugging Face PyTorch model
109        if (
110            len(model_id2label) > len(old_model_id2label)
111            and model_config.framework == Framework.PYTORCH
112            and isinstance(model, PreTrainedModel)
113        ):
114            alter_classification_layer(
115                model=model,
116                model_id2label=model_id2label,
117                old_model_id2label=old_model_id2label,
118                flat_dataset_synonyms=flat_dataset_synonyms,
119                dataset_num_labels=task_config.num_labels,
120            )
121
122        # Update the label conversion in the model config
123        model_config.id2label = model_id2label
124        model_config.label2id = model_label2id
125
126        # If the model is a Hugging Face model then update the label conversions that
127        # the model thinks it has, as well as the number of labels it thinks that it
128        # has. This helps prevent errors when the model is used for evaluation.
129        if isinstance(model, PreTrainedModel):
130            model.config.num_labels = len(model_id2label)
131            model.num_labels = len(model_id2label)
132            model.config.id2label = model_id2label
133            model.config.label2id = model_label2id

Adjust the model to the task.

This ensures that the label IDs in the model are consistent with the label IDs in the dataset.

If the model is a Hugging Face model and there are labels in the dataset which the model has not been trained on, then the model's classification layer is extended to include these labels.

Arguments:
  • model: The model to adjust the label ids of.
  • model_config: The model configuration.
  • task_config: The task configuration.
Raises:
  • InvalidEvaluation: If there is a gap in the indexing dictionary of the model.
def alter_classification_layer( model: transformers.modeling_utils.PreTrainedModel, model_id2label: list, old_model_id2label: list, flat_dataset_synonyms: list, dataset_num_labels: int) -> None:
136def alter_classification_layer(
137    model: PreTrainedModel,
138    model_id2label: list,
139    old_model_id2label: list,
140    flat_dataset_synonyms: list,
141    dataset_num_labels: int,
142) -> None:
143    """Alter the classification layer of the model to match the dataset.
144
145    This changes the classification layer in the finetuned model to be consistent with
146    all the labels in the dataset. If the model was previously finetuned on a dataset
147    which left out a label, say, then that label will be inserted in the model
148    architecture here, but without the model ever predicting it. This will allow the
149    model to be benchmarked on such datasets, however.
150
151    Note that this only works on classification tasks and only for transformer models.
152    This code needs to be rewritten when we add other types of tasks and model types.
153
154    Args:
155        model:
156            The model to alter the classification layer of.
157        model_id2label:
158            The model's label conversion.
159        old_model_id2label:
160            The model's old label conversion.
161        flat_dataset_synonyms:
162            The synonyms of the dataset labels.
163        dataset_num_labels:
164            The number of labels in the dataset.
165
166    Raises:
167        InvalidEvaluation:
168            If the model has not been trained on any of the labels, or synonyms
169            thereof, of if it is not a classification model.
170    """
171    # Count the number of new labels to add to the model
172    num_new_labels = len(model_id2label) - len(old_model_id2label)
173    if num_new_labels == 0:
174        return
175
176    # If *all* the new labels are new and aren't even synonyms of the model's labels,
177    # then raise an exception
178    if num_new_labels == dataset_num_labels:
179        if len(set(flat_dataset_synonyms).intersection(old_model_id2label)) == 0:
180            raise InvalidEvaluation(
181                "The model has not been trained on any of the labels in the dataset, "
182                "or synonyms thereof."
183            )
184
185    # Load the weights from the model's current classification layer. This handles both
186    # the token classification case and the sequence classification case.
187    # NOTE: This might need additional cases (or a general solution) when we start
188    # dealing with other tasks.
189    try:
190        clf_weight = model.classifier.weight.data
191        use_out_proj = False
192    except AttributeError:
193        try:
194            clf_weight = model.classifier.out_proj.weight.data
195            use_out_proj = True
196        except AttributeError:
197            raise InvalidEvaluation("Model does not seem to be a classification model.")
198
199    # Create the new weights, which have zeros at all the new entries
200    zeros = torch.zeros(num_new_labels, model.config.hidden_size)
201    new_clf_weight = torch.cat((clf_weight, zeros), dim=0)
202    new_clf_weight = Parameter(new_clf_weight)
203
204    # Create the new classification layer
205    new_clf = nn.Linear(model.config.hidden_size, len(model_id2label))
206
207    # Assign the new weights to the new classification layer, and replace the old
208    # classification layer with this one
209    new_clf.weight = new_clf_weight
210    if use_out_proj:
211        model.classifier.out_proj = new_clf
212    else:
213        model.classifier = new_clf

Alter the classification layer of the model to match the dataset.

This changes the classification layer in the finetuned model to be consistent with all the labels in the dataset. If the model was previously finetuned on a dataset which left out a label, say, then that label will be inserted in the model architecture here, but without the model ever predicting it. This will allow the model to be benchmarked on such datasets, however.

Note that this only works on classification tasks and only for transformer models. This code needs to be rewritten when we add other types of tasks and model types.

Arguments:
  • model: The model to alter the classification layer of.
  • model_id2label: The model's label conversion.
  • old_model_id2label: The model's old label conversion.
  • flat_dataset_synonyms: The synonyms of the dataset labels.
  • dataset_num_labels: The number of labels in the dataset.
Raises:
  • InvalidEvaluation: If the model has not been trained on any of the labels, or synonyms thereof, of if it is not a classification model.