alexandra_ai_eval.question_answering

Class for question-answering tasks.

  1"""Class for question-answering tasks."""
  2
  3from collections import defaultdict
  4
  5import numpy as np
  6from datasets.arrow_dataset import Dataset
  7from transformers.data.data_collator import DataCollator, default_data_collator
  8from transformers.models.auto.processing_auto import AutoProcessor
  9from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
 10
 11from .config import ModelConfig, TaskConfig
 12from .exceptions import FrameworkCannotHandleTask
 13from .task import Task
 14
 15
 16class QuestionAnswering(Task):
 17    """Question answering task.
 18
 19    Args:
 20        task_config:
 21            The configuration of the task.
 22        evaluation_config:
 23            The configuration of the evaluation.
 24
 25    Attributes:
 26        task_config:
 27            The configuration of the task.
 28        evaluation_config:
 29            The configuration of the evaluation.
 30    """
 31
 32    def _pytorch_preprocess_fn(
 33        self,
 34        examples: BatchEncoding,
 35        tokenizer: PreTrainedTokenizerBase,
 36        model_config: ModelConfig,
 37        task_config: TaskConfig,
 38    ) -> BatchEncoding:
 39        return prepare_test_examples(
 40            examples=examples,
 41            tokenizer=tokenizer,
 42        )
 43
 44    def _load_data_collator(
 45        self, tokenizer_or_processor: PreTrainedTokenizerBase | AutoProcessor
 46    ) -> DataCollator:
 47        return default_data_collator
 48
 49    def _prepare_predictions_and_labels(
 50        self,
 51        predictions: list,
 52        dataset: Dataset,
 53        prepared_dataset: Dataset,
 54        **kwargs,
 55    ) -> list[tuple[list, list]]:
 56        predictions = postprocess_predictions(
 57            predictions=predictions,
 58            dataset=dataset,
 59            prepared_dataset=prepared_dataset,
 60            cls_token_index=kwargs["cls_token_index"],
 61        )
 62        labels = postprocess_labels(dataset=dataset)
 63
 64        return [(predictions, labels)]
 65
 66    def _check_if_model_is_trained_for_task(self, model_predictions: list) -> bool:
 67        sample_preds = model_predictions[0]
 68        elements_are_pairs = len(sample_preds[0]) == 2
 69        leaves_are_floats = sample_preds[0][0].dtype.kind == "f"
 70        elements_are_strings = isinstance(sample_preds[0], str)
 71        return (elements_are_pairs and leaves_are_floats) or elements_are_strings
 72
 73    def _spacy_preprocess_fn(self, examples: dict) -> dict:
 74        raise FrameworkCannotHandleTask(
 75            framework="spaCy", task=self.task_config.pretty_name
 76        )
 77
 78    def _extract_spacy_predictions(self, tokens_processed: tuple) -> list:
 79        raise FrameworkCannotHandleTask(
 80            framework="spaCy", task=self.task_config.pretty_name
 81        )
 82
 83
 84def prepare_test_examples(
 85    examples: BatchEncoding,
 86    tokenizer: PreTrainedTokenizerBase,
 87) -> BatchEncoding:
 88    """Prepare test examples.
 89
 90    Args:
 91        examples:
 92            Dictionary of test examples.
 93        tokenizer:
 94            The tokenizer used to preprocess the examples.
 95
 96    Returns:
 97        Dictionary of prepared test examples.
 98    """
 99    # Some of the questions have lots of whitespace on the left, which is not useful
