Files
MasterarbeitCode/evaluation/helpers/instructions_dataset.py
2021-04-11 23:28:41 +02:00

39 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# adapted from:
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
# “Exploiting Food Embeddings for Ingredient Substitution.”
# In: Proceedings of the 14th International Joint Conference on Biomedical
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
# SciTePress, pp. 6777. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
class InstructionsDataset(Dataset):
def __init__(self, tokenizer, sentences):
self.tokenizer = tokenizer
batch_encoding = tokenizer.batch_encode_plus(sentences, add_special_tokens=True, max_length=512, truncation=True)
self.examples = batch_encoding["input_ids"]
self.examples = self._tensorize_batch([torch.tensor(elem) for elem in self.examples])
def _tensorize_batch(self, examples) -> torch.Tensor:
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
def __len__(self):
return len(self.examples)
def __getitem__(self, i):
return self.examples[i]