from flask import Flask, render_template, redirect, request, jsonify, make_response import datetime import torch import transformers device = torch.device('cuda') MODEL_NAME = 'liujch1998/vera' class Interactive: def __init__(self): self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME) self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto', offload_folder='offload') self.model.D = self.model.shared.embedding_dim self.linear = torch.nn.Linear(self.model.D, 1, dtype=self.model.dtype).to(device) self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D) self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1) self.model.eval() self.t = self.model.shared.weight[32097, 0].item() def run(self, statement): input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest', truncation='longest_first', max_length=128).input_ids.to(device) with torch.no_grad(): output = self.model(input_ids) last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D) hidden = last_hidden_state[0, -1, :] # (D) logit = self.linear(hidden).squeeze(-1) # () logit_calibrated = logit / self.t score = logit.sigmoid() score_calibrated = logit_calibrated.sigmoid() return { 'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), 'statement': statement, 'logit': logit.item(), 'logit_calibrated': logit_calibrated.item(), 'score': score.item(), 'score_calibrated': score_calibrated.item(), } interactive = Interactive() app = Flask(__name__) @app.route('/', methods=['GET', 'POST']) def main(): try: print(request) data = request.get_json() statement = data.get('statement') except Exception as e: return jsonify({ 'success': False, 'message': 'Please provide a statement.', }), 400 try: result = interactive.run(statement) except Exception as e: return jsonify({ 'success': False, 'message': 'Internal error.', }), 500 return jsonify(result) if __name__ == "__main__": app.run(host="0.0.0.0", port=8372, threaded=True, ssl_context=('/etc/letsencrypt/live/qa.cs.washington.edu/fullchain.pem', '/etc/letsencrypt/live/qa.cs.washington.edu/privkey.pem')) # 8372 is when you type Vera on a phone keypad