initial commit of project
This commit is contained in:
53
evaluation/helpers/prediction_model.py
Normal file
53
evaluation/helpers/prediction_model.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# adapted from:
|
||||
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
|
||||
# “Exploiting Food Embeddings for Ingredient Substitution.”
|
||||
# In: Proceedings of the 14th International Joint Conference on Biomedical
|
||||
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
|
||||
# SciTePress, pp. 67–77. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
|
||||
|
||||
import json
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import BertModel, BertTokenizer
|
||||
|
||||
from evaluation.helpers.instructions_dataset import InstructionsDataset
|
||||
|
||||
|
||||
class PredictionModel:
|
||||
|
||||
def __init__(self, model_path=''):
|
||||
self.model: BertModel = BertModel.from_pretrained(
|
||||
pretrained_model_name_or_path=model_path)
|
||||
with open('train_model/vocab/used_ingredients.json', 'r') as f:
|
||||
used_ingredients = json.load(f)
|
||||
self.tokenizer = BertTokenizer(vocab_file='train_model/vocab/bert_vocab.txt', do_lower_case=False,
|
||||
max_len=512, never_split=used_ingredients, truncation=True)
|
||||
|
||||
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
self.model.to(self.device)
|
||||
|
||||
def predict_embeddings(self, sentences):
|
||||
dataset = InstructionsDataset(tokenizer=self.tokenizer, sentences=sentences)
|
||||
dataloader = DataLoader(dataset, batch_size=100, pin_memory=True)
|
||||
|
||||
embeddings = []
|
||||
ingredient_ids = []
|
||||
for batch in dataloader:
|
||||
batch = batch.to(self.device)
|
||||
with torch.no_grad():
|
||||
embeddings_batch = self.model(batch)
|
||||
embeddings.extend(embeddings_batch[0])
|
||||
ingredient_ids.extend(batch)
|
||||
|
||||
return torch.stack(embeddings), ingredient_ids
|
||||
|
||||
def compute_embedding_for_ingredient(self, sentence, ingredient_name):
|
||||
embeddings, ingredient_ids = self.predict_embeddings([sentence])
|
||||
embeddings_flat = embeddings.view((-1, 768))
|
||||
ingredient_ids_flat = torch.stack(ingredient_ids).flatten()
|
||||
food_id = self.tokenizer.convert_tokens_to_ids(ingredient_name)
|
||||
food_embedding = embeddings_flat[ingredient_ids_flat == food_id].cpu().numpy()
|
||||
|
||||
return food_embedding[0]
|
||||
Reference in New Issue
Block a user