100    # and will make the truncation of the context fail (the tokenized question will
101    # take a lots of space). So we remove that left whitespace
102    examples["question"] = [q.lstrip() for q in examples["question"]]
103
104    # Compute the stride, being a quarter of the context length
105    stride = tokenizer.model_max_length // 4
106    max_length = tokenizer.model_max_length - stride
107
108    # Tokenize our examples with truncation and maybe padding, but keep the overflows
109    # using a stride. This results in one example possible giving several features when
110    # a context is long, each of those features having a context that overlaps a bit
111    # the context of the previous feature.
112    tokenized_examples = tokenizer(
113        examples["question"],
114        examples["context"],
115        truncation="only_second",
116        max_length=max_length,
117        stride=stride,
118        return_overflowing_tokens=True,
119        return_offsets_mapping=True,
120        padding="max_length",
121    )
122
123    # Since one example might give us several features if it has a long context, we
124    # need a map from a feature to its corresponding example. This key gives us just
125    # that.
126    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
127
128    # We keep the id that gave us this feature and we will store the offset mappings.
129    tokenized_examples["id"] = list()
130
131    for i in range(len(tokenized_examples["input_ids"])):
132        # Grab the sequence corresponding to that example (to know what is the context
133        # and what is the question).
134        sequence_ids = tokenized_examples.sequence_ids(i)
135        context_index = 1
136
137        # One example can give several spans, this is the index of the example
138        # containing this span of text.
139        sample_index = sample_mapping[i]
140        tokenized_examples["id"].append(examples["id"][sample_index])
141
142        # Set to (-1, -1) the offset_mapping that are not part of the context so it's
143        # easy to determine if a token position is part of the context or not.
144        tokenized_examples["offset_mapping"][i] = [
145            (o if sequence_ids[k] == context_index else (-1, -1))
146            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
147        ]
148
149    return tokenized_examples
150
151
152def postprocess_predictions(
153    predictions: list,
154    dataset: Dataset,
155    prepared_dataset: Dataset,
156    cls_token_index: int,
157) -> list[dict]:
158    """Postprocess the predictions, to allow easier metric computation.
159
160    Args:
161        predictions:
162            The predictions to postprocess.
163        dataset:
164            The dataset containing the examples.
165        prepared_dataset:
166            The dataset containing the prepared examples.
167        cls_token_index:
168            The index of the CLS token.
169
170    Returns:
171        The postprocessed predictions.
172    """
173    all_start_logits = np.asarray(predictions)[:, :, 0]
174    all_end_logits = np.asarray(predictions)[:, :, 1]
175
176    # Build a map from an example to its corresponding features, being the blocks of
177    # text from the context that we're feeding into the model. An example can have
178    # multiple features/blocks if it has a long context.
179    id_to_index = {k: i for i, k in enumerate(dataset["id"])}
180    features_per_example = defaultdict(list)
181    for i, feature in enumerate(prepared_dataset):
182        id = feature["id"]
183        example_index = id_to_index[id]
184        features_per_example[example_index].append(i)
185
186    # Loop over all the examples
187    predictions = list()
188    for example_index, example in enumerate(dataset):
189        best_answer = find_best_answer(
190            all_start_logits=all_start_logits,
191            all_end_logits=all_end_logits,
192            prepared_dataset=prepared_dataset,
193            feature_indices=features_per_example[example_index],
194            context=example["context"],
195            max_answer_length=30,
196            num_best_logits=20,
197            min_null_score=0.0,
198            cls_token_index=cls_token_index,
199        )
200
201        # Create the final prediction dictionary, to be added to the list of
202        # predictions
203        prediction = dict(
204            id=str(example["id"]),
205            prediction_text=best_answer,
206            no_answer_probability=0.0,
207        )
208        predictions.append(prediction)
209
210    return predictions
211
212
213def find_best_answer(
214    all_start_logits: np.ndarray,
215    all_end_logits: np.ndarray,
216    prepared_dataset: Dataset,
217    feature_indices: list[int],
218    context: str,
219    max_answer_length: int,
220    num_best_logits: int,
221    min_null_score: float,
222    cls_token_index: int,
223) -> str:
224    """Find the best answer for a given example.
225
226    Args:
227        all_start_logits:
228            The start logits for all the features.
229        all_end_logits:
230            The end logits for all the features.
231        prepared_dataset:
232            The dataset containing the prepared examples.
233        feature_indices:
234            The indices of the features associated with the current example.
235        context:
236            The context of the example.
237        max_answer_length:
238            The maximum length of the answer.
239        num_best_logits:
240            The number of best logits to consider.
241        min_null_score:
242            The minimum score an answer can have.
243        cls_token_index:
244            The index of the CLS token.
245
246    Returns:
247        The best answer for the example.
248    """
249    # Loop through all the features associated to the current example
250    valid_answers = list()
251    for feature_index in feature_indices:
252        features = prepared_dataset[feature_index]
253
254        # Get the predictions of the model for this feature
255        start_logits = all_start_logits[feature_index]
256        end_logits = all_end_logits[feature_index]
257
258        # Update minimum null prediction
259        cls_index = features["input_ids"].index(cls_token_index)
260        feature_null_score = (start_logits[cls_index] + end_logits[cls_index]).item()
261        if min_null_score < feature_null_score:
262            min_null_score = feature_null_score
263
264        # Find the valid answers for the feature
265        valid_answers_for_feature = find_valid_answers(
266            start_logits=start_logits,
267            end_logits=end_logits,
268            offset_mapping=features["offset_mapping"],
269            context=context,
270            max_answer_length=max_answer_length,
271            num_best_logits=num_best_logits,
272            min_null_score=min_null_score,
273        )
274        valid_answers.extend(valid_answers_for_feature)
275
276    # In the very rare edge case we have not a single non-null prediction, we create a
277    # fake prediction to avoid failure
278    if not valid_answers:
279        return ""
280
281    # Otherwise, we select the answer with the largest score as the best answer, and
282    # return it
283    best_answer_dict = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
284    return best_answer_dict["text"]
285
286
287def find_valid_answers(
288    start_logits: np.ndarray,
289    end_logits: np.ndarray,
290    offset_mapping: list[tuple[int, int]],
291    context: str,
292    max_answer_length: int,
293    num_best_logits: int,
294    min_null_score: float,
295) -> list[dict]:
296    """Find the valid answers from the start and end indexes.
297
298    Args:
299        start_logits:
300            The logits for the start of the answer.
301        end_logits:
302            The logits for the end of the answer.
303        offset_mapping:
304            The offset mapping, being a list of pairs of integers for each token index,
305            containing the start and end character index in the original context.
306        max_answer_length:
307            The maximum length of the answer.
308        num_best_logits:
309            The number of best logits to consider. Note that this function will run in
310             time.
311        min_null_score:
312            The minimum score an answer can have.
313
314    Returns:
315        A list of the valid answers, each being a dictionary with keys "text" and
316        "score", the score being the sum of the start and end logits.
317    """
318    # Fetch the top-k predictions for the start- and end token indices
319    start_indexes = np.argsort(start_logits)[-1 : -num_best_logits - 1 : -1].tolist()
320    end_indexes = np.argsort(end_logits)[-1 : -num_best_logits - 1 : -1].tolist()
321
322    # We loop over all combinations of starting and ending indexes for valid answers
323    valid_answers = list()
324    for start_index in start_indexes:
325        for end_index in end_indexes:
326            # If the starting or ending index is out-of-scope, meaning that they are
327            # either out of bounds or correspond to part of the input_ids that are not
328            # in the context, then we skip this index
329            if (
330                start_index >= len(offset_mapping)
331                or end_index >= len(offset_mapping)
332                or tuple(offset_mapping[start_index]) == (-1, -1)
333                or tuple(offset_mapping[end_index]) == (-1, -1)
334            ):
335                continue
336
337            # Do not consider answers with a length that is either negative or greater
338            # than the context length
339            max_val = max_answer_length + start_index - 1
340            if end_index < start_index or end_index > max_val:
341                continue
342
343            # If we got to this point then the answer is valid, so we store the
344            # corresponding start- and end character indices in the original context,
345            # and from these extract the answer
346            start_char = offset_mapping[start_index][0]
347            end_char = offset_mapping[end_index][1]
348            text = context[start_char:end_char]
349
350            # Compute the score of the answer, being the sum of the start and end
351            # logits. Intuitively, this indicates how likely the answer is to be
352            # correct, and allows us to pick the best valid answer.
353            score = start_logits[start_index] + end_logits[end_index]
354
355            # Add the answer to the list of valid answers, if the score is greater
356            # than the minimum null score
357            if score > min_null_score:
358                valid_answers.append(dict(score=score, text=text))
359
360    return valid_answers
361
362
363def postprocess_labels(dataset: Dataset) -> list[dict]:
364    """Postprocess the labels, to allow easier metric computation.
365
366    Args:
367        dataset:
368            The dataset containing the examples.
369
370    Returns:
371         The postprocessed labels.
372    """
373    labels = list()
374    for example in dataset:
375        # Create the associated reference dictionary, to be added to the list of
376        # references
377        label = dict(
378            id=str(example["id"]),
379            answers=dict(
380                text=[example["answer"]],
381                answer_start=[example["answer_start"]],
382            ),
383        )
384        labels.append(label)
385
386    return labels
class QuestionAnswering(alexandra_ai_eval.task.Task):
17class QuestionAnswering(Task):
18    """Question answering task.
19
20    Args:
21        task_config:
22            The configuration of the task.
23        evaluation_config:
24            The configuration of the evaluation.
25
26    Attributes:
27        task_config:
28            The configuration of the task.
29        evaluation_config:
30            The configuration of the evaluation.
31    """
32
33    def _pytorch_preprocess_fn(
34        self,
35        examples: BatchEncoding,
36        tokenizer: PreTrainedTokenizerBase,
37        model_config: ModelConfig,
38        task_config: TaskConfig,
39    ) -> BatchEncoding:
40        return prepare_test_examples(
41            examples=examples,
42            tokenizer=tokenizer,
43        )
44
45    def _load_data_collator(
46        self, tokenizer_or_processor: PreTrainedTokenizerBase | AutoProcessor
47    ) -> DataCollator:
48        return default_data_collator
49
50    def _prepare_predictions_and_labels(
51        self,
52        predictions: list,
53        dataset: Dataset,
54        prepared_dataset: Dataset,
55        **kwargs,
56    ) -> list[tuple[list, list]]:
57        predictions = postprocess_predictions(
58            predictions=predictions,
59            dataset=dataset,
60            prepared_dataset=prepared_dataset,
61            cls_token_index=kwargs["cls_token_index"],
62        )
63        labels = postprocess_labels(dataset=dataset)
64
65        return [(predictions, labels)]
66
67    def _check_if_model_is_trained_for_task(self, model_predictions: list) -> bool:
68        sample_preds = model_predictions[0]
69        elements_are_pairs = len(sample_preds[0]) == 2
70        leaves_are_floats = sample_preds[0][0].dtype.kind == "f"
71        elements_are_strings = isinstance(sample_preds[0], str)
72        return (elements_are_pairs and leaves_are_floats) or elements_are_strings
73
74    def _spacy_preprocess_fn(self, examples: dict) -> dict:
75        raise FrameworkCannotHandleTask(
76            framework="spaCy", task=self.task_config.pretty_name
77        )
78
79    def _extract_spacy_predictions(self, tokens_processed: tuple) -> list:
80        raise FrameworkCannotHandleTask(
81            framework="spaCy", task=self.task_config.pretty_name
82        )

