import torch import os from tqdm import tqdm from dataset import TextDataset from config import * from model import build_transformer from torchinfo import summary def train_model(num_epochs=100): dataset = TextDataset(DATA_PATH1, DATA_PATH2, INP_SEQ_LEN) model = build_transformer( src_vocab=dataset.rus_lang.n_words, src_len=INP_SEQ_LEN, lang=dataset.rus_lang ) model = model.to(DEVICE) summary(model) loss_fn = torch.nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE, weight_decay=WEIGHT_DECAY) if os.path.isfile(CHECKPOINT_PATH): model.load_checkpoint_train(CHECKPOINT_PATH, optimizer) model = model.to(DEVICE) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True) scaler = torch.cuda.amp.GradScaler() for epoch in range(num_epochs): train_loop(model, dataloader, optimizer, loss_fn, dataset.rus_lang, scaler) model.save_train(CHECKPOINT_PATH, optimizer) def train_loop(model, dataloader, optimizer, loss_fn, lang, scaler): total_loss = 0 loop = tqdm(dataloader) for idx, batch in enumerate(loop): with torch.cuda.amp.autocast(): encoder_input = batch["encoder_input"].to(DEVICE) encoder_mask = batch["encoder_mask"].to(DEVICE) encoder_out = model.encode(encoder_input.to(DEVICE), encoder_mask.to(DEVICE)) target = batch["target"].to(DEVICE) out = model.classify(encoder_out) out_for_print = out[0].detach() print(torch.sigmoid(out_for_print).item()) loss = loss_fn(out, target) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() total_loss += loss.item() loop.set_postfix(loss=total_loss / (idx + 1))