train.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import torch
  2. import os
  3. from tqdm import tqdm
  4. from dataset import TextDataset
  5. from config import *
  6. from model import build_transformer
  7. from torchinfo import summary
  8. def train_model(num_epochs=100):
  9. dataset = TextDataset(DATA_PATH1, DATA_PATH2, INP_SEQ_LEN)
  10. model = build_transformer(
  11. src_vocab=dataset.rus_lang.n_words,
  12. src_len=INP_SEQ_LEN,
  13. lang=dataset.rus_lang
  14. )
  15. model = model.to(DEVICE)
  16. summary(model)
  17. loss_fn = torch.nn.BCEWithLogitsLoss()
  18. optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE, weight_decay=WEIGHT_DECAY)
  19. if os.path.isfile(CHECKPOINT_PATH):
  20. model.load_checkpoint_train(CHECKPOINT_PATH, optimizer)
  21. model = model.to(DEVICE)
  22. dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
  23. scaler = torch.cuda.amp.GradScaler()
  24. for epoch in range(num_epochs):
  25. train_loop(model, dataloader, optimizer, loss_fn, dataset.rus_lang, scaler)
  26. model.save_train(CHECKPOINT_PATH, optimizer)
  27. def train_loop(model, dataloader, optimizer, loss_fn, lang, scaler):
  28. total_loss = 0
  29. loop = tqdm(dataloader)
  30. for idx, batch in enumerate(loop):
  31. with torch.cuda.amp.autocast():
  32. encoder_input = batch["encoder_input"].to(DEVICE)
  33. encoder_mask = batch["encoder_mask"].to(DEVICE)
  34. encoder_out = model.encode(encoder_input.to(DEVICE), encoder_mask.to(DEVICE))
  35. target = batch["target"].to(DEVICE)
  36. out = model.classify(encoder_out)
  37. out_for_print = out[0].detach()
  38. print(torch.sigmoid(out_for_print).item())
  39. loss = loss_fn(out, target)
  40. optimizer.zero_grad()
  41. scaler.scale(loss).backward()
  42. scaler.step(optimizer)
  43. scaler.update()
  44. total_loss += loss.item()
  45. loop.set_postfix(loss=total_loss / (idx + 1))