import torch def indexes_from_sentence(lang, sentence): indexes = [] for word in sentence: if word.lemma_ not in lang.word2index.keys(): indexes.append(lang.word2index[""]) continue indexes.append(lang.word2index[word.lemma_]) return indexes def tensor_from_sentence(lang, sentence, seq_len): indexes = indexes_from_sentence(lang, sentence) pad_count = seq_len - len(indexes) for _ in range(pad_count): indexes.append(lang.word2index[""]) return torch.tensor(indexes, dtype=torch.long) def print_result(output, lang): for index in output[0].argmax(-1): print(lang.index2word[index.item()], end=" ") print()