54 lines
2.2 KiB
Python
54 lines
2.2 KiB
Python
# adapted from:
|
||
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
|
||
# “Exploiting Food Embeddings for Ingredient Substitution.”
|
||
# In: Proceedings of the 14th International Joint Conference on Biomedical
|
||
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
|
||
# SciTePress, pp. 67–77. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
|
||
|
||
import json
|
||
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from transformers import BertModel, BertTokenizer
|
||
|
||
from evaluation.helpers.instructions_dataset import InstructionsDataset
|
||
|
||
|
||
class PredictionModel:
|
||
|
||
def __init__(self, model_path=''):
|
||
self.model: BertModel = BertModel.from_pretrained(
|
||
pretrained_model_name_or_path=model_path)
|
||
with open('train_model/vocab/used_ingredients.json', 'r') as f:
|
||
used_ingredients = json.load(f)
|
||
self.tokenizer = BertTokenizer(vocab_file='train_model/vocab/bert_vocab.txt', do_lower_case=False,
|
||
max_len=512, never_split=used_ingredients, truncation=True)
|
||
|
||
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||
|
||
self.model.to(self.device)
|
||
|
||
def predict_embeddings(self, sentences):
|
||
dataset = InstructionsDataset(tokenizer=self.tokenizer, sentences=sentences)
|
||
dataloader = DataLoader(dataset, batch_size=100, pin_memory=True)
|
||
|
||
embeddings = []
|
||
ingredient_ids = []
|
||
for batch in dataloader:
|
||
batch = batch.to(self.device)
|
||
with torch.no_grad():
|
||
embeddings_batch = self.model(batch)
|
||
embeddings.extend(embeddings_batch[0])
|
||
ingredient_ids.extend(batch)
|
||
|
||
return torch.stack(embeddings), ingredient_ids
|
||
|
||
def compute_embedding_for_ingredient(self, sentence, ingredient_name):
|
||
embeddings, ingredient_ids = self.predict_embeddings([sentence])
|
||
embeddings_flat = embeddings.view((-1, 768))
|
||
ingredient_ids_flat = torch.stack(ingredient_ids).flatten()
|
||
food_id = self.tokenizer.convert_tokens_to_ids(ingredient_name)
|
||
food_embedding = embeddings_flat[ingredient_ids_flat == food_id].cpu().numpy()
|
||
|
||
return food_embedding[0]
|