tests.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import pytest
  2. import torch
  3. import config
  4. from model import build_transformer
  5. from dataset import TextDataset
  6. @pytest.fixture(scope="class", autouse=False)
  7. def get_data(request):
  8. dataset = TextDataset(config.DATA_PATH1, config.DATA_PATH2, config.INP_SEQ_LEN)
  9. dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1)
  10. model = build_transformer(dataset.rus_lang.n_words, config.INP_SEQ_LEN, dataset.rus_lang)
  11. model.load_checkpoint(config.CHECKPOINT_PATH)
  12. model = model.to(config.DEVICE)
  13. request.cls.model = model
  14. for batch in dataloader:
  15. request.cls.batch = batch
  16. break
  17. yield
  18. @pytest.mark.usefixtures("get_data")
  19. class TestBody:
  20. def test_encoder_seq(self):
  21. input_data = self.batch["encoder_input"].to(config.DEVICE)
  22. input_mask = self.batch["encoder_mask"].to(config.DEVICE)
  23. enc_out = self.model.encode(input_data, input_mask)
  24. assert enc_out.shape[1] == config.INP_SEQ_LEN
  25. def test_encoder_embed(self):
  26. input_data = self.batch["encoder_input"].to(config.DEVICE)
  27. input_mask = self.batch["encoder_mask"].to(config.DEVICE)
  28. enc_out = self.model.encode(input_data, input_mask)
  29. assert enc_out.shape[2] == 512
  30. def test_class_out_shape(self):
  31. input_data = self.batch["encoder_input"].to(config.DEVICE)
  32. input_mask = self.batch["encoder_mask"].to(config.DEVICE)
  33. enc_out = self.model.encode(input_data, input_mask)
  34. output = self.model.classify(enc_out)
  35. assert output.shape[-1] == 1
  36. def test_class_out_value(self):
  37. input_data = self.batch["encoder_input"].to(config.DEVICE)
  38. input_mask = self.batch["encoder_mask"].to(config.DEVICE)
  39. enc_out = self.model.encode(input_data, input_mask)
  40. output = self.model.classify(enc_out)
  41. probability = torch.sigmoid(output).item()
  42. assert 0 <= probability <= 1
  43. if __name__ == "__main__":
  44. pytest.main()