Question answering task.

Arguments:
  • task_config: The configuration of the task.
  • evaluation_config: The configuration of the evaluation.
Attributes:
  • task_config: The configuration of the task.
  • evaluation_config: The configuration of the evaluation.
def prepare_test_examples( examples: transformers.tokenization_utils_base.BatchEncoding, tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase) -> transformers.tokenization_utils_base.BatchEncoding:
 85def prepare_test_examples(
 86    examples: BatchEncoding,
 87    tokenizer: PreTrainedTokenizerBase,
 88) -> BatchEncoding:
 89    """Prepare test examples.
 90
 91    Args:
 92        examples:
 93            Dictionary of test examples.
 94        tokenizer:
 95            The tokenizer used to preprocess the examples.
 96
 97    Returns:
 98        Dictionary of prepared test examples.
 99    """
100    # Some of the questions have lots of whitespace on the left, which is not useful
101    # and will make the truncation of the context fail (the tokenized question will
102    # take a lots of space). So we remove that left whitespace
103    examples["question"] = [q.lstrip() for q in examples["question"]]
104
105    # Compute the stride, being a quarter of the context length
106    stride = tokenizer.model_max_length // 4
107    max_length = tokenizer.model_max_length - stride
108
109    # Tokenize our examples with truncation and maybe padding, but keep the overflows
110    # using a stride. This results in one example possible giving several features when
111    # a context is long, each of those features having a context that overlaps a bit
112    # the context of the previous feature.
113    tokenized_examples = tokenizer(
114        examples["question"],
115        examples["context"],
116        truncation="only_second",
117        max_length=max_length,
118        stride=stride,
119        return_overflowing_tokens=True,
120        return_offsets_mapping=True,
121        padding="max_length",
122    )
123
124    # Since one example might give us several features if it has a long context, we
125    # need a map from a feature to its corresponding example. This key gives us just
126    # that.
127    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
128
129    # We keep the id that gave us this feature and we will store the offset mappings.
130    tokenized_examples["id"] = list()
131
132    for i in range(len(tokenized_examples["input_ids"])):
133        # Grab the sequence corresponding to that example (to know what is the context
134        # and what is the question).
135        sequence_ids = tokenized_examples.sequence_ids(i)
136        context_index = 1
137
138        # One example can give several spans, this is the index of the example
139        # containing this span of text.
140        sample_index = sample_mapping[i]
141        tokenized_examples["id"].append(examples["id"][sample_index])
142
143        # Set to (-1, -1) the offset_mapping that are not part of the context so it's
144        # easy to determine if a token position is part of the context or not.
145        tokenized_examples["offset_mapping"][i] = [
146            (o if sequence_ids[k] == context_index else (-1, -1))
147            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
148        ]
149
150    return tokenized_examples

