dataset.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import torch
  2. import csv
  3. from tokenizer import CustomTokenizer
  4. from utils import tensor_from_sentence
  5. class Lang:
  6. def __init__(self, name):
  7. self.name = name
  8. self.word2index = {"<sos>": 0, "<eos>": 1, "<pad>": 2, ".": 3, ",": 4, "!": 5, "?": 6}
  9. self.index2word = {0: "<sos>", 1: "<eos>", 2: "<pad>", 3: ".", 4: ",", 5: "!", 6: "?"}
  10. self.n_words = 7 # Count SOS and EOS
  11. def add_sentence(self, sentence):
  12. for word in sentence:
  13. self.add_word(word.lemma_)
  14. def add_word(self, word):
  15. if word not in self.word2index.keys():
  16. self.word2index[word] = self.n_words
  17. self.index2word[self.n_words] = word
  18. self.n_words += 1
  19. else:
  20. pass
  21. class TextDataset(torch.utils.data.Dataset):
  22. def __init__(self, file_path1, file_path2, inp_seq_len):
  23. self.tokenizer = CustomTokenizer()
  24. self.pairs = self.read_langs(file_path1, file_path2)
  25. self.pairs = TextDataset.filter_pairs(self.pairs, inp_seq_len)
  26. self.rus_lang = Lang("rus")
  27. self.fill_langs()
  28. print(self.rus_lang.n_words)
  29. self.inp_seq_len = inp_seq_len
  30. def read_langs(self, file_path1, file_path2):
  31. print("Reading lines...")
  32. with open(file_path1, newline='', encoding="UTF-8") as csvfile_text:
  33. spamreader = csv.reader(csvfile_text, delimiter=',', quotechar='"')
  34. text_lines = []
  35. for idx, row in enumerate(spamreader):
  36. if idx == 0:
  37. continue
  38. text_lines.append(self.tokenizer.tokenize_rus("<sos> " + row[2] + " <eos>"))
  39. with open(file_path2, newline='', encoding="UTF-8") as csvfile_labels:
  40. spamreader = csv.reader(csvfile_labels, delimiter=',', quotechar='"')
  41. stances = []
  42. for idx, row in enumerate(spamreader):
  43. if idx == 0:
  44. continue
  45. stances += [True] if row[3] == "agree" else [False]
  46. assert len(stances) == len(text_lines)
  47. pairs = [(text_lines[i], stances[i]) for i in range(len(stances))]
  48. return pairs
  49. @staticmethod
  50. def filter_pairs(pairs, inp_seq_len):
  51. pairs = [pair for pair in pairs if len(pair[0]) < inp_seq_len]
  52. return pairs
  53. def fill_langs(self):
  54. for pair in self.pairs:
  55. self.rus_lang.add_sentence(pair[0])
  56. def __len__(self):
  57. return len(self.pairs)
  58. def __getitem__(self, item):
  59. pair = self.pairs[item]
  60. enc_inp = tensor_from_sentence(self.rus_lang, pair[0], self.inp_seq_len)
  61. dec_tgt = torch.tensor(pair[1]).float().unsqueeze(-1)
  62. print(pair[0])
  63. return {
  64. "encoder_input": enc_inp, # (seq_len)
  65. "encoder_mask": (enc_inp != self.rus_lang.word2index["<pad>"]).unsqueeze(0).unsqueeze(0),
  66. # (1, seq_len) & (1, seq_len, seq_len),
  67. "target": dec_tgt, # (seq_len)
  68. }