import pytest import torch import config from model import build_transformer from dataset import TextDataset @pytest.fixture(scope="class", autouse=False) def get_data(request): dataset = TextDataset(config.DATA_PATH1, config.DATA_PATH2, config.INP_SEQ_LEN) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1) model = build_transformer(dataset.rus_lang.n_words, config.INP_SEQ_LEN, dataset.rus_lang) model.load_checkpoint(config.CHECKPOINT_PATH) model = model.to(config.DEVICE) request.cls.model = model for batch in dataloader: request.cls.batch = batch break yield @pytest.mark.usefixtures("get_data") class TestBody: def test_encoder_seq(self): input_data = self.batch["encoder_input"].to(config.DEVICE) input_mask = self.batch["encoder_mask"].to(config.DEVICE) enc_out = self.model.encode(input_data, input_mask) assert enc_out.shape[1] == config.INP_SEQ_LEN def test_encoder_embed(self): input_data = self.batch["encoder_input"].to(config.DEVICE) input_mask = self.batch["encoder_mask"].to(config.DEVICE) enc_out = self.model.encode(input_data, input_mask) assert enc_out.shape[2] == 512 def test_class_out_shape(self): input_data = self.batch["encoder_input"].to(config.DEVICE) input_mask = self.batch["encoder_mask"].to(config.DEVICE) enc_out = self.model.encode(input_data, input_mask) output = self.model.classify(enc_out) assert output.shape[-1] == 1 def test_class_out_value(self): input_data = self.batch["encoder_input"].to(config.DEVICE) input_mask = self.batch["encoder_mask"].to(config.DEVICE) enc_out = self.model.encode(input_data, input_mask) output = self.model.classify(enc_out) probability = torch.sigmoid(output).item() assert 0 <= probability <= 1 if __name__ == "__main__": pytest.main()