Prepare test examples.

Arguments:
  • examples: Dictionary of test examples.
  • tokenizer: The tokenizer used to preprocess the examples.
Returns:

Dictionary of prepared test examples.

def postprocess_predictions( predictions: list, dataset: datasets.arrow_dataset.Dataset, prepared_dataset: datasets.arrow_dataset.Dataset, cls_token_index: int) -> list[dict]:
153def postprocess_predictions(
154    predictions: list,
155    dataset: Dataset,
156    prepared_dataset: Dataset,
157    cls_token_index: int,
158) -> list[dict]:
159    """Postprocess the predictions, to allow easier metric computation.
160
161    Args:
162        predictions:
163            The predictions to postprocess.
164        dataset:
165            The dataset containing the examples.
166        prepared_dataset:
167            The dataset containing the prepared examples.
168        cls_token_index:
169            The index of the CLS token.
170
171    Returns:
172        The postprocessed predictions.
173    """
174    all_start_logits = np.asarray(predictions)[:, :, 0]
175    all_end_logits = np.asarray(predictions)[:, :, 1]
176
177    # Build a map from an example to its corresponding features, being the blocks of
178    # text from the context that we're feeding into the model. An example can have
179    # multiple features/blocks if it has a long context.
180    id_to_index = {k: i for i, k in enumerate(dataset["id"])}
181    features_per_example = defaultdict(list)
182    for i, feature in enumerate(prepared_dataset):
183        id = feature["id"]
184        example_index = id_to_index[id]
185        features_per_example[example_index].append(i)
186
187    # Loop over all the examples
188    predictions = list()
189    for example_index, example in enumerate(dataset):
190        best_answer = find_best_answer(
191            all_start_logits=all_start_logits,
192            all_end_logits=all_end_logits,
193            prepared_dataset=prepared_dataset,
194            feature_indices=features_per_example[example_index],
195            context=example["context"],
196            max_answer_length=30,
197            num_best_logits=20,
198            min_null_score=0.0,
199            cls_token_index=cls_token_index,
200        )
201
202        # Create the final prediction dictionary, to be added to the list of
203        # predictions
204        prediction = dict(
205            id=str(example["id"]),
206            prediction_text=best_answer,
207            no_answer_probability=0.0,
208        )
209        predictions.append(prediction)
210
211    return predictions

