Onnx-converter2 / app.py
Sakalti's picture
Update app.py
2eec75b verified
import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import HfApi, create_repo
# モデルのONNXエクスポート関数
def convert_to_onnx_and_deploy(model_repo, input_text, hf_token, repo_name):
try:
# Hugging Faceトークンを設定
os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
# モデルとトークナイザーの読み込み
tokenizer = AutoTokenizer.from_pretrained(model_repo)
model = AutoModelForCausalLM.from_pretrained(model_repo)
# 入力テキストをトークナイズ
inputs = tokenizer(input_text, return_tensors="pt")
# ONNXファイルの保存
onnx_file = f"{repo_name}.onnx"
torch.onnx.export(
model,
inputs['input_ids'],
onnx_file,
input_names=['input_ids'],
output_names=['output'],
dynamic_axes={'input_ids': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
# モデルをHugging Face Hubにデプロイ
api = HfApi()
create_repo(repo_name, private=True) # プライベートリポジトリを作成
api.upload_file(onnx_file, repo_id=repo_name) # ONNXファイルをアップロード
return f"ONNXモデルが作成され、リポジトリにデプロイされました: {repo_name}"
except Exception as e:
return str(e)
# Gradioインターフェース
iface = gr.Interface(
fn=convert_to_onnx_and_deploy,
inputs=[
gr.Textbox(label="モデルリポジトリ(例: rinna/japanese-gpt2-medium)"),
gr.Textbox(label="入力テキスト"),
gr.Textbox(label="Hugging Faceトークン", type="password"), # パスワード入力タイプ
gr.Textbox(label="デプロイ先リポジトリ名") # デプロイ先のリポジトリ名
],
outputs="text",
title="ONNX変換とモデルデプロイ機能",
description="指定したHugging FaceのモデルリポジトリをONNX形式に変換し、デプロイします。"
)
# 使用するポート番号を指定してインターフェースを起動
iface.launch(server_port=7865) # 7865ポートを指定