utils.py 714 B

12345678910111213141516171819202122232425
  1. import torch
  2. def indexes_from_sentence(lang, sentence):
  3. indexes = []
  4. for word in sentence:
  5. if word.lemma_ not in lang.word2index.keys():
  6. indexes.append(lang.word2index["<pad>"])
  7. continue
  8. indexes.append(lang.word2index[word.lemma_])
  9. return indexes
  10. def tensor_from_sentence(lang, sentence, seq_len):
  11. indexes = indexes_from_sentence(lang, sentence)
  12. pad_count = seq_len - len(indexes)
  13. for _ in range(pad_count):
  14. indexes.append(lang.word2index["<pad>"])
  15. return torch.tensor(indexes, dtype=torch.long)
  16. def print_result(output, lang):
  17. for index in output[0].argmax(-1):
  18. print(lang.index2word[index.item()], end=" ")
  19. print()