Postprocess the predictions, to allow easier metric computation.

Arguments:
  • predictions: The predictions to postprocess.
  • dataset: The dataset containing the examples.
  • prepared_dataset: The dataset containing the prepared examples.
  • cls_token_index: The index of the CLS token.
Returns:

The postprocessed predictions.

def find_best_answer( all_start_logits: numpy.ndarray, all_end_logits: numpy.ndarray, prepared_dataset: datasets.arrow_dataset.Dataset, feature_indices: list[int], context: str, max_answer_length: int, num_best_logits: int, min_null_score: float, cls_token_index: int) -> str:
214def find_best_answer(
215    all_start_logits: np.ndarray,
216    all_end_logits: np.ndarray,
217    prepared_dataset: Dataset,
218    feature_indices: list[int],
219    context: str,
220    max_answer_length: int,
221    num_best_logits: int,
222    min_null_score: float,
223    cls_token_index: int,
224) -> str:
225    """Find the best answer for a given example.
226
227    Args:
228        all_start_logits:
229            The start logits for all the features.
230        all_end_logits:
231            The end logits for all the features.
232        prepared_dataset:
233            The dataset containing the prepared examples.
234        feature_indices:
235            The indices of the features associated with the current example.
236        context:
237            The context of the example.
238        max_answer_length:
239            The maximum length of the answer.
240        num_best_logits:
241            The number of best logits to consider.
242        min_null_score:
243            The minimum score an answer can have.
244        cls_token_index:
245            The index of the CLS token.
246
247    Returns:
248        The best answer for the example.
249    """
250    # Loop through all the features associated to the current example
251    valid_answers = list()
252    for feature_index in feature_indices:
253        features = prepared_dataset[feature_index]
254
255        # Get the predictions of the model for this feature
256        start_logits = all_start_logits[feature_index]
257        end_logits = all_end_logits[feature_index]
258
259        # Update minimum null prediction
260        cls_index = features["input_ids"].index(cls_token_index)
261        feature_null_score = (start_logits[cls_index] + end_logits[cls_index]).item()
262        if min_null_score < feature_null_score:
263            min_null_score = feature_null_score
264
265        # Find the valid answers for the feature
266        valid_answers_for_feature = find_valid_answers(
267            start_logits=start_logits,
268            end_logits=end_logits,
269            offset_mapping=features["offset_mapping"],
270            context=context,
271            max_answer_length=max_answer_length,
272            num_best_logits=num_best_logits,
273            min_null_score=min_null_score,
274        )
275        valid_answers.extend(valid_answers_for_feature)
276
277    # In the very rare edge case we have not a single non-null prediction, we create a
278    # fake prediction to avoid failure
279    if not valid_answers:
280        return ""
281
282    # Otherwise, we select the answer with the largest score as the best answer, and
283    # return it
284    best_answer_dict = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
285    return best_answer_dict["text"]

