I am writing a Question Answering system using pre-trained BERT
with a linear layer and a softmax
layer on top. When following the templates available on the net the labels of one example usually only consists of one answer_start_index
and one answer_end_index
. For example, from Huggingface
when instantiating a SQUADFeatures
object:
```
self.input_ids = input_ids
self.attention_mask = attention_mask
self.token_type_ids = token_type_ids
self.cls_index = cls_index
self.p_mask = p_mask
self.example_index = example_index
self.unique_id = unique_id
self.paragraph_len = paragraph_len
self.token_is_max_context = token_is_max_context
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
self.qas_id = qas_id
```
However, in my own dataset I have examples where the answer word is found at several locations in the context, i.e. there may be several correct spans constituting the answer.
My problem is I don't know how to manage such examples? In the templates available on the net labels are usually in a list, say:
- [start_example1, start_example2, start_example3]
- [end_example1, end_example2, end_example3]
In my case this may look like:
- [start_example1, [start_example2_1, start_example2_2], start_example3]
- and same for ends of course
In other words, I do not have a list containing one label per example, but a list containing either single-labels or a list of "labels" for an example, i.e. a list consisting of lists.
When following other templates the next step in the process is:
```
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
token_type_ids = torch.cat(token_type_ids, dim=0)
span_starts = torch(span_starts) #Something like this
span_ends = torch(span_ends) #Something like this
```
However this of course (?) raises an error as my span_start lists and span_end lists does not contain only single-items but sometimes a list within the list.
Anyone have an idea on how I can tackle this problem? Should I only use examples where there's only one span constituting the answer present in the context?
If I work around the torch-error, will the backpropagation / evaluation/ computation of loss still work?
Thank You! /B
Have you checked the code
I am not sure if this is the best way, but you may check instead of argmax to use
topk
, and check if this correspond to the correct answer.