12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- import torch
- import csv
- from tokenizer import CustomTokenizer
- from utils import tensor_from_sentence
- class Lang:
- def __init__(self, name):
- self.name = name
- self.word2index = {"<sos>": 0, "<eos>": 1, "<pad>": 2, ".": 3, ",": 4, "!": 5, "?": 6}
- self.index2word = {0: "<sos>", 1: "<eos>", 2: "<pad>", 3: ".", 4: ",", 5: "!", 6: "?"}
- self.n_words = 7 # Count SOS and EOS
- def add_sentence(self, sentence):
- for word in sentence:
- self.add_word(word.lemma_)
- def add_word(self, word):
- if word not in self.word2index.keys():
- self.word2index[word] = self.n_words
- self.index2word[self.n_words] = word
- self.n_words += 1
- else:
- pass
- class TextDataset(torch.utils.data.Dataset):
- def __init__(self, file_path1, file_path2, inp_seq_len):
- self.tokenizer = CustomTokenizer()
- self.pairs = self.read_langs(file_path1, file_path2)
- self.pairs = TextDataset.filter_pairs(self.pairs, inp_seq_len)
- self.rus_lang = Lang("rus")
- self.fill_langs()
- print(self.rus_lang.n_words)
- self.inp_seq_len = inp_seq_len
- def read_langs(self, file_path1, file_path2):
- print("Reading lines...")
- with open(file_path1, newline='', encoding="UTF-8") as csvfile_text:
- spamreader = csv.reader(csvfile_text, delimiter=',', quotechar='"')
- text_lines = []
- for idx, row in enumerate(spamreader):
- if idx == 0:
- continue
- text_lines.append(self.tokenizer.tokenize_rus("<sos> " + row[2] + " <eos>"))
- with open(file_path2, newline='', encoding="UTF-8") as csvfile_labels:
- spamreader = csv.reader(csvfile_labels, delimiter=',', quotechar='"')
- stances = []
- for idx, row in enumerate(spamreader):
- if idx == 0:
- continue
- stances += [True] if row[3] == "agree" else [False]
- assert len(stances) == len(text_lines)
- pairs = [(text_lines[i], stances[i]) for i in range(len(stances))]
- return pairs
- @staticmethod
- def filter_pairs(pairs, inp_seq_len):
- pairs = [pair for pair in pairs if len(pair[0]) < inp_seq_len]
- return pairs
- def fill_langs(self):
- for pair in self.pairs:
- self.rus_lang.add_sentence(pair[0])
- def __len__(self):
- return len(self.pairs)
- def __getitem__(self, item):
- pair = self.pairs[item]
- enc_inp = tensor_from_sentence(self.rus_lang, pair[0], self.inp_seq_len)
- dec_tgt = torch.tensor(pair[1]).float().unsqueeze(-1)
- print(pair[0])
- return {
- "encoder_input": enc_inp, # (seq_len)
- "encoder_mask": (enc_inp != self.rus_lang.word2index["<pad>"]).unsqueeze(0).unsqueeze(0),
- # (1, seq_len) & (1, seq_len, seq_len),
- "target": dec_tgt, # (seq_len)
- }
|