alexandra_ai_eval.model_adjustment
Adjusting a model's configuration, to make it suitable for a task.
1"""Adjusting a model's configuration, to make it suitable for a task.""" 2 3from copy import deepcopy 4 5import torch 6import torch.nn as nn 7from torch.nn.parameter import Parameter 8from transformers.modeling_utils import PreTrainedModel 9 10from .config import ModelConfig, TaskConfig 11from .enums import Framework 12from .exceptions import InvalidEvaluation 13 14 15def adjust_model_to_task( 16 model: nn.Module, 17 model_config: ModelConfig, 18 task_config: TaskConfig, 19) -> None: 20 """Adjust the model to the task. 21 22 This ensures that the label IDs in the model are consistent with the label IDs in 23 the dataset. 24 25 If the model is a Hugging Face model and there are labels in the dataset which the 26 model has not been trained on, then the model's classification layer is extended to 27 include these labels. 28 29 Args: 30 model: 31 The model to adjust the label ids of. 32 model_config: 33 The model configuration. 34 task_config: 35 The task configuration. 36 37 Raises: 38 InvalidEvaluation: 39 If there is a gap in the indexing dictionary of the model. 40 """ 41 # Define the model's label conversion 42 model_id2label: dict | list | None 43 44 # If the model does not have label conversions, then use the defaults 45 if model_config.id2label is None: 46 model_id2label = task_config.id2label 47 48 # If the model *does* have conversions, then ensure that it can deal with all the 49 # labels in the default conversions. This ensures that we can smoothly deal with 50 # labels that the model have not been trained on (it will just always get those 51 # labels wrong) 52 else: 53 model_id2label = deepcopy(model_config.id2label) 54 55 # Collect the dataset labels and model labels in the `model_id2label` 56 # conversion list 57 for label in task_config.id2label: 58 syns = [ 59 syn 60 for lst in task_config.label_synonyms 61 for syn in lst 62 if label.upper() in lst 63 ] 64 if all([syn not in model_id2label for syn in syns]): 65 model_id2label.append(label) 66 67 # Ensure that the model_id2label does not contain duplicates modulo synonyms 68 for idx, label in enumerate(model_id2label): 69 try: 70 canonical_syn = [ 71 syn_lst 72 for syn_lst in task_config.label_synonyms 73 if label.upper() in syn_lst 74 ][0][0] 75 model_id2label[idx] = canonical_syn 76 77 # IndexError appears when the label does not appear within the 78 # label_synonyms (i.e. that we added it in the previous step). In this 79 # case, we just skip the label. 80 except IndexError: 81 continue 82 83 # Get the synonyms of all the labels, new ones included 84 new_synonyms = list(task_config.label_synonyms) 85 flat_dataset_synonyms = [ 86 syn for lst in task_config.label_synonyms for syn in lst 87 ] 88 new_synonyms += [ 89 [label.upper()] 90 for label in model_id2label 91 if label.upper() not in flat_dataset_synonyms 92 ] 93 94 # Add all the synonyms of the labels into the label2id conversion dictionary 95 model_label2id = { 96 label.upper(): id 97 for id, lbl in enumerate(model_id2label) 98 for label_syns in new_synonyms 99 for label in label_syns 100 if lbl.upper() in label_syns 101 } 102 103 # Get the old model id2label conversion 104 old_model_id2label = model_config.id2label 105 106 # Alter the model's classification layer to match the dataset if the model is 107 # missing labels. This only works if the model is a Hugging Face PyTorch model 108 if ( 109 len(model_id2label) > len(old_model_id2label) 110 and model_config.framework == Framework.PYTORCH 111 and isinstance(model, PreTrainedModel) 112 ): 113 alter_classification_layer( 114 model=model, 115 model_id2label=model_id2label, 116 old_model_id2label=old_model_id2label, 117 flat_dataset_synonyms=flat_dataset_synonyms, 118 dataset_num_labels=task_config.num_labels, 119 ) 120 121 # Update the label conversion in the model config 122 model_config.id2label = model_id2label 123 model_config.label2id = model_label2id 124 125 # If the model is a Hugging Face model then update the label conversions that 126 # the model thinks it has, as well as the number of labels it thinks that it 127 # has. This helps prevent errors when the model is used for evaluation. 128 if isinstance(model, PreTrainedModel): 129 model.config.num_labels = len(model_id2label) 130 model.num_labels = len(model_id2label) 131 model.config.id2label = model_id2label 132 model.config.label2id = model_label2id 133 134 135def alter_classification_layer( 136 model: PreTrainedModel, 137 model_id2label: list, 138 old_model_id2label: list, 139 flat_dataset_synonyms: list, 140 dataset_num_labels: int, 141) -> None: 142 """Alter the classification layer of the model to match the dataset. 143 144 This changes the classification layer in the finetuned model to be consistent with 145 all the labels in the dataset. If the model was previously finetuned on a dataset 146 which left out a label, say, then that label will be inserted in the model 147 architecture here, but without the model ever predicting it. This will allow the 148 model to be benchmarked on such datasets, however. 149 150 Note that this only works on classification tasks and only for transformer models. 151 This code needs to be rewritten when we add other types of tasks and model types. 152 153 Args: 154 model: 155 The model to alter the classification layer of. 156 model_id2label: 157 The model's label conversion. 158 old_model_id2label: 159 The model's old label conversion. 160 flat_dataset_synonyms: 161 The synonyms of the dataset labels. 162 dataset_num_labels: 163 The number of labels in the dataset. 164 165 Raises: 166 InvalidEvaluation: 167 If the model has not been trained on any of the labels, or synonyms 168 thereof, of if it is not a classification model. 169 """ 170 # Count the number of new labels to add to the model 171 num_new_labels = len(model_id2label) - len(old_model_id2label) 172 if num_new_labels == 0: 173 return 174 175 # If *all* the new labels are new and aren't even synonyms of the model's labels, 176 # then raise an exception 177 if num_new_labels == dataset_num_labels: 178 if len(set(flat_dataset_synonyms).intersection(old_model_id2label)) == 0: 179 raise InvalidEvaluation( 180 "The model has not been trained on any of the labels in the dataset, " 181 "or synonyms thereof." 182 ) 183 184 # Load the weights from the model's current classification layer. This handles both 185 # the token classification case and the sequence classification case. 186 # NOTE: This might need additional cases (or a general solution) when we start 187 # dealing with other tasks. 188 try: 189 clf_weight = model.classifier.weight.data 190 use_out_proj = False 191 except AttributeError: 192 try: 193 clf_weight = model.classifier.out_proj.weight.data 194 use_out_proj = True 195 except AttributeError: 196 raise InvalidEvaluation("Model does not seem to be a classification model.") 197 198 # Create the new weights, which have zeros at all the new entries 199 zeros = torch.zeros(num_new_labels, model.config.hidden_size) 200 new_clf_weight = torch.cat((clf_weight, zeros), dim=0) 201 new_clf_weight = Parameter(new_clf_weight) 202 203 # Create the new classification layer 204 new_clf = nn.Linear(model.config.hidden_size, len(model_id2label)) 205 206 # Assign the new weights to the new classification layer, and replace the old 207 # classification layer with this one 208 new_clf.weight = new_clf_weight 209 if use_out_proj: 210 model.classifier.out_proj = new_clf 211 else: 212 model.classifier = new_clf
16def adjust_model_to_task( 17 model: nn.Module, 18 model_config: ModelConfig, 19 task_config: TaskConfig, 20) -> None: 21 """Adjust the model to the task. 22 23 This ensures that the label IDs in the model are consistent with the label IDs in 24 the dataset. 25 26 If the model is a Hugging Face model and there are labels in the dataset which the 27 model has not been trained on, then the model's classification layer is extended to 28 include these labels. 29 30 Args: 31 model: 32 The model to adjust the label ids of. 33 model_config: 34 The model configuration. 35 task_config: 36 The task configuration. 37 38 Raises: 39 InvalidEvaluation: 40 If there is a gap in the indexing dictionary of the model. 41 """ 42 # Define the model's label conversion 43 model_id2label: dict | list | None 44 45 # If the model does not have label conversions, then use the defaults 46 if model_config.id2label is None: 47 model_id2label = task_config.id2label 48 49 # If the model *does* have conversions, then ensure that it can deal with all the 50 # labels in the default conversions. This ensures that we can smoothly deal with 51 # labels that the model have not been trained on (it will just always get those 52 # labels wrong) 53 else: 54 model_id2label = deepcopy(model_config.id2label) 55 56 # Collect the dataset labels and model labels in the `model_id2label` 57 # conversion list 58 for label in task_config.id2label: 59 syns = [ 60 syn 61 for lst in task_config.label_synonyms 62 for syn in lst 63 if label.upper() in lst 64 ] 65 if all([syn not in model_id2label for syn in syns]): 66 model_id2label.append(label) 67 68 # Ensure that the model_id2label does not contain duplicates modulo synonyms 69 for idx, label in enumerate(model_id2label): 70 try: 71 canonical_syn = [ 72 syn_lst 73 for syn_lst in task_config.label_synonyms 74 if label.upper() in syn_lst 75 ][0][0] 76 model_id2label[idx] = canonical_syn 77 78 # IndexError appears when the label does not appear within the 79 # label_synonyms (i.e. that we added it in the previous step). In this 80 # case, we just skip the label. 81 except IndexError: 82 continue 83 84 # Get the synonyms of all the labels, new ones included 85 new_synonyms = list(task_config.label_synonyms) 86 flat_dataset_synonyms = [ 87 syn for lst in task_config.label_synonyms for syn in lst 88 ] 89 new_synonyms += [ 90 [label.upper()] 91 for label in model_id2label 92 if label.upper() not in flat_dataset_synonyms 93 ] 94 95 # Add all the synonyms of the labels into the label2id conversion dictionary 96 model_label2id = { 97 label.upper(): id 98 for id, lbl in enumerate(model_id2label) 99 for label_syns in new_synonyms 100 for label in label_syns 101 if lbl.upper() in label_syns 102 } 103 104 # Get the old model id2label conversion 105 old_model_id2label = model_config.id2label 106 107 # Alter the model's classification layer to match the dataset if the model is 108 # missing labels. This only works if the model is a Hugging Face PyTorch model 109 if ( 110 len(model_id2label) > len(old_model_id2label) 111 and model_config.framework == Framework.PYTORCH 112 and isinstance(model, PreTrainedModel) 113 ): 114 alter_classification_layer( 115 model=model, 116 model_id2label=model_id2label, 117 old_model_id2label=old_model_id2label, 118 flat_dataset_synonyms=flat_dataset_synonyms, 119 dataset_num_labels=task_config.num_labels, 120 ) 121 122 # Update the label conversion in the model config 123 model_config.id2label = model_id2label 124 model_config.label2id = model_label2id 125 126 # If the model is a Hugging Face model then update the label conversions that 127 # the model thinks it has, as well as the number of labels it thinks that it 128 # has. This helps prevent errors when the model is used for evaluation. 129 if isinstance(model, PreTrainedModel): 130 model.config.num_labels = len(model_id2label) 131 model.num_labels = len(model_id2label) 132 model.config.id2label = model_id2label 133 model.config.label2id = model_label2id
Adjust the model to the task.
This ensures that the label IDs in the model are consistent with the label IDs in the dataset.
If the model is a Hugging Face model and there are labels in the dataset which the model has not been trained on, then the model's classification layer is extended to include these labels.
Arguments:
- model: The model to adjust the label ids of.
- model_config: The model configuration.
- task_config: The task configuration.
Raises:
- InvalidEvaluation: If there is a gap in the indexing dictionary of the model.
136def alter_classification_layer( 137 model: PreTrainedModel, 138 model_id2label: list, 139 old_model_id2label: list, 140 flat_dataset_synonyms: list, 141 dataset_num_labels: int, 142) -> None: 143 """Alter the classification layer of the model to match the dataset. 144 145 This changes the classification layer in the finetuned model to be consistent with 146 all the labels in the dataset. If the model was previously finetuned on a dataset 147 which left out a label, say, then that label will be inserted in the model 148 architecture here, but without the model ever predicting it. This will allow the 149 model to be benchmarked on such datasets, however. 150 151 Note that this only works on classification tasks and only for transformer models. 152 This code needs to be rewritten when we add other types of tasks and model types. 153 154 Args: 155 model: 156 The model to alter the classification layer of. 157 model_id2label: 158 The model's label conversion. 159 old_model_id2label: 160 The model's old label conversion. 161 flat_dataset_synonyms: 162 The synonyms of the dataset labels. 163 dataset_num_labels: 164 The number of labels in the dataset. 165 166 Raises: 167 InvalidEvaluation: 168 If the model has not been trained on any of the labels, or synonyms 169 thereof, of if it is not a classification model. 170 """ 171 # Count the number of new labels to add to the model 172 num_new_labels = len(model_id2label) - len(old_model_id2label) 173 if num_new_labels == 0: 174 return 175 176 # If *all* the new labels are new and aren't even synonyms of the model's labels, 177 # then raise an exception 178 if num_new_labels == dataset_num_labels: 179 if len(set(flat_dataset_synonyms).intersection(old_model_id2label)) == 0: 180 raise InvalidEvaluation( 181 "The model has not been trained on any of the labels in the dataset, " 182 "or synonyms thereof." 183 ) 184 185 # Load the weights from the model's current classification layer. This handles both 186 # the token classification case and the sequence classification case. 187 # NOTE: This might need additional cases (or a general solution) when we start 188 # dealing with other tasks. 189 try: 190 clf_weight = model.classifier.weight.data 191 use_out_proj = False 192 except AttributeError: 193 try: 194 clf_weight = model.classifier.out_proj.weight.data 195 use_out_proj = True 196 except AttributeError: 197 raise InvalidEvaluation("Model does not seem to be a classification model.") 198 199 # Create the new weights, which have zeros at all the new entries 200 zeros = torch.zeros(num_new_labels, model.config.hidden_size) 201 new_clf_weight = torch.cat((clf_weight, zeros), dim=0) 202 new_clf_weight = Parameter(new_clf_weight) 203 204 # Create the new classification layer 205 new_clf = nn.Linear(model.config.hidden_size, len(model_id2label)) 206 207 # Assign the new weights to the new classification layer, and replace the old 208 # classification layer with this one 209 new_clf.weight = new_clf_weight 210 if use_out_proj: 211 model.classifier.out_proj = new_clf 212 else: 213 model.classifier = new_clf
Alter the classification layer of the model to match the dataset.
This changes the classification layer in the finetuned model to be consistent with all the labels in the dataset. If the model was previously finetuned on a dataset which left out a label, say, then that label will be inserted in the model architecture here, but without the model ever predicting it. This will allow the model to be benchmarked on such datasets, however.
Note that this only works on classification tasks and only for transformer models. This code needs to be rewritten when we add other types of tasks and model types.
Arguments:
- model: The model to alter the classification layer of.
- model_id2label: The model's label conversion.
- old_model_id2label: The model's old label conversion.
- flat_dataset_synonyms: The synonyms of the dataset labels.
- dataset_num_labels: The number of labels in the dataset.
Raises:
- InvalidEvaluation: If the model has not been trained on any of the labels, or synonyms thereof, of if it is not a classification model.