model.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import torch
  2. import torch.nn as nn
  3. import math
  4. class InputEmbeddings(nn.Module):
  5. def __init__(self, d_model: int, vocab_size: int):
  6. super(InputEmbeddings, self).__init__()
  7. self.d_model = d_model
  8. self.vocab_size = vocab_size
  9. self.embedding = nn.Embedding(vocab_size, d_model)
  10. def forward(self, x):
  11. x = self.embedding(x) * self.d_model ** 0.5
  12. return x
  13. class PositionalEncoding(nn.Module):
  14. def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
  15. super(PositionalEncoding, self).__init__()
  16. self.d_model = d_model
  17. self.seq_len = seq_len
  18. self.dropout = nn.Dropout(dropout)
  19. pe = torch.zeros(seq_len, d_model)
  20. position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (Seq_len, 1)
  21. div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(1000.0) / d_model))
  22. pe[:, 0::2] = torch.sin(position * div_term)
  23. pe[:, 1::2] = torch.cos(position * div_term)
  24. pe = pe.unsqueeze(0)
  25. self.register_buffer('pe', pe)
  26. def forward(self, x):
  27. x = x + self.pe[:, :x.shape[1], :].requires_grad_(False)
  28. x = self.dropout(x)
  29. return x
  30. class LayerNormalization(nn.Module):
  31. def __init__(self, eps: float = 1e-6) -> None:
  32. super(LayerNormalization, self).__init__()
  33. self.layerNorm = nn.LayerNorm(normalized_shape=512, eps=eps)
  34. def forward(self, x):
  35. x = self.layerNorm(x)
  36. return x
  37. class FFBlock(nn.Module):
  38. def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
  39. super(FFBlock, self).__init__()
  40. self.layers = nn.Sequential(
  41. nn.Linear(d_model, d_ff), # W1 + B1
  42. nn.Dropout(dropout),
  43. nn.ReLU(),
  44. nn.Linear(d_ff, d_model), # W2 + B2
  45. )
  46. def forward(self, x):
  47. x = self.layers(x)
  48. return x
  49. class MultiHeadAttentionBlock(nn.Module):
  50. def __init__(self, d_model: int, h: int, dropout: float) -> None:
  51. super().__init__()
  52. self.d_model = d_model # Embedding vector size
  53. self.h = h # Number of heads
  54. # Make sure d_model is divisible by h
  55. assert d_model % h == 0, "d_model is not divisible by h"
  56. self.d_k = d_model // h # Dimension of vector seen by each head
  57. self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
  58. self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
  59. self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
  60. self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
  61. self.dropout = nn.Dropout(dropout)
  62. @staticmethod
  63. def attention(query, key, value, mask, dropout: nn.Dropout):
  64. d_k = query.shape[-1]
  65. # Just apply the formula from the paper
  66. # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
  67. attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
  68. if mask is not None:
  69. # Write a very low value (indicating -inf) to the positions where mask == 0
  70. attention_scores.masked_fill_(~mask, -65504)
  71. attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
  72. if dropout is not None:
  73. attention_scores = dropout(attention_scores)
  74. # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
  75. # return attention scores which can be used for visualization
  76. return (attention_scores @ value), attention_scores
  77. def forward(self, q, k, v, mask):
  78. query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
  79. key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
  80. value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
  81. # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
  82. query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
  83. key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
  84. value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
  85. # Calculate attention
  86. x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
  87. # Combine all the heads together
  88. # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
  89. x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
  90. # Multiply by Wo
  91. # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
  92. return self.w_o(x)
  93. class ResidualConnection(nn.Module):
  94. def __init__(self, dropout: float) -> None:
  95. super(ResidualConnection, self).__init__()
  96. self.dropout = nn.Dropout(dropout)
  97. self.norm = LayerNormalization()
  98. def forward(self, x, sublayer):
  99. x = x + self.dropout(sublayer(self.norm(x)))
  100. return x
  101. class EncoderBlock(nn.Module):
  102. def __init__(
  103. self,
  104. self_attention_block: MultiHeadAttentionBlock,
  105. feed_forward_block: FFBlock,
  106. dropout: float
  107. ):
  108. super(EncoderBlock, self).__init__()
  109. self.self_attention_block = self_attention_block
  110. self.feed_forward_block = feed_forward_block
  111. self.residual_connections = nn.ModuleList(
  112. [
  113. ResidualConnection(dropout)
  114. for _ in range(2)
  115. ]
  116. )
  117. def forward(self, x, src_mask):
  118. x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
  119. x = self.residual_connections[1](x, self.feed_forward_block)
  120. return x
  121. class Encoder(nn.Module):
  122. def __init__(self, layers: nn.ModuleList) -> None:
  123. super(Encoder, self).__init__()
  124. self.layers = layers
  125. self.norm = LayerNormalization()
  126. def forward(self, x, mask):
  127. for layer in self.layers:
  128. x = layer(x, mask)
  129. x = self.norm(x)
  130. return x
  131. class Classifier(nn.Module):
  132. def __init__(self, src_len, d_model):
  133. super(Classifier, self).__init__()
  134. self.src_len = src_len
  135. self.d_model = d_model
  136. self.layers = nn.Sequential(
  137. nn.Linear(d_model, 128),
  138. nn.LayerNorm(128),
  139. nn.Flatten(),
  140. nn.Linear(src_len * 128, 128),
  141. nn.ReLU(),
  142. nn.Linear(128, 16),
  143. nn.Linear(16, 1),
  144. )
  145. def forward(self, x):
  146. x = self.layers(x)
  147. return x
  148. class Transformer(nn.Module):
  149. def __init__(
  150. self,
  151. encoder: Encoder,
  152. src_embed: InputEmbeddings,
  153. src_pos: PositionalEncoding,
  154. classifier: Classifier,
  155. lang
  156. ) -> None:
  157. super(Transformer, self).__init__()
  158. self.encoder = encoder
  159. self.src_embed = src_embed
  160. self.src_pos = src_pos
  161. self.classifier = classifier
  162. self.lang = lang
  163. def encode(self, src, src_mask):
  164. # (batch, seq_len, d_model)
  165. src = self.src_embed(src)
  166. src = self.src_pos(src)
  167. src = self.encoder(src, src_mask)
  168. return src
  169. def classify(self, enc_out):
  170. return self.classifier(enc_out)
  171. def save_train(self, file_path, optimizer):
  172. checkpoint = {
  173. "model_state_dict": self.state_dict(),
  174. "optimizer_state_dict": optimizer.state_dict(),
  175. "lang": self.lang
  176. }
  177. torch.save(checkpoint, file_path)
  178. def load_checkpoint_train(self, file_path, optimizer):
  179. checkpoint = torch.load(file_path, map_location="cuda")
  180. self.load_state_dict(checkpoint["model_state_dict"])
  181. self.lang = checkpoint["lang"]
  182. optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
  183. def load_checkpoint(self, file_path):
  184. checkpoint = torch.load(file_path)
  185. self.load_state_dict(checkpoint["model_state_dict"])
  186. self.lang = checkpoint["lang"]
  187. def build_transformer(
  188. src_vocab: int,
  189. src_len: int,
  190. lang,
  191. d_model: int = 512,
  192. N: int = 6,
  193. h: int = 8,
  194. dropout: float = 0.1,
  195. d_ff: int = 2048) -> Transformer:
  196. # Create the embed layers
  197. src_embed = InputEmbeddings(d_model, src_vocab)
  198. # Create the pos encoding layer
  199. src_pos = PositionalEncoding(d_model, src_len, dropout)
  200. # Create the encoder blocks
  201. encoder_blocks = []
  202. for _ in range(N):
  203. self_attn_block = MultiHeadAttentionBlock(d_model, h, dropout)
  204. ff_block = FFBlock(d_model, d_ff, dropout)
  205. encoder_block = EncoderBlock(self_attn_block, ff_block, dropout)
  206. encoder_blocks.append(encoder_block)
  207. # Create the encoder
  208. encoder = Encoder(nn.ModuleList(encoder_blocks))
  209. # Create the classifier
  210. classifier = Classifier(src_len, d_model)
  211. # Create the transformer
  212. transformer = Transformer(encoder, src_embed, src_pos, classifier, lang)
  213. # Initialize the parameters
  214. for p in transformer.parameters():
  215. if p.dim() > 1:
  216. nn.init.xavier_uniform_(p)
  217. return transformer