Spaces:
Running
Running
# import streamlit as st | |
# import os | |
# import json | |
# import re | |
# import datasets | |
# import tiktoken | |
# import zipfile | |
# from pathlib import Path | |
# # 定义 tiktoken 编码器 | |
# encoding = tiktoken.get_encoding("cl100k_base") | |
# _CITATION = """\ | |
# @InProceedings{huggingface:dataset, | |
# title = {MGT detection}, | |
# author={Trustworthy AI Lab}, | |
# year={2024} | |
# } | |
# """ | |
# _DESCRIPTION = """\ | |
# For detecting machine generated text. | |
# """ | |
# _HOMEPAGE = "" | |
# _LICENSE = "" | |
# # MGTHuman 类 | |
# class MGTHuman(datasets.GeneratorBasedBuilder): | |
# VERSION = datasets.Version("1.0.0") | |
# BUILDER_CONFIGS = [ | |
# datasets.BuilderConfig(name="human", version=VERSION, description="This part of human data"), | |
# datasets.BuilderConfig(name="Moonshot", version=VERSION, description="Data from the Moonshot model"), | |
# datasets.BuilderConfig(name="gpt35", version=VERSION, description="Data from the gpt-3.5-turbo model"), | |
# datasets.BuilderConfig(name="Llama3", version=VERSION, description="Data from the Llama3 model"), | |
# datasets.BuilderConfig(name="Mixtral", version=VERSION, description="Data from the Mixtral model"), | |
# datasets.BuilderConfig(name="Qwen", version=VERSION, description="Data from the Qwen model"), | |
# ] | |
# DEFAULT_CONFIG_NAME = "human" | |
# def _info(self): | |
# features = datasets.Features( | |
# { | |
# "id": datasets.Value("int32"), | |
# "text": datasets.Value("string"), | |
# "file": datasets.Value("string"), | |
# } | |
# ) | |
# return datasets.DatasetInfo( | |
# description=_DESCRIPTION, | |
# features=features, | |
# homepage=_HOMEPAGE, | |
# license=_LICENSE, | |
# citation=_CITATION, | |
# ) | |
# def truncate_text(self, text, max_tokens=2048): | |
# tokens = encoding.encode(text, allowed_special={'<|endoftext|>'}) | |
# if len(tokens) > max_tokens: | |
# tokens = tokens[:max_tokens] | |
# truncated_text = encoding.decode(tokens) | |
# last_period_idx = truncated_text.rfind('。') | |
# if last_period_idx == -1: | |
# last_period_idx = truncated_text.rfind('.') | |
# if last_period_idx != -1: | |
# truncated_text = truncated_text[:last_period_idx + 1] | |
# return truncated_text | |
# else: | |
# return text | |
# def get_text_by_index(self, filepath, index, cut_tokens=False, max_tokens=2048): | |
# count = 0 | |
# with open(filepath, 'r') as f: | |
# data = json.load(f) | |
# for row in data: | |
# if not row["text"].strip(): | |
# continue | |
# if count == index: | |
# text = row["text"] | |
# if cut_tokens: | |
# text = self.truncate_text(text, max_tokens) | |
# return text | |
# count += 1 | |
# return "Index 超出范围,请输入有效的数字。" | |
# def count_entries(self, filepath): | |
# """返回文件中的总条数,用于动态生成索引范围""" | |
# count = 0 | |
# with open(filepath, 'r') as f: | |
# data = json.load(f) | |
# for row in data: | |
# if row["text"].strip(): | |
# count += 1 | |
# return count | |
# # Streamlit UI | |
# st.title("MGTHuman Dataset Viewer") | |
# # 上传包含 JSON 文件的 ZIP 文件 | |
# uploaded_folder = st.file_uploader("上传包含 JSON 文件的 ZIP 文件夹", type=["zip"]) | |
# if uploaded_folder: | |
# folder_path = Path("temp") | |
# folder_path.mkdir(exist_ok=True) | |
# zip_path = folder_path / uploaded_folder.name | |
# with open(zip_path, "wb") as f: | |
# f.write(uploaded_folder.getbuffer()) | |
# with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
# zip_ref.extractall(folder_path) | |
# # 递归获取所有 JSON 文件并分类到不同的 domain | |
# category = {} | |
# for json_file in folder_path.rglob("*.json"): # 使用 rglob 递归查找所有 JSON 文件 | |
# domain = json_file.stem.split('_task3')[0] | |
# category.setdefault(domain, []).append(str(json_file)) | |
# # 显示可用的 domain 下拉框 | |
# if category: | |
# selected_domain = st.selectbox("选择数据种类", options=list(category.keys())) | |
# # 确定该 domain 的第一个文件路径并获取条目数量 | |
# file_to_display = category[selected_domain][0] | |
# mgt_human = MGTHuman(name=selected_domain) | |
# total_entries = mgt_human.count_entries(file_to_display) | |
# st.write(f"可用的索引范围: 0 到 {total_entries - 1}") | |
# # 输入序号查看文本 | |
# index_to_view = st.number_input("输入要查看的文本序号", min_value=0, max_value=total_entries - 1, step=1) | |
# # 添加复选框以选择是否切割文本 | |
# cut_tokens = st.checkbox("是否对文本进行token切割", value=False) | |
# if st.button("显示文本"): | |
# text = mgt_human.get_text_by_index(file_to_display, index=index_to_view, cut_tokens=cut_tokens) | |
# st.write("对应的文本内容为:", text) | |
# else: | |
# st.write("未找到任何 JSON 文件,请检查 ZIP 文件结构。") | |
# # 清理上传文件的临时目录 | |
# if st.button("清除文件"): | |
# import shutil | |
# shutil.rmtree("temp") | |
# st.write("临时文件已清除。") | |
import streamlit as st | |
from transformers import pipeline | |
# Initialize Hugging Face text classifier | |
# Cache the model to avoid reloading | |
def load_model(): | |
# Use a Hugging Face pre-trained text classification model | |
# Replace with a suitable model if necessary | |
classifier = pipeline("text-classification", model="roberta-base-openai-detector") | |
return classifier | |
st.title("Machine-Generated Text Detector") | |
st.write("Enter a text snippet, and I will analyze it to determine if it is likely written by a human or generated by a machine.") | |
# Load the model | |
classifier = load_model() | |
# Input text | |
input_text = st.text_area("Enter text here:", height=150) | |
# Button to trigger detection | |
if st.button("Analyze"): | |
if input_text: | |
# Make prediction | |
result = classifier(input_text) | |
# Extract label and confidence score | |
label = result[0]['label'] | |
score = result[0]['score'] * 100 # Convert to percentage for readability | |
# Display result | |
if label == "LABEL_1": | |
st.write(f"**Result:** This text is likely **Machine-Generated**.") | |
else: | |
st.write(f"**Result:** This text is likely **Human-Written**.") | |
# Display confidence score | |
st.write(f"**Confidence Score:** {score:.2f}%") | |
else: | |
st.write("Please enter some text for analysis.") |