123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- import torch
- import torch.nn as nn
- import math
- class InputEmbeddings(nn.Module):
- def __init__(self, d_model: int, vocab_size: int):
- super(InputEmbeddings, self).__init__()
- self.d_model = d_model
- self.vocab_size = vocab_size
- self.embedding = nn.Embedding(vocab_size, d_model)
- def forward(self, x):
- x = self.embedding(x) * self.d_model ** 0.5
- return x
- class PositionalEncoding(nn.Module):
- def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
- super(PositionalEncoding, self).__init__()
- self.d_model = d_model
- self.seq_len = seq_len
- self.dropout = nn.Dropout(dropout)
- pe = torch.zeros(seq_len, d_model)
- position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (Seq_len, 1)
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(1000.0) / d_model))
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- pe = pe.unsqueeze(0)
- self.register_buffer('pe', pe)
- def forward(self, x):
- x = x + self.pe[:, :x.shape[1], :].requires_grad_(False)
- x = self.dropout(x)
- return x
- class LayerNormalization(nn.Module):
- def __init__(self, eps: float = 1e-6) -> None:
- super(LayerNormalization, self).__init__()
- self.layerNorm = nn.LayerNorm(normalized_shape=512, eps=eps)
- def forward(self, x):
- x = self.layerNorm(x)
- return x
- class FFBlock(nn.Module):
- def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
- super(FFBlock, self).__init__()
- self.layers = nn.Sequential(
- nn.Linear(d_model, d_ff), # W1 + B1
- nn.Dropout(dropout),
- nn.ReLU(),
- nn.Linear(d_ff, d_model), # W2 + B2
- )
- def forward(self, x):
- x = self.layers(x)
- return x
- class MultiHeadAttentionBlock(nn.Module):
- def __init__(self, d_model: int, h: int, dropout: float) -> None:
- super().__init__()
- self.d_model = d_model # Embedding vector size
- self.h = h # Number of heads
- # Make sure d_model is divisible by h
- assert d_model % h == 0, "d_model is not divisible by h"
- self.d_k = d_model // h # Dimension of vector seen by each head
- self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
- self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
- self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
- self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
- self.dropout = nn.Dropout(dropout)
- @staticmethod
- def attention(query, key, value, mask, dropout: nn.Dropout):
- d_k = query.shape[-1]
- # Just apply the formula from the paper
- # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
- attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
- if mask is not None:
- # Write a very low value (indicating -inf) to the positions where mask == 0
- attention_scores.masked_fill_(~mask, -65504)
- attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
- if dropout is not None:
- attention_scores = dropout(attention_scores)
- # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
- # return attention scores which can be used for visualization
- return (attention_scores @ value), attention_scores
- def forward(self, q, k, v, mask):
- query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
- key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
- value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
- # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
- query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
- key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
- value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
- # Calculate attention
- x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
- # Combine all the heads together
- # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
- x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
- # Multiply by Wo
- # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
- return self.w_o(x)
- class ResidualConnection(nn.Module):
- def __init__(self, dropout: float) -> None:
- super(ResidualConnection, self).__init__()
- self.dropout = nn.Dropout(dropout)
- self.norm = LayerNormalization()
- def forward(self, x, sublayer):
- x = x + self.dropout(sublayer(self.norm(x)))
- return x
- class EncoderBlock(nn.Module):
- def __init__(
- self,
- self_attention_block: MultiHeadAttentionBlock,
- feed_forward_block: FFBlock,
- dropout: float
- ):
- super(EncoderBlock, self).__init__()
- self.self_attention_block = self_attention_block
- self.feed_forward_block = feed_forward_block
- self.residual_connections = nn.ModuleList(
- [
- ResidualConnection(dropout)
- for _ in range(2)
- ]
- )
- def forward(self, x, src_mask):
- x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
- x = self.residual_connections[1](x, self.feed_forward_block)
- return x
- class Encoder(nn.Module):
- def __init__(self, layers: nn.ModuleList) -> None:
- super(Encoder, self).__init__()
- self.layers = layers
- self.norm = LayerNormalization()
- def forward(self, x, mask):
- for layer in self.layers:
- x = layer(x, mask)
- x = self.norm(x)
- return x
- class Classifier(nn.Module):
- def __init__(self, src_len, d_model):
- super(Classifier, self).__init__()
- self.src_len = src_len
- self.d_model = d_model
- self.layers = nn.Sequential(
- nn.Linear(d_model, 128),
- nn.LayerNorm(128),
- nn.Flatten(),
- nn.Linear(src_len * 128, 128),
- nn.ReLU(),
- nn.Linear(128, 16),
- nn.Linear(16, 1),
- )
- def forward(self, x):
- x = self.layers(x)
- return x
- class Transformer(nn.Module):
- def __init__(
- self,
- encoder: Encoder,
- src_embed: InputEmbeddings,
- src_pos: PositionalEncoding,
- classifier: Classifier,
- lang
- ) -> None:
- super(Transformer, self).__init__()
- self.encoder = encoder
- self.src_embed = src_embed
- self.src_pos = src_pos
- self.classifier = classifier
- self.lang = lang
- def encode(self, src, src_mask):
- # (batch, seq_len, d_model)
- src = self.src_embed(src)
- src = self.src_pos(src)
- src = self.encoder(src, src_mask)
- return src
- def classify(self, enc_out):
- return self.classifier(enc_out)
- def save_train(self, file_path, optimizer):
- checkpoint = {
- "model_state_dict": self.state_dict(),
- "optimizer_state_dict": optimizer.state_dict(),
- "lang": self.lang
- }
- torch.save(checkpoint, file_path)
- def load_checkpoint_train(self, file_path, optimizer):
- checkpoint = torch.load(file_path, map_location="cuda")
- self.load_state_dict(checkpoint["model_state_dict"])
- self.lang = checkpoint["lang"]
- optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
- def load_checkpoint(self, file_path):
- checkpoint = torch.load(file_path)
- self.load_state_dict(checkpoint["model_state_dict"])
- self.lang = checkpoint["lang"]
- def build_transformer(
- src_vocab: int,
- src_len: int,
- lang,
- d_model: int = 512,
- N: int = 6,
- h: int = 8,
- dropout: float = 0.1,
- d_ff: int = 2048) -> Transformer:
- # Create the embed layers
- src_embed = InputEmbeddings(d_model, src_vocab)
- # Create the pos encoding layer
- src_pos = PositionalEncoding(d_model, src_len, dropout)
- # Create the encoder blocks
- encoder_blocks = []
- for _ in range(N):
- self_attn_block = MultiHeadAttentionBlock(d_model, h, dropout)
- ff_block = FFBlock(d_model, d_ff, dropout)
- encoder_block = EncoderBlock(self_attn_block, ff_block, dropout)
- encoder_blocks.append(encoder_block)
- # Create the encoder
- encoder = Encoder(nn.ModuleList(encoder_blocks))
- # Create the classifier
- classifier = Classifier(src_len, d_model)
- # Create the transformer
- transformer = Transformer(encoder, src_embed, src_pos, classifier, lang)
- # Initialize the parameters
- for p in transformer.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- return transformer
|