123456789101112131415161718192021222324252627282930313233343536373839 |
- from flask import Flask, request
- from tokenizer import CustomTokenizer
- from model import build_transformer
- from config import *
- from utils import tensor_from_sentence
- import torch
- app = Flask(__name__)
- CustomTokenizer = CustomTokenizer()
- @app.route("/", methods=["POST"])
- def classify_article():
- if request.method == "POST":
- Lang = torch.load(CHECKPOINT_PATH)["lang"]
- model = build_transformer(Lang.n_words, INP_SEQ_LEN, Lang).to(DEVICE)
- model.load_checkpoint(CHECKPOINT_PATH)
- json = request.json
- text = json["Text"]
- tokens = CustomTokenizer.tokenize_rus(text)
- text_tensor = tensor_from_sentence(Lang, tokens, INP_SEQ_LEN).to(DEVICE)
- text_mask = (text_tensor != Lang.word2index["<pad>"]).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(DEVICE)
- text_tensor = text_tensor.unsqueeze(0)
- enc_out = model.encode(text_tensor, text_mask)
- output = model.classify(enc_out)
- probability = torch.sigmoid(output).item()
- response = {
- "Probability": probability,
- "Type": "Ложная статья" if probability < 0.7 else "Правдивая статья"
- }
- return response, 200
|