Files
MasterarbeitCode/evaluation/helpers/prediction_model.py
2021-04-11 23:28:41 +02:00

54 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# adapted from:
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
# “Exploiting Food Embeddings for Ingredient Substitution.”
# In: Proceedings of the 14th International Joint Conference on Biomedical
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
# SciTePress, pp. 6777. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
import json
import torch
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer
from evaluation.helpers.instructions_dataset import InstructionsDataset
class PredictionModel:
def __init__(self, model_path=''):
self.model: BertModel = BertModel.from_pretrained(
pretrained_model_name_or_path=model_path)
with open('train_model/vocab/used_ingredients.json', 'r') as f:
used_ingredients = json.load(f)
self.tokenizer = BertTokenizer(vocab_file='train_model/vocab/bert_vocab.txt', do_lower_case=False,
max_len=512, never_split=used_ingredients, truncation=True)
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model.to(self.device)
def predict_embeddings(self, sentences):
dataset = InstructionsDataset(tokenizer=self.tokenizer, sentences=sentences)
dataloader = DataLoader(dataset, batch_size=100, pin_memory=True)
embeddings = []
ingredient_ids = []
for batch in dataloader:
batch = batch.to(self.device)
with torch.no_grad():
embeddings_batch = self.model(batch)
embeddings.extend(embeddings_batch[0])
ingredient_ids.extend(batch)
return torch.stack(embeddings), ingredient_ids
def compute_embedding_for_ingredient(self, sentence, ingredient_name):
embeddings, ingredient_ids = self.predict_embeddings([sentence])
embeddings_flat = embeddings.view((-1, 768))
ingredient_ids_flat = torch.stack(ingredient_ids).flatten()
food_id = self.tokenizer.convert_tokens_to_ids(ingredient_name)
food_embedding = embeddings_flat[ingredient_ids_flat == food_id].cpu().numpy()
return food_embedding[0]