alexandra_ai_eval.question_answering
Class for question-answering tasks.
1"""Class for question-answering tasks.""" 2 3from collections import defaultdict 4 5import numpy as np 6from datasets.arrow_dataset import Dataset 7from transformers.data.data_collator import DataCollator, default_data_collator 8from transformers.models.auto.processing_auto import AutoProcessor 9from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase 10 11from .config import ModelConfig, TaskConfig 12from .exceptions import FrameworkCannotHandleTask 13from .task import Task 14 15 16class QuestionAnswering(Task): 17 """Question answering task. 18 19 Args: 20 task_config: 21 The configuration of the task. 22 evaluation_config: 23 The configuration of the evaluation. 24 25 Attributes: 26 task_config: 27 The configuration of the task. 28 evaluation_config: 29 The configuration of the evaluation. 30 """ 31 32 def _pytorch_preprocess_fn( 33 self, 34 examples: BatchEncoding, 35 tokenizer: PreTrainedTokenizerBase, 36 model_config: ModelConfig, 37 task_config: TaskConfig, 38 ) -> BatchEncoding: 39 return prepare_test_examples( 40 examples=examples, 41 tokenizer=tokenizer, 42 ) 43 44 def _load_data_collator( 45 self, tokenizer_or_processor: PreTrainedTokenizerBase | AutoProcessor 46 ) -> DataCollator: 47 return default_data_collator 48 49 def _prepare_predictions_and_labels( 50 self, 51 predictions: list, 52 dataset: Dataset, 53 prepared_dataset: Dataset, 54 **kwargs, 55 ) -> list[tuple[list, list]]: 56 predictions = postprocess_predictions( 57 predictions=predictions, 58 dataset=dataset, 59 prepared_dataset=prepared_dataset, 60 cls_token_index=kwargs["cls_token_index"], 61 ) 62 labels = postprocess_labels(dataset=dataset) 63 64 return [(predictions, labels)] 65 66 def _check_if_model_is_trained_for_task(self, model_predictions: list) -> bool: 67 sample_preds = model_predictions[0] 68 elements_are_pairs = len(sample_preds[0]) == 2 69 leaves_are_floats = sample_preds[0][0].dtype.kind == "f" 70 elements_are_strings = isinstance(sample_preds[0], str) 71 return (elements_are_pairs and leaves_are_floats) or elements_are_strings 72 73 def _spacy_preprocess_fn(self, examples: dict) -> dict: 74 raise FrameworkCannotHandleTask( 75 framework="spaCy", task=self.task_config.pretty_name 76 ) 77 78 def _extract_spacy_predictions(self, tokens_processed: tuple) -> list: 79 raise FrameworkCannotHandleTask( 80 framework="spaCy", task=self.task_config.pretty_name 81 ) 82 83 84def prepare_test_examples( 85 examples: BatchEncoding, 86 tokenizer: PreTrainedTokenizerBase, 87) -> BatchEncoding: 88 """Prepare test examples. 89 90 Args: 91 examples: 92 Dictionary of test examples. 93 tokenizer: 94 The tokenizer used to preprocess the examples. 95 96 Returns: 97 Dictionary of prepared test examples. 98 """ 99 # Some of the questions have lots of whitespace on the left, which is not useful 100 # and will make the truncation of the context fail (the tokenized question will 101 # take a lots of space). So we remove that left whitespace 102 examples["question"] = [q.lstrip() for q in examples["question"]] 103 104 # Compute the stride, being a quarter of the context length 105 stride = tokenizer.model_max_length // 4 106 max_length = tokenizer.model_max_length - stride 107 108 # Tokenize our examples with truncation and maybe padding, but keep the overflows 109 # using a stride. This results in one example possible giving several features when 110 # a context is long, each of those features having a context that overlaps a bit 111 # the context of the previous feature. 112 tokenized_examples = tokenizer( 113 examples["question"], 114 examples["context"], 115 truncation="only_second", 116 max_length=max_length, 117 stride=stride, 118 return_overflowing_tokens=True, 119 return_offsets_mapping=True, 120 padding="max_length", 121 ) 122 123 # Since one example might give us several features if it has a long context, we 124 # need a map from a feature to its corresponding example. This key gives us just 125 # that. 126 sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") 127 128 # We keep the id that gave us this feature and we will store the offset mappings. 129 tokenized_examples["id"] = list() 130 131 for i in range(len(tokenized_examples["input_ids"])): 132 # Grab the sequence corresponding to that example (to know what is the context 133 # and what is the question). 134 sequence_ids = tokenized_examples.sequence_ids(i) 135 context_index = 1 136 137 # One example can give several spans, this is the index of the example 138 # containing this span of text. 139 sample_index = sample_mapping[i] 140 tokenized_examples["id"].append(examples["id"][sample_index]) 141 142 # Set to (-1, -1) the offset_mapping that are not part of the context so it's 143 # easy to determine if a token position is part of the context or not. 144 tokenized_examples["offset_mapping"][i] = [ 145 (o if sequence_ids[k] == context_index else (-1, -1)) 146 for k, o in enumerate(tokenized_examples["offset_mapping"][i]) 147 ] 148 149 return tokenized_examples 150 151 152def postprocess_predictions( 153 predictions: list, 154 dataset: Dataset, 155 prepared_dataset: Dataset, 156 cls_token_index: int, 157) -> list[dict]: 158 """Postprocess the predictions, to allow easier metric computation. 159 160 Args: 161 predictions: 162 The predictions to postprocess. 163 dataset: 164 The dataset containing the examples. 165 prepared_dataset: 166 The dataset containing the prepared examples. 167 cls_token_index: 168 The index of the CLS token. 169 170 Returns: 171 The postprocessed predictions. 172 """ 173 all_start_logits = np.asarray(predictions)[:, :, 0] 174 all_end_logits = np.asarray(predictions)[:, :, 1] 175 176 # Build a map from an example to its corresponding features, being the blocks of 177 # text from the context that we're feeding into the model. An example can have 178 # multiple features/blocks if it has a long context. 179 id_to_index = {k: i for i, k in enumerate(dataset["id"])} 180 features_per_example = defaultdict(list) 181 for i, feature in enumerate(prepared_dataset): 182 id = feature["id"] 183 example_index = id_to_index[id] 184 features_per_example[example_index].append(i) 185 186 # Loop over all the examples 187 predictions = list() 188 for example_index, example in enumerate(dataset): 189 best_answer = find_best_answer( 190 all_start_logits=all_start_logits, 191 all_end_logits=all_end_logits, 192 prepared_dataset=prepared_dataset, 193 feature_indices=features_per_example[example_index], 194 context=example["context"], 195 max_answer_length=30, 196 num_best_logits=20, 197 min_null_score=0.0, 198 cls_token_index=cls_token_index, 199 ) 200 201 # Create the final prediction dictionary, to be added to the list of 202 # predictions 203 prediction = dict( 204 id=str(example["id"]), 205 prediction_text=best_answer, 206 no_answer_probability=0.0, 207 ) 208 predictions.append(prediction) 209 210 return predictions 211 212 213def find_best_answer( 214 all_start_logits: np.ndarray, 215 all_end_logits: np.ndarray, 216 prepared_dataset: Dataset, 217 feature_indices: list[int], 218 context: str, 219 max_answer_length: int, 220 num_best_logits: int, 221 min_null_score: float, 222 cls_token_index: int, 223) -> str: 224 """Find the best answer for a given example. 225 226 Args: 227 all_start_logits: 228 The start logits for all the features. 229 all_end_logits: 230 The end logits for all the features. 231 prepared_dataset: 232 The dataset containing the prepared examples. 233 feature_indices: 234 The indices of the features associated with the current example. 235 context: 236 The context of the example. 237 max_answer_length: 238 The maximum length of the answer. 239 num_best_logits: 240 The number of best logits to consider. 241 min_null_score: 242 The minimum score an answer can have. 243 cls_token_index: 244 The index of the CLS token. 245 246 Returns: 247 The best answer for the example. 248 """ 249 # Loop through all the features associated to the current example 250 valid_answers = list() 251 for feature_index in feature_indices: 252 features = prepared_dataset[feature_index] 253 254 # Get the predictions of the model for this feature 255 start_logits = all_start_logits[feature_index] 256 end_logits = all_end_logits[feature_index] 257 258 # Update minimum null prediction 259 cls_index = features["input_ids"].index(cls_token_index) 260 feature_null_score = (start_logits[cls_index] + end_logits[cls_index]).item() 261 if min_null_score < feature_null_score: 262 min_null_score = feature_null_score 263 264 # Find the valid answers for the feature 265 valid_answers_for_feature = find_valid_answers( 266 start_logits=start_logits, 267 end_logits=end_logits, 268 offset_mapping=features["offset_mapping"], 269 context=context, 270 max_answer_length=max_answer_length, 271 num_best_logits=num_best_logits, 272 min_null_score=min_null_score, 273 ) 274 valid_answers.extend(valid_answers_for_feature) 275 276 # In the very rare edge case we have not a single non-null prediction, we create a 277 # fake prediction to avoid failure 278 if not valid_answers: 279 return "" 280 281 # Otherwise, we select the answer with the largest score as the best answer, and 282 # return it 283 best_answer_dict = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0] 284 return best_answer_dict["text"] 285 286 287def find_valid_answers( 288 start_logits: np.ndarray, 289 end_logits: np.ndarray, 290 offset_mapping: list[tuple[int, int]], 291 context: str, 292 max_answer_length: int, 293 num_best_logits: int, 294 min_null_score: float, 295) -> list[dict]: 296 """Find the valid answers from the start and end indexes. 297 298 Args: 299 start_logits: 300 The logits for the start of the answer. 301 end_logits: 302 The logits for the end of the answer. 303 offset_mapping: 304 The offset mapping, being a list of pairs of integers for each token index, 305 containing the start and end character index in the original context. 306 max_answer_length: 307 The maximum length of the answer. 308 num_best_logits: 309 The number of best logits to consider. Note that this function will run in 310 time. 311 min_null_score: 312 The minimum score an answer can have. 313 314 Returns: 315 A list of the valid answers, each being a dictionary with keys "text" and 316 "score", the score being the sum of the start and end logits. 317 """ 318 # Fetch the top-k predictions for the start- and end token indices 319 start_indexes = np.argsort(start_logits)[-1 : -num_best_logits - 1 : -1].tolist() 320 end_indexes = np.argsort(end_logits)[-1 : -num_best_logits - 1 : -1].tolist() 321 322 # We loop over all combinations of starting and ending indexes for valid answers 323 valid_answers = list() 324 for start_index in start_indexes: 325 for end_index in end_indexes: 326 # If the starting or ending index is out-of-scope, meaning that they are 327 # either out of bounds or correspond to part of the input_ids that are not 328 # in the context, then we skip this index 329 if ( 330 start_index >= len(offset_mapping) 331 or end_index >= len(offset_mapping) 332 or tuple(offset_mapping[start_index]) == (-1, -1) 333 or tuple(offset_mapping[end_index]) == (-1, -1) 334 ): 335 continue 336 337 # Do not consider answers with a length that is either negative or greater 338 # than the context length 339 max_val = max_answer_length + start_index - 1 340 if end_index < start_index or end_index > max_val: 341 continue 342 343 # If we got to this point then the answer is valid, so we store the 344 # corresponding start- and end character indices in the original context, 345 # and from these extract the answer 346 start_char = offset_mapping[start_index][0] 347 end_char = offset_mapping[end_index][1] 348 text = context[start_char:end_char] 349 350 # Compute the score of the answer, being the sum of the start and end 351 # logits. Intuitively, this indicates how likely the answer is to be 352 # correct, and allows us to pick the best valid answer. 353 score = start_logits[start_index] + end_logits[end_index] 354 355 # Add the answer to the list of valid answers, if the score is greater 356 # than the minimum null score 357 if score > min_null_score: 358 valid_answers.append(dict(score=score, text=text)) 359 360 return valid_answers 361 362 363def postprocess_labels(dataset: Dataset) -> list[dict]: 364 """Postprocess the labels, to allow easier metric computation. 365 366 Args: 367 dataset: 368 The dataset containing the examples. 369 370 Returns: 371 The postprocessed labels. 372 """ 373 labels = list() 374 for example in dataset: 375 # Create the associated reference dictionary, to be added to the list of 376 # references 377 label = dict( 378 id=str(example["id"]), 379 answers=dict( 380 text=[example["answer"]], 381 answer_start=[example["answer_start"]], 382 ), 383 ) 384 labels.append(label) 385 386 return labels
17class QuestionAnswering(Task): 18 """Question answering task. 19 20 Args: 21 task_config: 22 The configuration of the task. 23 evaluation_config: 24 The configuration of the evaluation. 25 26 Attributes: 27 task_config: 28 The configuration of the task. 29 evaluation_config: 30 The configuration of the evaluation. 31 """ 32 33 def _pytorch_preprocess_fn( 34 self, 35 examples: BatchEncoding, 36 tokenizer: PreTrainedTokenizerBase, 37 model_config: ModelConfig, 38 task_config: TaskConfig, 39 ) -> BatchEncoding: 40 return prepare_test_examples( 41 examples=examples, 42 tokenizer=tokenizer, 43 ) 44 45 def _load_data_collator( 46 self, tokenizer_or_processor: PreTrainedTokenizerBase | AutoProcessor 47 ) -> DataCollator: 48 return default_data_collator 49 50 def _prepare_predictions_and_labels( 51 self, 52 predictions: list, 53 dataset: Dataset, 54 prepared_dataset: Dataset, 55 **kwargs, 56 ) -> list[tuple[list, list]]: 57 predictions = postprocess_predictions( 58 predictions=predictions, 59 dataset=dataset, 60 prepared_dataset=prepared_dataset, 61 cls_token_index=kwargs["cls_token_index"], 62 ) 63 labels = postprocess_labels(dataset=dataset) 64 65 return [(predictions, labels)] 66 67 def _check_if_model_is_trained_for_task(self, model_predictions: list) -> bool: 68 sample_preds = model_predictions[0] 69 elements_are_pairs = len(sample_preds[0]) == 2 70 leaves_are_floats = sample_preds[0][0].dtype.kind == "f" 71 elements_are_strings = isinstance(sample_preds[0], str) 72 return (elements_are_pairs and leaves_are_floats) or elements_are_strings 73 74 def _spacy_preprocess_fn(self, examples: dict) -> dict: 75 raise FrameworkCannotHandleTask( 76 framework="spaCy", task=self.task_config.pretty_name 77 ) 78 79 def _extract_spacy_predictions(self, tokens_processed: tuple) -> list: 80 raise FrameworkCannotHandleTask( 81 framework="spaCy", task=self.task_config.pretty_name 82 )
Question answering task.
Arguments:
- task_config: The configuration of the task.
- evaluation_config: The configuration of the evaluation.
Attributes:
- task_config: The configuration of the task.
- evaluation_config: The configuration of the evaluation.
Inherited Members
def
prepare_test_examples( examples: transformers.tokenization_utils_base.BatchEncoding, tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase) -> transformers.tokenization_utils_base.BatchEncoding:
85def prepare_test_examples( 86 examples: BatchEncoding, 87 tokenizer: PreTrainedTokenizerBase, 88) -> BatchEncoding: 89 """Prepare test examples. 90 91 Args: 92 examples: 93 Dictionary of test examples. 94 tokenizer: 95 The tokenizer used to preprocess the examples. 96 97 Returns: 98 Dictionary of prepared test examples. 99 """ 100 # Some of the questions have lots of whitespace on the left, which is not useful 101 # and will make the truncation of the context fail (the tokenized question will 102 # take a lots of space). So we remove that left whitespace 103 examples["question"] = [q.lstrip() for q in examples["question"]] 104 105 # Compute the stride, being a quarter of the context length 106 stride = tokenizer.model_max_length // 4 107 max_length = tokenizer.model_max_length - stride 108 109 # Tokenize our examples with truncation and maybe padding, but keep the overflows 110 # using a stride. This results in one example possible giving several features when 111 # a context is long, each of those features having a context that overlaps a bit 112 # the context of the previous feature. 113 tokenized_examples = tokenizer( 114 examples["question"], 115 examples["context"], 116 truncation="only_second", 117 max_length=max_length, 118 stride=stride, 119 return_overflowing_tokens=True, 120 return_offsets_mapping=True, 121 padding="max_length", 122 ) 123 124 # Since one example might give us several features if it has a long context, we 125 # need a map from a feature to its corresponding example. This key gives us just 126 # that. 127 sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") 128 129 # We keep the id that gave us this feature and we will store the offset mappings. 130 tokenized_examples["id"] = list() 131 132 for i in range(len(tokenized_examples["input_ids"])): 133 # Grab the sequence corresponding to that example (to know what is the context 134 # and what is the question). 135 sequence_ids = tokenized_examples.sequence_ids(i) 136 context_index = 1 137 138 # One example can give several spans, this is the index of the example 139 # containing this span of text. 140 sample_index = sample_mapping[i] 141 tokenized_examples["id"].append(examples["id"][sample_index]) 142 143 # Set to (-1, -1) the offset_mapping that are not part of the context so it's 144 # easy to determine if a token position is part of the context or not. 145 tokenized_examples["offset_mapping"][i] = [ 146 (o if sequence_ids[k] == context_index else (-1, -1)) 147 for k, o in enumerate(tokenized_examples["offset_mapping"][i]) 148 ] 149 150 return tokenized_examples
Prepare test examples.
Arguments:
- examples: Dictionary of test examples.
- tokenizer: The tokenizer used to preprocess the examples.
Returns:
Dictionary of prepared test examples.
def
postprocess_predictions( predictions: list, dataset: datasets.arrow_dataset.Dataset, prepared_dataset: datasets.arrow_dataset.Dataset, cls_token_index: int) -> list[dict]:
153def postprocess_predictions( 154 predictions: list, 155 dataset: Dataset, 156 prepared_dataset: Dataset, 157 cls_token_index: int, 158) -> list[dict]: 159 """Postprocess the predictions, to allow easier metric computation. 160 161 Args: 162 predictions: 163 The predictions to postprocess. 164 dataset: 165 The dataset containing the examples. 166 prepared_dataset: 167 The dataset containing the prepared examples. 168 cls_token_index: 169 The index of the CLS token. 170 171 Returns: 172 The postprocessed predictions. 173 """ 174 all_start_logits = np.asarray(predictions)[:, :, 0] 175 all_end_logits = np.asarray(predictions)[:, :, 1] 176 177 # Build a map from an example to its corresponding features, being the blocks of 178 # text from the context that we're feeding into the model. An example can have 179 # multiple features/blocks if it has a long context. 180 id_to_index = {k: i for i, k in enumerate(dataset["id"])} 181 features_per_example = defaultdict(list) 182 for i, feature in enumerate(prepared_dataset): 183 id = feature["id"] 184 example_index = id_to_index[id] 185 features_per_example[example_index].append(i) 186 187 # Loop over all the examples 188 predictions = list() 189 for example_index, example in enumerate(dataset): 190 best_answer = find_best_answer( 191 all_start_logits=all_start_logits, 192 all_end_logits=all_end_logits, 193 prepared_dataset=prepared_dataset, 194 feature_indices=features_per_example[example_index], 195 context=example["context"], 196 max_answer_length=30, 197 num_best_logits=20, 198 min_null_score=0.0, 199 cls_token_index=cls_token_index, 200 ) 201 202 # Create the final prediction dictionary, to be added to the list of 203 # predictions 204 prediction = dict( 205 id=str(example["id"]), 206 prediction_text=best_answer, 207 no_answer_probability=0.0, 208 ) 209 predictions.append(prediction) 210 211 return predictions
Postprocess the predictions, to allow easier metric computation.
Arguments:
- predictions: The predictions to postprocess.
- dataset: The dataset containing the examples.
- prepared_dataset: The dataset containing the prepared examples.
- cls_token_index: The index of the CLS token.
Returns:
The postprocessed predictions.
def
find_best_answer( all_start_logits: numpy.ndarray, all_end_logits: numpy.ndarray, prepared_dataset: datasets.arrow_dataset.Dataset, feature_indices: list[int], context: str, max_answer_length: int, num_best_logits: int, min_null_score: float, cls_token_index: int) -> str:
214def find_best_answer( 215 all_start_logits: np.ndarray, 216 all_end_logits: np.ndarray, 217 prepared_dataset: Dataset, 218 feature_indices: list[int], 219 context: str, 220 max_answer_length: int, 221 num_best_logits: int, 222 min_null_score: float, 223 cls_token_index: int, 224) -> str: 225 """Find the best answer for a given example. 226 227 Args: 228 all_start_logits: 229 The start logits for all the features. 230 all_end_logits: 231 The end logits for all the features. 232 prepared_dataset: 233 The dataset containing the prepared examples. 234 feature_indices: 235 The indices of the features associated with the current example. 236 context: 237 The context of the example. 238 max_answer_length: 239 The maximum length of the answer. 240 num_best_logits: 241 The number of best logits to consider. 242 min_null_score: 243 The minimum score an answer can have. 244 cls_token_index: 245 The index of the CLS token. 246 247 Returns: 248 The best answer for the example. 249 """ 250 # Loop through all the features associated to the current example 251 valid_answers = list() 252 for feature_index in feature_indices: 253 features = prepared_dataset[feature_index] 254 255 # Get the predictions of the model for this feature 256 start_logits = all_start_logits[feature_index] 257 end_logits = all_end_logits[feature_index] 258 259 # Update minimum null prediction 260 cls_index = features["input_ids"].index(cls_token_index) 261 feature_null_score = (start_logits[cls_index] + end_logits[cls_index]).item() 262 if min_null_score < feature_null_score: 263 min_null_score = feature_null_score 264 265 # Find the valid answers for the feature 266 valid_answers_for_feature = find_valid_answers( 267 start_logits=start_logits, 268 end_logits=end_logits, 269 offset_mapping=features["offset_mapping"], 270 context=context, 271 max_answer_length=max_answer_length, 272 num_best_logits=num_best_logits, 273 min_null_score=min_null_score, 274 ) 275 valid_answers.extend(valid_answers_for_feature) 276 277 # In the very rare edge case we have not a single non-null prediction, we create a 278 # fake prediction to avoid failure 279 if not valid_answers: 280 return "" 281 282 # Otherwise, we select the answer with the largest score as the best answer, and 283 # return it 284 best_answer_dict = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0] 285 return best_answer_dict["text"]
Find the best answer for a given example.
Arguments:
- all_start_logits: The start logits for all the features.
- all_end_logits: The end logits for all the features.
- prepared_dataset: The dataset containing the prepared examples.
- feature_indices: The indices of the features associated with the current example.
- context: The context of the example.
- max_answer_length: The maximum length of the answer.
- num_best_logits: The number of best logits to consider.
- min_null_score: The minimum score an answer can have.
- cls_token_index: The index of the CLS token.
Returns:
The best answer for the example.
def
find_valid_answers( start_logits: numpy.ndarray, end_logits: numpy.ndarray, offset_mapping: list[tuple[int, int]], context: str, max_answer_length: int, num_best_logits: int, min_null_score: float) -> list[dict]:
288def find_valid_answers( 289 start_logits: np.ndarray, 290 end_logits: np.ndarray, 291 offset_mapping: list[tuple[int, int]], 292 context: str, 293 max_answer_length: int, 294 num_best_logits: int, 295 min_null_score: float, 296) -> list[dict]: 297 """Find the valid answers from the start and end indexes. 298 299 Args: 300 start_logits: 301 The logits for the start of the answer. 302 end_logits: 303 The logits for the end of the answer. 304 offset_mapping: 305 The offset mapping, being a list of pairs of integers for each token index, 306 containing the start and end character index in the original context. 307 max_answer_length: 308 The maximum length of the answer. 309 num_best_logits: 310 The number of best logits to consider. Note that this function will run in 311 time. 312 min_null_score: 313 The minimum score an answer can have. 314 315 Returns: 316 A list of the valid answers, each being a dictionary with keys "text" and 317 "score", the score being the sum of the start and end logits. 318 """ 319 # Fetch the top-k predictions for the start- and end token indices 320 start_indexes = np.argsort(start_logits)[-1 : -num_best_logits - 1 : -1].tolist() 321 end_indexes = np.argsort(end_logits)[-1 : -num_best_logits - 1 : -1].tolist() 322 323 # We loop over all combinations of starting and ending indexes for valid answers 324 valid_answers = list() 325 for start_index in start_indexes: 326 for end_index in end_indexes: 327 # If the starting or ending index is out-of-scope, meaning that they are 328 # either out of bounds or correspond to part of the input_ids that are not 329 # in the context, then we skip this index 330 if ( 331 start_index >= len(offset_mapping) 332 or end_index >= len(offset_mapping) 333 or tuple(offset_mapping[start_index]) == (-1, -1) 334 or tuple(offset_mapping[end_index]) == (-1, -1) 335 ): 336 continue 337 338 # Do not consider answers with a length that is either negative or greater 339 # than the context length 340 max_val = max_answer_length + start_index - 1 341 if end_index < start_index or end_index > max_val: 342 continue 343 344 # If we got to this point then the answer is valid, so we store the 345 # corresponding start- and end character indices in the original context, 346 # and from these extract the answer 347 start_char = offset_mapping[start_index][0] 348 end_char = offset_mapping[end_index][1] 349 text = context[start_char:end_char] 350 351 # Compute the score of the answer, being the sum of the start and end 352 # logits. Intuitively, this indicates how likely the answer is to be 353 # correct, and allows us to pick the best valid answer. 354 score = start_logits[start_index] + end_logits[end_index] 355 356 # Add the answer to the list of valid answers, if the score is greater 357 # than the minimum null score 358 if score > min_null_score: 359 valid_answers.append(dict(score=score, text=text)) 360 361 return valid_answers
Find the valid answers from the start and end indexes.
Arguments:
- start_logits: The logits for the start of the answer.
- end_logits: The logits for the end of the answer.
- offset_mapping: The offset mapping, being a list of pairs of integers for each token index, containing the start and end character index in the original context.
- max_answer_length: The maximum length of the answer.
- num_best_logits: The number of best logits to consider. Note that this function will run in time.
- min_null_score: The minimum score an answer can have.
Returns:
A list of the valid answers, each being a dictionary with keys "text" and "score", the score being the sum of the start and end logits.
def
postprocess_labels(dataset: datasets.arrow_dataset.Dataset) -> list[dict]:
364def postprocess_labels(dataset: Dataset) -> list[dict]: 365 """Postprocess the labels, to allow easier metric computation. 366 367 Args: 368 dataset: 369 The dataset containing the examples. 370 371 Returns: 372 The postprocessed labels. 373 """ 374 labels = list() 375 for example in dataset: 376 # Create the associated reference dictionary, to be added to the list of 377 # references 378 label = dict( 379 id=str(example["id"]), 380 answers=dict( 381 text=[example["answer"]], 382 answer_start=[example["answer_start"]], 383 ), 384 ) 385 labels.append(label) 386 387 return labels
Postprocess the labels, to allow easier metric computation.
Arguments:
- dataset: The dataset containing the examples.
Returns:
The postprocessed labels.