{ "cells": [ { "cell_type": "raw", "id": "b77c4818-9d92-4581-a2fb-2b38da776a8b", "metadata": {}, "source": [ "# python >= 3.8\n", "import sys\n", "!{sys.executable} -m pip insatll langchain, gradio, tiktoken, unstructured" ] }, { "cell_type": "code", "execution_count": 1, "id": "80b2033d-b985-440d-a01f-e7f8547a6801", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/requests/__init__.py:102: RequestsDependencyWarning: urllib3 (1.26.8) or chardet (5.1.0)/charset_normalizer (2.0.12) doesn't match a supported version!\n", " warnings.warn(\"urllib3 ({}) or chardet ({})/charset_normalizer ({}) doesn't match a supported \"\n", "/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "# from langchain.text_splitter import LatexTextSplitter\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "from typing import Any\n", "import requests\n", "import logging\n", "import json\n", "import tiktoken\n", "import gradio as gr\n", "from langchain.document_loaders import UnstructuredPDFLoader" ] }, { "cell_type": "code", "execution_count": 2, "id": "720927dd-848e-47ff-a847-4e1097768561", "metadata": { "tags": [] }, "outputs": [], "source": [ "turbo_encoding = tiktoken.encoding_for_model(\"gpt-3.5-turbo\")\n", "with open(\"sample.tex\", \"r\") as f:\n", " content = f.read()" ] }, { "cell_type": "code", "execution_count": 3, "id": "13bc48fd-a048-4635-99c4-79ddb62d5c4d", "metadata": { "tags": [] }, "outputs": [], "source": [ "class LatexTextSplitter(RecursiveCharacterTextSplitter):\n", " \"\"\"Attempts to split the text along Latex-formatted layout elements.\"\"\"\n", "\n", " def __init__(self, **kwargs: Any):\n", " \"\"\"Initialize a LatexTextSplitter.\"\"\"\n", " separators = [\n", " # First, try to split along Latex sections\n", " \"\\chapter{\",\n", " \"\\section{\",\n", " \"\\subsection{\",\n", " \"\\subsubsection{\",\n", "\n", " # Now split by environments\n", " \"\\begin{\"\n", " # \"\\n\\\\begin{enumerate}\",\n", " # \"\\n\\\\begin{itemize}\",\n", " # \"\\n\\\\begin{description}\",\n", " # \"\\n\\\\begin{list}\",\n", " # \"\\n\\\\begin{quote}\",\n", " # \"\\n\\\\begin{quotation}\",\n", " # \"\\n\\\\begin{verse}\",\n", " # \"\\n\\\\begin{verbatim}\",\n", "\n", " ## Now split by math environments\n", " # \"\\n\\\\begin{align}\",\n", " # \"$$\",\n", " # \"$\",\n", "\n", " # Now split by the normal type of lines\n", " \" \",\n", " \"\",\n", " ]\n", " super().__init__(separators=separators, **kwargs)\n", "\n", "\n", "def json_validator(text: str, openai_key: str, retry: int = 3):\n", " for _ in range(retry):\n", " try:\n", " return json.loads(text)\n", " except Exception:\n", " \n", " try:\n", " prompt = f\"Modify the following into a valid json format:\\n{text}\"\n", " prompt_token_length = len(turbo_encoding.encode(prompt))\n", "\n", " data = {\n", " \"model\": \"text-davinci-003\",\n", " \"prompt\": prompt,\n", " \"max_tokens\": 4097 - prompt_token_length - 64\n", " }\n", " headers = {\n", " \"Content-Type\": \"application/json\",\n", " \"Authorization\": f\"Bearer {openai_key}\"\n", " }\n", " for _ in range(retry):\n", " response = requests.post(\n", " 'https://api.openai.com/v1/completions',\n", " json=data,\n", " headers=headers,\n", " timeout=300\n", " )\n", " if response.status_code != 200:\n", " logging.warning(f'fetch openai chat retry: {response.text}')\n", " continue\n", " text = response.json()['choices'][0]['text']\n", " break\n", " except:\n", " return response.json()['error']\n", " \n", " return text" ] }, { "cell_type": "code", "execution_count": 9, "id": "a34bf526-fb0c-4b5e-8c3d-bab7a13e5fe7", "metadata": { "tags": [] }, "outputs": [], "source": [ "def analyze(latex_whole_document: str, openai_key: str, progress):\n", " \n", " logging.info(\"start analysis\")\n", " \n", " output_format = \"\"\"\n", "\n", " ```json\n", " [\n", " \\\\ Potential point for improvement 1\n", " {{\n", " \"title\": string \\\\ What this modification is about\n", " \"thought\": string \\\\ The reason why this should be improved\n", " \"action\": string \\\\ how to make improvement\n", " \"original\": string \\\\ the original latex snippet that can be improved\n", " \"improved\": string \\\\ the improved latex snippet which address your point\n", " }},\n", " {{}}\n", " ]\n", " ```\n", " \"\"\"\n", " \n", " chunk_size = 1000\n", " # for _ in range(5):\n", " # try:\n", " # latex_splitter = LatexTextSplitter(\n", " # chunk_size=min(chunk_size, len(latex_whole_document)),\n", " # chunk_overlap=0,\n", " # )\n", " # docs = latex_splitter.create_documents([latex_whole_document])\n", " # break\n", " # except:\n", " # chunk_size // 2\n", "\n", " latex_splitter = LatexTextSplitter(\n", " chunk_size=min(chunk_size, len(latex_whole_document)),\n", " chunk_overlap=0,\n", " )\n", " docs = latex_splitter.create_documents([latex_whole_document])\n", " \n", " progress(0.05)\n", " ideas = []\n", " for doc in progress.tqdm(docs):\n", "\n", " prompt = f\"\"\"\n", " I'm a computer science student.\n", " You are my editor.\n", " Your goal is to improve my paper quality at your best.\n", " \n", " \n", " ```\n", " {doc.page_content}\n", " ```\n", " The above is a segment of my research paper. If the end of the segment is not complete, just ignore it.\n", " Point out the parts that can be improved.\n", " Focus on grammar, writing, content, section structure.\n", " Ignore comments and those that are outside the document environment.\n", " List out all the points with a latex snippet which is the improved version addressing your point.\n", " Same paragraph should be only address once.\n", " Output the response in the following valid json format:\n", " {output_format}\n", "\n", " \"\"\"\n", " \n", " idea = fetch_chat(prompt, openai_key)\n", " if isinstance(idea, list):\n", " ideas += idea\n", " break\n", " else:\n", " raise gr.Error(idea)\n", "\n", " logging.info('complete analysis')\n", " return ideas\n", "\n", "\n", "def fetch_chat(prompt: str, openai_key: str, retry: int = 3):\n", " json = {\n", " \"model\": \"gpt-3.5-turbo-16k\",\n", " \"messages\": [{\"role\": \"user\", \"content\": prompt}]\n", " }\n", " headers = {\n", " \"Content-Type\": \"application/json\",\n", " \"Authorization\": f\"Bearer {openai_key}\"\n", " }\n", " for _ in range(retry):\n", " response = requests.post(\n", " 'https://api.openai.com/v1/chat/completions',\n", " json=json,\n", " headers=headers,\n", " timeout=300\n", " )\n", " if response.status_code != 200:\n", " logging.warning(f'fetch openai chat retry: {response.text}')\n", " continue\n", " result = response.json()['choices'][0]['message']['content']\n", " return json_validator(result, openai_key)\n", " \n", " return response.json()[\"error\"]\n", " \n", " \n", "def read_file(f: str):\n", " if f is None:\n", " return \"\"\n", " elif f.name.endswith('pdf'):\n", " loader = UnstructuredPDFLoader(f.name)\n", " pages = loader.load_and_split()\n", " return \"\\n\".join([p.page_content for p in pages])\n", " elif f.name.endswith('tex'):\n", " with open(f.name, \"r\") as f:\n", " return f.read()\n", " else:\n", " return \"Only support .tex & .pdf\"" ] }, { "cell_type": "code", "execution_count": 11, "id": "cec63e87-9741-4596-a3f1-a901830e3771", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/components/button.py:112: UserWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n", " warnings.warn(\n", "/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/layouts.py:80: UserWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n", " warnings.warn(\n", "/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/components/textbox.py:259: UserWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://0.0.0.0:7653\n", "Running on public URL: https://73992a9ff20adf33a3.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:root:fetch openai chat retry: {\n", " \"error\": {\n", " \"message\": \"\",\n", " \"type\": \"invalid_request_error\",\n", " \"param\": null,\n", " \"code\": \"invalid_api_key\"\n", " }\n", "}\n", "\n", "WARNING:root:fetch openai chat retry: {\n", " \"error\": {\n", " \"message\": \"\",\n", " \"type\": \"invalid_request_error\",\n", " \"param\": null,\n", " \"code\": \"invalid_api_key\"\n", " }\n", "}\n", "\n", "WARNING:root:fetch openai chat retry: {\n", " \"error\": {\n", " \"message\": \"\",\n", " \"type\": \"invalid_request_error\",\n", " \"param\": null,\n", " \"code\": \"invalid_api_key\"\n", " }\n", "}\n", "\n", "Traceback (most recent call last):\n", " File \"/tmp/ipykernel_22031/279099274.py\", line 14, in generate\n", " idea_list = analyze(txt, openai_key, progress)\n", " File \"/tmp/ipykernel_22031/3345783910.py\", line 69, in analyze\n", " raise gr.Error(idea)\n", "gradio.exceptions.Error: {'message': '', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}\n", "\n", "During handling of the above exception, another exception occurred:\n", "\n", "Traceback (most recent call last):\n", " File \"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/routes.py\", line 437, in run_predict\n", " output = await app.get_blocks().process_api(\n", " File \"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/blocks.py\", line 1352, in process_api\n", " result = await self.call_function(\n", " File \"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/blocks.py\", line 1077, in call_function\n", " prediction = await anyio.to_thread.run_sync(\n", " File \"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/anyio/to_thread.py\", line 31, in run_sync\n", " return await get_asynclib().run_sync_in_worker_thread(\n", " File \"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/anyio/_backends/_asyncio.py\", line 937, in run_sync_in_worker_thread\n", " return await future\n", " File \"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/anyio/_backends/_asyncio.py\", line 867, in run\n", " result = context.run(func, *args)\n", " File \"/tmp/ipykernel_22031/279099274.py\", line 37, in generate\n", " raise gr.Error(str(e))\n", "gradio.exceptions.Error: \"{'message': '', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}\"\n" ] } ], "source": [ "idea_list = []\n", "max_ideas = 20\n", "\n", "\n", "with gr.Blocks() as demo:\n", " \n", " def generate(txt: str, openai_key: str, progress=gr.Progress()):\n", " \n", " if not openai_key:\n", " raise gr.Error(\"Please provide openai key !\")\n", " \n", " try:\n", " global idea_list\n", " idea_list = analyze(txt, openai_key, progress)\n", " k = min(len(idea_list), max_ideas)\n", "\n", " idea_buttons = [\n", " gr.Button.update(visible=True, value=i['title'])\n", " for e, i in enumerate(idea_list[:max_ideas])\n", " ]\n", " idea_buttons += [\n", " gr.Button.update(visible=False)\n", " ]*(max_ideas-len(idea_buttons))\n", "\n", " idea_details = [\n", " gr.Textbox.update(value=\"\", label=\"thought\", visible=True),\n", " gr.Textbox.update(value=\"\", label=\"action\", visible=True),\n", " gr.Textbox.update(value=\"\", label=\"original\", visible=True, max_lines=5, lines=5),\n", " gr.Textbox.update(value=\"\", label=\"improved\", visible=True, max_lines=5, lines=5)\n", " ]\n", "\n", " return [\n", " gr.Textbox.update(\"Suggestions\", interactive=False, show_label=False),\n", " gr.Button.update(visible=True, value=\"Analyze\")\n", " ] + idea_details + idea_buttons\n", " except Exception as e:\n", " raise gr.Error(str(e))\n", "\n", " def select(name: str):\n", " global idea_list\n", " for i in idea_list:\n", " if i['title'] == name:\n", " return [\n", " gr.Textbox.update(value=i[\"thought\"], label=\"thought\", visible=True),\n", " gr.Textbox.update(value=i[\"action\"], label=\"action\", visible=True),\n", " gr.Textbox.update(value=i[\"original\"], label=\"original\", visible=True, max_lines=5, lines=5),\n", " gr.Textbox.update(value=i[\"improved\"], label=\"improved\", visible=True, max_lines=5, lines=5)\n", " ]\n", " \n", " title = gr.Button(\"PaperGPT\", interactive=False).style(size=10)\n", " key = gr.Textbox(label=\"openai_key\")\n", " with gr.Row().style(equal_height=True):\n", " with gr.Column(scale=0.95):\n", " txt_in = gr.Textbox(label=\"Input\", lines=11, max_lines=11, value=content[2048+2048+256-45:])\n", " with gr.Column(scale=0.05):\n", " upload = gr.File(file_count=\"single\", file_types=[\"tex\", \".pdf\"])\n", " btn = gr.Button(\"Analyze\")\n", " upload.change(read_file, inputs=upload, outputs=txt_in)\n", "\n", " textboxes = []\n", " sug = gr.Textbox(\"Suggestions\", interactive=False, show_label=False).style(text_align=\"center\")\n", " with gr.Row():\n", " with gr.Column(scale=0.4):\n", " for i in range(max_ideas):\n", " t = gr.Button(\"\", visible=False)\n", " textboxes.append(t)\n", " with gr.Column(scale=0.6):\n", " thought = gr.Textbox(label=\"thought\", visible=False, interactive=False)\n", " action = gr.Textbox(label=\"action\", visible=False, interactive=False)\n", " original = gr.Textbox(label=\"original\", visible=False, max_lines=5, lines=5, interactive=False)\n", " improved = gr.Textbox(label=\"improved\", visible=False, max_lines=5, lines=5, interactive=False)\n", "\n", " btn.click(generate, inputs=[txt_in, key], outputs=[sug, btn, thought, action, original, improved] + textboxes)\n", " for i in textboxes:\n", " i.click(select, inputs=[i], outputs=[thought, action, original, improved])\n", " demo.launch(server_name=\"0.0.0.0\", server_port=7653, share=True, enable_queue=True)" ] }, { "cell_type": "code", "execution_count": 10, "id": "8ac8aa92-f7a6-480c-a1b9-2f1c61426846", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closing server running on port: 7653\n" ] } ], "source": [ "demo.close()" ] }, { "cell_type": "code", "execution_count": null, "id": "c9a19815-b8de-4a99-9fcf-0b1a0d3981a3", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }