initial commit of project

This commit is contained in:
2021-04-11 19:51:12 +02:00
commit a21a8186d9
110 changed files with 16326178 additions and 0 deletions

View File

@@ -0,0 +1,53 @@
# adapted from:
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
# “Exploiting Food Embeddings for Ingredient Substitution.”
# In: Proceedings of the 14th International Joint Conference on Biomedical
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
# SciTePress, pp. 6777. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
import json
import torch
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer
from evaluation.helpers.instructions_dataset import InstructionsDataset
class PredictionModel:
def __init__(self, model_path=''):
self.model: BertModel = BertModel.from_pretrained(
pretrained_model_name_or_path=model_path)
with open('train_model/vocab/used_ingredients.json', 'r') as f:
used_ingredients = json.load(f)
self.tokenizer = BertTokenizer(vocab_file='train_model/vocab/bert_vocab.txt', do_lower_case=False,
max_len=512, never_split=used_ingredients, truncation=True)
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model.to(self.device)
def predict_embeddings(self, sentences):
dataset = InstructionsDataset(tokenizer=self.tokenizer, sentences=sentences)
dataloader = DataLoader(dataset, batch_size=100, pin_memory=True)
embeddings = []
ingredient_ids = []
for batch in dataloader:
batch = batch.to(self.device)
with torch.no_grad():
embeddings_batch = self.model(batch)
embeddings.extend(embeddings_batch[0])
ingredient_ids.extend(batch)
return torch.stack(embeddings), ingredient_ids
def compute_embedding_for_ingredient(self, sentence, ingredient_name):
embeddings, ingredient_ids = self.predict_embeddings([sentence])
embeddings_flat = embeddings.view((-1, 768))
ingredient_ids_flat = torch.stack(ingredient_ids).flatten()
food_id = self.tokenizer.convert_tokens_to_ids(ingredient_name)
food_embedding = embeddings_flat[ingredient_ids_flat == food_id].cpu().numpy()
return food_embedding[0]