Find the best answer for a given example.

Arguments:
  • all_start_logits: The start logits for all the features.
  • all_end_logits: The end logits for all the features.
  • prepared_dataset: The dataset containing the prepared examples.
  • feature_indices: The indices of the features associated with the current example.
  • context: The context of the example.
  • max_answer_length: The maximum length of the answer.
  • num_best_logits: The number of best logits to consider.
  • min_null_score: The minimum score an answer can have.
  • cls_token_index: The index of the CLS token.
Returns:

The best answer for the example.

def find_valid_answers( start_logits: numpy.ndarray, end_logits: numpy.ndarray, offset_mapping: list[tuple[int, int]], context: str, max_answer_length: int, num_best_logits: int, min_null_score: float) -> list[dict]:
288def find_valid_answers(
289    start_logits: np.ndarray,
290    end_logits: np.ndarray,
291    offset_mapping: list[tuple[int, int]],
292    context: str,
293    max_answer_length: int,
294    num_best_logits: int,
295    min_null_score: float,
296) -> list[dict]:
297    """Find the valid answers from the start and end indexes.
298
299    Args:
300        start_logits:
301            The logits for the start of the answer.
302        end_logits:
303            The logits for the end of the answer.
304        offset_mapping:
305            The offset mapping, being a list of pairs of integers for each token index,
306            containing the start and end character index in the original context.
307        max_answer_length:
308            The maximum length of the answer.
309        num_best_logits:
310            The number of best logits to consider. Note that this function will run in
311             time.
312        min_null_score:
313            The minimum score an answer can have.
314
315    Returns:
316        A list of the valid answers, each being a dictionary with keys "text" and
317        "score", the score being the sum of the start and end logits.
318    """
319    # Fetch the top-k predictions for the start- and end token indices
320    start_indexes = np.argsort(start_logits)[-1 : -num_best_logits - 1 : -1].tolist()
321    end_indexes = np.argsort(end_logits)[-1 : -num_best_logits - 1 : -1].tolist()
322
323    # We loop over all combinations of starting and ending indexes for valid answers
324    valid_answers = list()
325    for start_index in start_indexes:
326        for end_index in end_indexes:
327            # If the starting or ending index is out-of-scope, meaning that they are
328            # either out of bounds or correspond to part of the input_ids that are not
329            # in the context, then we skip this index
330            if (
331                start_index >= len(offset_mapping)
332                or end_index >= len(offset_mapping)
333                or tuple(offset_mapping[start_index]) == (-1, -1)
334                or tuple(offset_mapping[end_index]) == (-1, -1)
335            ):
336                continue
337
338            # Do not consider answers with a length that is either negative or greater
339            # than the context length
340            max_val = max_answer_length + start_index - 1
341            if end_index < start_index or end_index > max_val:
342                continue
343
344            # If we got to this point then the answer is valid, so we store the
345            # corresponding start- and end character indices in the original context,
346            # and from these extract the answer
347            start_char = offset_mapping[start_index][0]
348            end_char = offset_mapping[end_index][1]
349            text = context[start_char:end_char]
350
351            # Compute the score of the answer, being the sum of the start and end
352            # logits. Intuitively, this indicates how likely the answer is to be
353            # correct, and allows us to pick the best valid answer.
354            score = start_logits[start_index] + end_logits[end_index]
355
356            # Add the answer to the list of valid answers, if the score is greater
357            # than the minimum null score
358            if score > min_null_score:
359                valid_answers.append(dict(score=score, text=text))
360
361    return valid_answers

