12345678910111213141516171819202122232425 |
- import torch
- def indexes_from_sentence(lang, sentence):
- indexes = []
- for word in sentence:
- if word.lemma_ not in lang.word2index.keys():
- indexes.append(lang.word2index["<pad>"])
- continue
- indexes.append(lang.word2index[word.lemma_])
- return indexes
- def tensor_from_sentence(lang, sentence, seq_len):
- indexes = indexes_from_sentence(lang, sentence)
- pad_count = seq_len - len(indexes)
- for _ in range(pad_count):
- indexes.append(lang.word2index["<pad>"])
- return torch.tensor(indexes, dtype=torch.long)
- def print_result(output, lang):
- for index in output[0].argmax(-1):
- print(lang.index2word[index.item()], end=" ")
- print()
|