1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- import torch
- import os
- from tqdm import tqdm
- from dataset import TextDataset
- from config import *
- from model import build_transformer
- from torchinfo import summary
- def train_model(num_epochs=100):
- dataset = TextDataset(DATA_PATH1, DATA_PATH2, INP_SEQ_LEN)
- model = build_transformer(
- src_vocab=dataset.rus_lang.n_words,
- src_len=INP_SEQ_LEN,
- lang=dataset.rus_lang
- )
- model = model.to(DEVICE)
- summary(model)
- loss_fn = torch.nn.BCEWithLogitsLoss()
- optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE, weight_decay=WEIGHT_DECAY)
- if os.path.isfile(CHECKPOINT_PATH):
- model.load_checkpoint_train(CHECKPOINT_PATH, optimizer)
- model = model.to(DEVICE)
- dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
- scaler = torch.cuda.amp.GradScaler()
- for epoch in range(num_epochs):
- train_loop(model, dataloader, optimizer, loss_fn, dataset.rus_lang, scaler)
- model.save_train(CHECKPOINT_PATH, optimizer)
- def train_loop(model, dataloader, optimizer, loss_fn, lang, scaler):
- total_loss = 0
- loop = tqdm(dataloader)
- for idx, batch in enumerate(loop):
- with torch.cuda.amp.autocast():
- encoder_input = batch["encoder_input"].to(DEVICE)
- encoder_mask = batch["encoder_mask"].to(DEVICE)
- encoder_out = model.encode(encoder_input.to(DEVICE), encoder_mask.to(DEVICE))
- target = batch["target"].to(DEVICE)
- out = model.classify(encoder_out)
- out_for_print = out[0].detach()
- print(torch.sigmoid(out_for_print).item())
- loss = loss_fn(out, target)
- optimizer.zero_grad()
- scaler.scale(loss).backward()
- scaler.step(optimizer)
- scaler.update()
- total_loss += loss.item()
- loop.set_postfix(loss=total_loss / (idx + 1))
|