Find the valid answers from the start and end indexes.

Arguments:
  • start_logits: The logits for the start of the answer.
  • end_logits: The logits for the end of the answer.
  • offset_mapping: The offset mapping, being a list of pairs of integers for each token index, containing the start and end character index in the original context.
  • max_answer_length: The maximum length of the answer.
  • num_best_logits: The number of best logits to consider. Note that this function will run in time.
  • min_null_score: The minimum score an answer can have.
Returns:

A list of the valid answers, each being a dictionary with keys "text" and "score", the score being the sum of the start and end logits.

def postprocess_labels(dataset: datasets.arrow_dataset.Dataset) -> list[dict]:
364def postprocess_labels(dataset: Dataset) -> list[dict]:
365    """Postprocess the labels, to allow easier metric computation.
366
367    Args:
368        dataset:
369            The dataset containing the examples.
370
371    Returns:
372         The postprocessed labels.
373    """
374    labels = list()
375    for example in dataset:
376        # Create the associated reference dictionary, to be added to the list of
377        # references
378        label = dict(
379            id=str(example["id"]),
380            answers=dict(
381                text=[example["answer"]],
382                answer_start=[example["answer_start"]],
383            ),
384        )
385        labels.append(label)
386
387    return labels

Postprocess the labels, to allow easier metric computation.

Arguments:
  • dataset: The dataset containing the examples.
Returns:

The postprocessed labels.