api.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from flask import Flask, request
  2. from tokenizer import CustomTokenizer
  3. from model import build_transformer
  4. from config import *
  5. from utils import tensor_from_sentence
  6. import torch
  7. app = Flask(__name__)
  8. CustomTokenizer = CustomTokenizer()
  9. @app.route("/", methods=["POST"])
  10. def classify_article():
  11. if request.method == "POST":
  12. Lang = torch.load(CHECKPOINT_PATH)["lang"]
  13. model = build_transformer(Lang.n_words, INP_SEQ_LEN, Lang).to(DEVICE)
  14. model.load_checkpoint(CHECKPOINT_PATH)
  15. json = request.json
  16. text = json["Text"]
  17. tokens = CustomTokenizer.tokenize_rus(text)
  18. text_tensor = tensor_from_sentence(Lang, tokens, INP_SEQ_LEN).to(DEVICE)
  19. text_mask = (text_tensor != Lang.word2index["<pad>"]).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(DEVICE)
  20. text_tensor = text_tensor.unsqueeze(0)
  21. enc_out = model.encode(text_tensor, text_mask)
  22. output = model.classify(enc_out)
  23. probability = torch.sigmoid(output).item()
  24. response = {
  25. "Probability": probability,
  26. "Type": "Ложная статья" if probability < 0.7 else "Правдивая статья"
  27. }
  28. return response, 200