MGTbenchmark / app.py
Evan73's picture
update the app.py
6f96319
# import streamlit as st
# import os
# import json
# import re
# import datasets
# import tiktoken
# import zipfile
# from pathlib import Path
# # 定义 tiktoken 编码器
# encoding = tiktoken.get_encoding("cl100k_base")
# _CITATION = """\
# @InProceedings{huggingface:dataset,
# title = {MGT detection},
# author={Trustworthy AI Lab},
# year={2024}
# }
# """
# _DESCRIPTION = """\
# For detecting machine generated text.
# """
# _HOMEPAGE = ""
# _LICENSE = ""
# # MGTHuman 类
# class MGTHuman(datasets.GeneratorBasedBuilder):
# VERSION = datasets.Version("1.0.0")
# BUILDER_CONFIGS = [
# datasets.BuilderConfig(name="human", version=VERSION, description="This part of human data"),
# datasets.BuilderConfig(name="Moonshot", version=VERSION, description="Data from the Moonshot model"),
# datasets.BuilderConfig(name="gpt35", version=VERSION, description="Data from the gpt-3.5-turbo model"),
# datasets.BuilderConfig(name="Llama3", version=VERSION, description="Data from the Llama3 model"),
# datasets.BuilderConfig(name="Mixtral", version=VERSION, description="Data from the Mixtral model"),
# datasets.BuilderConfig(name="Qwen", version=VERSION, description="Data from the Qwen model"),
# ]
# DEFAULT_CONFIG_NAME = "human"
# def _info(self):
# features = datasets.Features(
# {
# "id": datasets.Value("int32"),
# "text": datasets.Value("string"),
# "file": datasets.Value("string"),
# }
# )
# return datasets.DatasetInfo(
# description=_DESCRIPTION,
# features=features,
# homepage=_HOMEPAGE,
# license=_LICENSE,
# citation=_CITATION,
# )
# def truncate_text(self, text, max_tokens=2048):
# tokens = encoding.encode(text, allowed_special={'<|endoftext|>'})
# if len(tokens) > max_tokens:
# tokens = tokens[:max_tokens]
# truncated_text = encoding.decode(tokens)
# last_period_idx = truncated_text.rfind('。')
# if last_period_idx == -1:
# last_period_idx = truncated_text.rfind('.')
# if last_period_idx != -1:
# truncated_text = truncated_text[:last_period_idx + 1]
# return truncated_text
# else:
# return text
# def get_text_by_index(self, filepath, index, cut_tokens=False, max_tokens=2048):
# count = 0
# with open(filepath, 'r') as f:
# data = json.load(f)
# for row in data:
# if not row["text"].strip():
# continue
# if count == index:
# text = row["text"]
# if cut_tokens:
# text = self.truncate_text(text, max_tokens)
# return text
# count += 1
# return "Index 超出范围,请输入有效的数字。"
# def count_entries(self, filepath):
# """返回文件中的总条数,用于动态生成索引范围"""
# count = 0
# with open(filepath, 'r') as f:
# data = json.load(f)
# for row in data:
# if row["text"].strip():
# count += 1
# return count
# # Streamlit UI
# st.title("MGTHuman Dataset Viewer")
# # 上传包含 JSON 文件的 ZIP 文件
# uploaded_folder = st.file_uploader("上传包含 JSON 文件的 ZIP 文件夹", type=["zip"])
# if uploaded_folder:
# folder_path = Path("temp")
# folder_path.mkdir(exist_ok=True)
# zip_path = folder_path / uploaded_folder.name
# with open(zip_path, "wb") as f:
# f.write(uploaded_folder.getbuffer())
# with zipfile.ZipFile(zip_path, 'r') as zip_ref:
# zip_ref.extractall(folder_path)
# # 递归获取所有 JSON 文件并分类到不同的 domain
# category = {}
# for json_file in folder_path.rglob("*.json"): # 使用 rglob 递归查找所有 JSON 文件
# domain = json_file.stem.split('_task3')[0]
# category.setdefault(domain, []).append(str(json_file))
# # 显示可用的 domain 下拉框
# if category:
# selected_domain = st.selectbox("选择数据种类", options=list(category.keys()))
# # 确定该 domain 的第一个文件路径并获取条目数量
# file_to_display = category[selected_domain][0]
# mgt_human = MGTHuman(name=selected_domain)
# total_entries = mgt_human.count_entries(file_to_display)
# st.write(f"可用的索引范围: 0 到 {total_entries - 1}")
# # 输入序号查看文本
# index_to_view = st.number_input("输入要查看的文本序号", min_value=0, max_value=total_entries - 1, step=1)
# # 添加复选框以选择是否切割文本
# cut_tokens = st.checkbox("是否对文本进行token切割", value=False)
# if st.button("显示文本"):
# text = mgt_human.get_text_by_index(file_to_display, index=index_to_view, cut_tokens=cut_tokens)
# st.write("对应的文本内容为:", text)
# else:
# st.write("未找到任何 JSON 文件,请检查 ZIP 文件结构。")
# # 清理上传文件的临时目录
# if st.button("清除文件"):
# import shutil
# shutil.rmtree("temp")
# st.write("临时文件已清除。")
import streamlit as st
from transformers import pipeline
# Initialize Hugging Face text classifier
@st.cache_resource # Cache the model to avoid reloading
def load_model():
# Use a Hugging Face pre-trained text classification model
# Replace with a suitable model if necessary
classifier = pipeline("text-classification", model="roberta-base-openai-detector")
return classifier
st.title("Machine-Generated Text Detector")
st.write("Enter a text snippet, and I will analyze it to determine if it is likely written by a human or generated by a machine.")
# Load the model
classifier = load_model()
# Input text
input_text = st.text_area("Enter text here:", height=150)
# Button to trigger detection
if st.button("Analyze"):
if input_text:
# Make prediction
result = classifier(input_text)
# Extract label and confidence score
label = result[0]['label']
score = result[0]['score'] * 100 # Convert to percentage for readability
# Display result
if label == "LABEL_1":
st.write(f"**Result:** This text is likely **Machine-Generated**.")
else:
st.write(f"**Result:** This text is likely **Human-Written**.")
# Display confidence score
st.write(f"**Confidence Score:** {score:.2f}%")
else:
st.write("Please enter some text for analysis.")