import torch import csv from tokenizer import CustomTokenizer from utils import tensor_from_sentence class Lang: def __init__(self, name): self.name = name self.word2index = {"": 0, "": 1, "": 2, ".": 3, ",": 4, "!": 5, "?": 6} self.index2word = {0: "", 1: "", 2: "", 3: ".", 4: ",", 5: "!", 6: "?"} self.n_words = 7 # Count SOS and EOS def add_sentence(self, sentence): for word in sentence: self.add_word(word.lemma_) def add_word(self, word): if word not in self.word2index.keys(): self.word2index[word] = self.n_words self.index2word[self.n_words] = word self.n_words += 1 else: pass class TextDataset(torch.utils.data.Dataset): def __init__(self, file_path1, file_path2, inp_seq_len): self.tokenizer = CustomTokenizer() self.pairs = self.read_langs(file_path1, file_path2) self.pairs = TextDataset.filter_pairs(self.pairs, inp_seq_len) self.rus_lang = Lang("rus") self.fill_langs() print(self.rus_lang.n_words) self.inp_seq_len = inp_seq_len def read_langs(self, file_path1, file_path2): print("Reading lines...") with open(file_path1, newline='', encoding="UTF-8") as csvfile_text: spamreader = csv.reader(csvfile_text, delimiter=',', quotechar='"') text_lines = [] for idx, row in enumerate(spamreader): if idx == 0: continue text_lines.append(self.tokenizer.tokenize_rus(" " + row[2] + " ")) with open(file_path2, newline='', encoding="UTF-8") as csvfile_labels: spamreader = csv.reader(csvfile_labels, delimiter=',', quotechar='"') stances = [] for idx, row in enumerate(spamreader): if idx == 0: continue stances += [True] if row[3] == "agree" else [False] assert len(stances) == len(text_lines) pairs = [(text_lines[i], stances[i]) for i in range(len(stances))] return pairs @staticmethod def filter_pairs(pairs, inp_seq_len): pairs = [pair for pair in pairs if len(pair[0]) < inp_seq_len] return pairs def fill_langs(self): for pair in self.pairs: self.rus_lang.add_sentence(pair[0]) def __len__(self): return len(self.pairs) def __getitem__(self, item): pair = self.pairs[item] enc_inp = tensor_from_sentence(self.rus_lang, pair[0], self.inp_seq_len) dec_tgt = torch.tensor(pair[1]).float().unsqueeze(-1) print(pair[0]) return { "encoder_input": enc_inp, # (seq_len) "encoder_mask": (enc_inp != self.rus_lang.word2index[""]).unsqueeze(0).unsqueeze(0), # (1, seq_len) & (1, seq_len, seq_len), "target": dec_tgt, # (seq_len) }