12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- import pytest
- import torch
- import config
- from model import build_transformer
- from dataset import TextDataset
- @pytest.fixture(scope="class", autouse=False)
- def get_data(request):
- dataset = TextDataset(config.DATA_PATH1, config.DATA_PATH2, config.INP_SEQ_LEN)
- dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1)
- model = build_transformer(dataset.rus_lang.n_words, config.INP_SEQ_LEN, dataset.rus_lang)
- model.load_checkpoint(config.CHECKPOINT_PATH)
- model = model.to(config.DEVICE)
- request.cls.model = model
- for batch in dataloader:
- request.cls.batch = batch
- break
- yield
- @pytest.mark.usefixtures("get_data")
- class TestBody:
- def test_encoder_seq(self):
- input_data = self.batch["encoder_input"].to(config.DEVICE)
- input_mask = self.batch["encoder_mask"].to(config.DEVICE)
- enc_out = self.model.encode(input_data, input_mask)
- assert enc_out.shape[1] == config.INP_SEQ_LEN
- def test_encoder_embed(self):
- input_data = self.batch["encoder_input"].to(config.DEVICE)
- input_mask = self.batch["encoder_mask"].to(config.DEVICE)
- enc_out = self.model.encode(input_data, input_mask)
- assert enc_out.shape[2] == 512
- def test_class_out_shape(self):
- input_data = self.batch["encoder_input"].to(config.DEVICE)
- input_mask = self.batch["encoder_mask"].to(config.DEVICE)
- enc_out = self.model.encode(input_data, input_mask)
- output = self.model.classify(enc_out)
- assert output.shape[-1] == 1
- def test_class_out_value(self):
- input_data = self.batch["encoder_input"].to(config.DEVICE)
- input_mask = self.batch["encoder_mask"].to(config.DEVICE)
- enc_out = self.model.encode(input_data, input_mask)
- output = self.model.classify(enc_out)
- probability = torch.sigmoid(output).item()
- assert 0 <= probability <= 1
- if __name__ == "__main__":
- pytest.main()
|