Spaces:
Running
Running
Merge pull request #5 from CCCBora/pre-defined-references
Browse files- app.py +63 -43
- auto_backgrounds.py +63 -53
- latex_templates/ICLR2022/fig.png +0 -0
- latex_templates/ICLR2022/template.tex +2 -1
- latex_templates/example_references.bib +20 -0
- latex_templates/pre_refs.bib +0 -17
- requirements.txt +0 -0
- section_generator.py +30 -17
- utils/gpt_interaction.py +16 -0
- utils/prompts.py +137 -38
- utils/references.py +124 -124
- utils/tex_processing.py +2 -1
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import openai
|
4 |
-
from auto_backgrounds import generate_backgrounds,
|
5 |
from utils.file_operations import hash_name
|
6 |
|
7 |
# note: App白屏bug:允许第三方cookie
|
@@ -9,19 +9,21 @@ from utils.file_operations import hash_name
|
|
9 |
# 6. get logs when the procedure is not completed. *
|
10 |
# 7. 自己的文件库; 更多的prompts
|
11 |
# 8. Decide on how to generate the main part of a paper * (Langchain/AutoGPT
|
12 |
-
# 9. Load .bibtex file to generate a pre-defined references list. *
|
13 |
# 1. 把paper改成纯JSON?
|
14 |
# 2. 实现别的功能
|
15 |
# 3. Check API Key GPT-4 Support.
|
16 |
# 8. Re-build some components using `langchain`
|
17 |
-
# - in `references.py`, use PromptTemplates.format -> str
|
18 |
# - in `gpt_interation`, use LLM
|
|
|
19 |
# future:
|
20 |
# 4. add auto_polishing function
|
21 |
# 12. Change link to more appealing color # after the website is built;
|
22 |
# 1. Check if there are any duplicated citations
|
23 |
# 2. Remove potential thebibliography and bibitem in .tex file
|
24 |
|
|
|
|
|
|
|
25 |
openai_key = os.getenv("OPENAI_API_KEY")
|
26 |
access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
|
27 |
secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
|
@@ -43,22 +45,19 @@ else:
|
|
43 |
IS_OPENAI_API_KEY_AVAILABLE = False
|
44 |
|
45 |
|
46 |
-
def clear_inputs(
|
47 |
return "", ""
|
48 |
|
49 |
|
50 |
def wrapped_generator(paper_title, paper_description, openai_api_key=None,
|
51 |
-
|
52 |
-
cache_mode=IS_CACHE_AVAILABLE
|
53 |
# if `cache_mode` is True, then follow the following steps:
|
54 |
# check if "title"+"description" have been generated before
|
55 |
# if so, download from the cloud storage, return it
|
56 |
# if not, generate the result.
|
57 |
-
if
|
58 |
-
|
59 |
-
# generator = generate_backgrounds
|
60 |
-
generator = generate_draft
|
61 |
-
# generator = fake_generator
|
62 |
if openai_api_key is not None:
|
63 |
openai.api_key = openai_api_key
|
64 |
openai.Model.list()
|
@@ -66,9 +65,8 @@ def wrapped_generator(paper_title, paper_description, openai_api_key=None,
|
|
66 |
if cache_mode:
|
67 |
from utils.storage import list_all_files, download_file, upload_file
|
68 |
# check if "title"+"description" have been generated before
|
69 |
-
|
70 |
input_dict = {"title": paper_title, "description": paper_description,
|
71 |
-
"generator": "generate_draft"}
|
72 |
file_name = hash_name(input_dict) + ".zip"
|
73 |
file_list = list_all_files()
|
74 |
# print(f"{file_name} will be generated. Check the file list {file_list}")
|
@@ -79,13 +77,17 @@ def wrapped_generator(paper_title, paper_description, openai_api_key=None,
|
|
79 |
else:
|
80 |
# generate the result.
|
81 |
# output = fake_generate_backgrounds(title, description, openai_key)
|
82 |
-
|
83 |
-
|
|
|
|
|
84 |
upload_file(output)
|
85 |
return output
|
86 |
else:
|
87 |
# output = fake_generate_backgrounds(title, description, openai_key)
|
88 |
-
output =
|
|
|
|
|
89 |
return output
|
90 |
|
91 |
|
@@ -96,21 +98,25 @@ theme = gr.themes.Default(font=gr.themes.GoogleFont("Questrial"))
|
|
96 |
# button_primary_background_fill="#281A39"
|
97 |
# )
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
with gr.Blocks(theme=theme) as demo:
|
100 |
gr.Markdown('''
|
101 |
# Auto-Draft: 文献整理辅助工具
|
102 |
|
103 |
-
本Demo提供对[Auto-Draft](https://github.com/CCCBora/auto-draft)的auto_draft
|
|
|
104 |
|
105 |
***2023-05-03 Update***: 在公开版本中为大家提供了输入OpenAI API Key的地址, 如果有GPT-4的API KEY的话可以在这里体验!
|
106 |
|
107 |
-
在这个Huggingface Organization里也提供一定额度的免费体验: [AUTO-ACADEMIC](https://huggingface.co/
|
108 |
|
109 |
-
如果有更多想法和建议欢迎加入QQ群里交流, 如果我在Space里更新了Key我会第一时间通知大家. 群号: ***249738228***.
|
110 |
-
|
111 |
-
## 用法
|
112 |
-
|
113 |
-
输入想要生成的论文名称(比如Playing Atari with Deep Reinforcement Learning), 点击Submit, 等待大概十分钟, 下载.zip格式的输出,在Overleaf上编译浏览.
|
114 |
''')
|
115 |
|
116 |
with gr.Row():
|
@@ -123,6 +129,8 @@ with gr.Blocks(theme=theme) as demo:
|
|
123 |
|
124 |
# 每个功能做一个tab
|
125 |
with gr.Tab("学术论文"):
|
|
|
|
|
126 |
title = gr.Textbox(value="Playing Atari with Deep Reinforcement Learning", lines=1, max_lines=1,
|
127 |
label="Title", info="论文标题")
|
128 |
|
@@ -130,33 +138,41 @@ with gr.Blocks(theme=theme) as demo:
|
|
130 |
description_pp = gr.Textbox(lines=5, label="Description (Optional)", visible=True,
|
131 |
info="对希望生成的论文的一些描述. 包括这篇论文的创新点, 主要贡献, 等.")
|
132 |
|
133 |
-
interactive = False
|
134 |
-
gr.Markdown('''
|
135 |
-
## 下面的功能我只做了UI, 还没来得及实现功能.
|
136 |
-
''')
|
137 |
with gr.Row():
|
138 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
gr.Markdown('''
|
140 |
-
|
141 |
-
|
142 |
-
通过上传.bib文件来控制GPT-4模型必须参考哪些文献.
|
143 |
''')
|
144 |
bibtex_file = gr.File(label="Upload .bib file", file_types=["text"],
|
145 |
-
interactive=
|
|
|
|
|
|
|
|
|
146 |
with gr.Column():
|
147 |
search_engine = gr.Dropdown(label="Search Engine",
|
148 |
choices=["ArXiv", "Semantic Scholar", "Google Scholar", "None"],
|
149 |
-
value=
|
150 |
-
interactive=
|
151 |
-
info="用于决定GPT-4用什么搜索引擎来搜索文献.
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
160 |
|
161 |
with gr.Row():
|
162 |
clear_button_pp = gr.Button("Clear")
|
@@ -195,7 +211,11 @@ with gr.Blocks(theme=theme) as demo:
|
|
195 |
file_output = gr.File(label="Output")
|
196 |
|
197 |
clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
|
198 |
-
submit_button_pp.click(fn=wrapped_generator,
|
|
|
|
|
|
|
|
|
199 |
|
200 |
demo.queue(concurrency_count=1, max_size=5, api_open=False)
|
201 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import openai
|
4 |
+
from auto_backgrounds import generate_backgrounds, generate_draft
|
5 |
from utils.file_operations import hash_name
|
6 |
|
7 |
# note: App白屏bug:允许第三方cookie
|
|
|
9 |
# 6. get logs when the procedure is not completed. *
|
10 |
# 7. 自己的文件库; 更多的prompts
|
11 |
# 8. Decide on how to generate the main part of a paper * (Langchain/AutoGPT
|
|
|
12 |
# 1. 把paper改成纯JSON?
|
13 |
# 2. 实现别的功能
|
14 |
# 3. Check API Key GPT-4 Support.
|
15 |
# 8. Re-build some components using `langchain`
|
|
|
16 |
# - in `gpt_interation`, use LLM
|
17 |
+
# 5. 从提供的bib文件中 找到cite和citedby的文章, 计算embeddings; 从整个paper list中 根据cos距离进行排序; 选取max_refs的文章
|
18 |
# future:
|
19 |
# 4. add auto_polishing function
|
20 |
# 12. Change link to more appealing color # after the website is built;
|
21 |
# 1. Check if there are any duplicated citations
|
22 |
# 2. Remove potential thebibliography and bibitem in .tex file
|
23 |
|
24 |
+
#######################################################################################################################
|
25 |
+
# Check if openai and cloud storage available
|
26 |
+
#######################################################################################################################
|
27 |
openai_key = os.getenv("OPENAI_API_KEY")
|
28 |
access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
|
29 |
secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
|
|
|
45 |
IS_OPENAI_API_KEY_AVAILABLE = False
|
46 |
|
47 |
|
48 |
+
def clear_inputs(*args):
|
49 |
return "", ""
|
50 |
|
51 |
|
52 |
def wrapped_generator(paper_title, paper_description, openai_api_key=None,
|
53 |
+
paper_template="ICLR2022", tldr=True, max_num_refs=50, selected_sections=None, bib_refs=None, model="gpt-4",
|
54 |
+
cache_mode=IS_CACHE_AVAILABLE):
|
55 |
# if `cache_mode` is True, then follow the following steps:
|
56 |
# check if "title"+"description" have been generated before
|
57 |
# if so, download from the cloud storage, return it
|
58 |
# if not, generate the result.
|
59 |
+
if bib_refs is not None:
|
60 |
+
bib_refs = bib_refs.name
|
|
|
|
|
|
|
61 |
if openai_api_key is not None:
|
62 |
openai.api_key = openai_api_key
|
63 |
openai.Model.list()
|
|
|
65 |
if cache_mode:
|
66 |
from utils.storage import list_all_files, download_file, upload_file
|
67 |
# check if "title"+"description" have been generated before
|
|
|
68 |
input_dict = {"title": paper_title, "description": paper_description,
|
69 |
+
"generator": "generate_draft"}
|
70 |
file_name = hash_name(input_dict) + ".zip"
|
71 |
file_list = list_all_files()
|
72 |
# print(f"{file_name} will be generated. Check the file list {file_list}")
|
|
|
77 |
else:
|
78 |
# generate the result.
|
79 |
# output = fake_generate_backgrounds(title, description, openai_key)
|
80 |
+
output = generate_draft(paper_title, paper_description, template=paper_template,
|
81 |
+
tldr=tldr, max_num_refs=max_num_refs,
|
82 |
+
sections=selected_sections, bib_refs=bib_refs, model=model)
|
83 |
+
# output = generate_draft(paper_title, paper_description, template, "gpt-4")
|
84 |
upload_file(output)
|
85 |
return output
|
86 |
else:
|
87 |
# output = fake_generate_backgrounds(title, description, openai_key)
|
88 |
+
output = generate_draft(paper_title, paper_description, template=paper_template,
|
89 |
+
tldr=tldr, max_num_refs=max_num_refs,
|
90 |
+
sections=selected_sections, bib_refs=bib_refs, model=model)
|
91 |
return output
|
92 |
|
93 |
|
|
|
98 |
# button_primary_background_fill="#281A39"
|
99 |
# )
|
100 |
|
101 |
+
ACADEMIC_PAPER = """## 一键生成论文初稿
|
102 |
+
|
103 |
+
1. 在Title文本框中输入想要生成的论文名称(比如Playing Atari with Deep Reinforcement Learning).
|
104 |
+
2. 点击Submit. 等待大概十分钟.
|
105 |
+
3. 在右侧下载.zip格式的输出,在Overleaf上编译浏览.
|
106 |
+
"""
|
107 |
+
|
108 |
with gr.Blocks(theme=theme) as demo:
|
109 |
gr.Markdown('''
|
110 |
# Auto-Draft: 文献整理辅助工具
|
111 |
|
112 |
+
本Demo提供对[Auto-Draft](https://github.com/CCCBora/auto-draft)的auto_draft功能的测试.
|
113 |
+
通过输入想要生成的论文名称(比如Playing atari with deep reinforcement learning),即可由AI辅助生成论文模板.
|
114 |
|
115 |
***2023-05-03 Update***: 在公开版本中为大家提供了输入OpenAI API Key的地址, 如果有GPT-4的API KEY的话可以在这里体验!
|
116 |
|
117 |
+
在这个Huggingface Organization里也提供一定额度的免费体验: [AUTO-ACADEMIC](https://huggingface.co/auto-academic).
|
118 |
|
119 |
+
如果有更多想法和建议欢迎加入QQ群里交流, 如果我在Space里更新了Key我会第一时间通知大家. 群号: ***249738228***.
|
|
|
|
|
|
|
|
|
120 |
''')
|
121 |
|
122 |
with gr.Row():
|
|
|
129 |
|
130 |
# 每个功能做一个tab
|
131 |
with gr.Tab("学术论文"):
|
132 |
+
gr.Markdown(ACADEMIC_PAPER)
|
133 |
+
|
134 |
title = gr.Textbox(value="Playing Atari with Deep Reinforcement Learning", lines=1, max_lines=1,
|
135 |
label="Title", info="论文标题")
|
136 |
|
|
|
138 |
description_pp = gr.Textbox(lines=5, label="Description (Optional)", visible=True,
|
139 |
info="对希望生成的论文的一些描述. 包括这篇论文的创新点, 主要贡献, 等.")
|
140 |
|
|
|
|
|
|
|
|
|
141 |
with gr.Row():
|
142 |
with gr.Column():
|
143 |
+
with gr.Row():
|
144 |
+
template = gr.Dropdown(label="Template", choices=["ICLR2022"], value="ICLR2022",
|
145 |
+
interactive=False,
|
146 |
+
info="生成论文的参考模板. (暂不支持修改)")
|
147 |
+
model_selection = gr.Dropdown(label="Model", choices=["gpt-4", "gpt-3.5-turbo"],
|
148 |
+
value="gpt-4",
|
149 |
+
interactive=True,
|
150 |
+
info="生成论文用到的语言模型.")
|
151 |
gr.Markdown('''
|
152 |
+
上传.bib文件提供AI需要参考的文献.
|
|
|
|
|
153 |
''')
|
154 |
bibtex_file = gr.File(label="Upload .bib file", file_types=["text"],
|
155 |
+
interactive=True)
|
156 |
+
gr.Examples(
|
157 |
+
examples=["latex_templates/example_references.bib"],
|
158 |
+
inputs=bibtex_file
|
159 |
+
)
|
160 |
with gr.Column():
|
161 |
search_engine = gr.Dropdown(label="Search Engine",
|
162 |
choices=["ArXiv", "Semantic Scholar", "Google Scholar", "None"],
|
163 |
+
value="Semantic Scholar",
|
164 |
+
interactive=False,
|
165 |
+
info="用于决定GPT-4用什么搜索引擎来搜索文献. (暂不支持修改)")
|
166 |
+
tldr_checkbox = gr.Checkbox(value=True, label="TLDR;",
|
167 |
+
info="选择此筐表示将使用Semantic Scholar的TLDR作为文献的总结.",
|
168 |
+
interactive=True)
|
169 |
+
sections = gr.CheckboxGroup(
|
170 |
+
choices=["introduction", "related works", "backgrounds", "methodology", "experiments",
|
171 |
+
"conclusion", "abstract"],
|
172 |
+
type="value", label="生成章节", interactive=True,
|
173 |
+
value=["introduction", "related works"])
|
174 |
+
slider = gr.Slider(minimum=1, maximum=100, value=50, step=1,
|
175 |
+
interactive=True, label="最大参考文献数目")
|
176 |
|
177 |
with gr.Row():
|
178 |
clear_button_pp = gr.Button("Clear")
|
|
|
211 |
file_output = gr.File(label="Output")
|
212 |
|
213 |
clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
|
214 |
+
# submit_button_pp.click(fn=wrapped_generator,
|
215 |
+
# inputs=[title, description_pp, key, template, tldr, slider, sections, bibtex_file], outputs=file_output)
|
216 |
+
submit_button_pp.click(fn=wrapped_generator,
|
217 |
+
inputs=[title, description_pp, key, template, tldr_checkbox, slider, sections, bibtex_file,
|
218 |
+
model_selection], outputs=file_output)
|
219 |
|
220 |
demo.queue(concurrency_count=1, max_size=5, api_open=False)
|
221 |
demo.launch()
|
auto_backgrounds.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os.path
|
2 |
-
|
3 |
from utils.references import References
|
4 |
from utils.file_operations import hash_name, make_archive, copy_templates
|
5 |
from section_generator import section_generation_bg, keywords_generation, figures_generation, section_generation
|
@@ -25,16 +25,35 @@ def log_usage(usage, generating_target, print_out=True):
|
|
25 |
TOTAL_COMPLETION_TOKENS += completion_tokens
|
26 |
|
27 |
message = f"For generating {generating_target}, {total_tokens} tokens have been used ({prompts_tokens} for prompts; {completion_tokens} for completion). " \
|
28 |
-
f"{TOTAL_TOKENS} tokens have been used in total
|
29 |
if print_out:
|
30 |
print(message)
|
31 |
logging.info(message)
|
32 |
|
33 |
-
def _generation_setup(title, description="", template="ICLR2022",
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
paper = {}
|
39 |
paper_body = {}
|
40 |
|
@@ -45,13 +64,17 @@ def _generation_setup(title, description="", template="ICLR2022", model="gpt-4",
|
|
45 |
# Generate keywords and references
|
46 |
print("Initialize the paper information ...")
|
47 |
input_dict = {"title": title, "description": description}
|
48 |
-
keywords, usage = keywords_generation(input_dict, model="gpt-3.5-turbo", max_kw_refs=max_kw_refs)
|
49 |
-
|
50 |
log_usage(usage, "keywords")
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
55 |
|
56 |
print(f"The paper information has been initialized. References are saved to {bibtex_path}.")
|
57 |
|
@@ -60,11 +83,12 @@ def _generation_setup(title, description="", template="ICLR2022", model="gpt-4",
|
|
60 |
paper["references"] = ref.to_prompts()
|
61 |
paper["body"] = paper_body
|
62 |
paper["bibtex"] = bibtex_path
|
63 |
-
return paper, destination_folder, all_paper_ids
|
64 |
|
65 |
|
66 |
|
67 |
def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-4"):
|
|
|
68 |
paper, destination_folder, _ = _generation_setup(title, description, template, model)
|
69 |
|
70 |
for section in ["introduction", "related works", "backgrounds"]:
|
@@ -82,54 +106,40 @@ def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-
|
|
82 |
return make_archive(destination_folder, filename)
|
83 |
|
84 |
|
85 |
-
def
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
log_usage(usage, section)
|
107 |
-
except Exception as e:
|
108 |
-
message = f"Failed to generate {section}. {type(e).__name__} was raised: {e}"
|
109 |
-
print(message)
|
110 |
-
logging.info(message)
|
111 |
-
max_attempts = 2
|
112 |
-
# todo: make this part more compact
|
113 |
-
# re-try `max_attempts` time
|
114 |
-
for i in range(max_attempts):
|
115 |
time.sleep(20)
|
116 |
-
try:
|
117 |
-
usage = section_generation(paper, section, destination_folder, model=model)
|
118 |
-
log_usage(usage, section)
|
119 |
-
e = None
|
120 |
-
except Exception as e:
|
121 |
-
pass
|
122 |
-
if e is None:
|
123 |
-
break
|
124 |
-
|
125 |
|
126 |
input_dict = {"title": title, "description": description, "generator": "generate_draft"}
|
127 |
filename = hash_name(input_dict) + ".zip"
|
|
|
128 |
return make_archive(destination_folder, filename)
|
129 |
|
130 |
|
131 |
if __name__ == "__main__":
|
|
|
|
|
|
|
132 |
title = "Using interpretable boosting algorithms for modeling environmental and agricultural data"
|
133 |
description = ""
|
134 |
-
output = generate_draft(title, description,
|
135 |
print(output)
|
|
|
1 |
import os.path
|
2 |
+
import json
|
3 |
from utils.references import References
|
4 |
from utils.file_operations import hash_name, make_archive, copy_templates
|
5 |
from section_generator import section_generation_bg, keywords_generation, figures_generation, section_generation
|
|
|
25 |
TOTAL_COMPLETION_TOKENS += completion_tokens
|
26 |
|
27 |
message = f"For generating {generating_target}, {total_tokens} tokens have been used ({prompts_tokens} for prompts; {completion_tokens} for completion). " \
|
28 |
+
f"{TOTAL_TOKENS} tokens have been used in total.\n\n"
|
29 |
if print_out:
|
30 |
print(message)
|
31 |
logging.info(message)
|
32 |
|
33 |
+
def _generation_setup(title, description="", template="ICLR2022", tldr=False,
|
34 |
+
max_kw_refs=10, max_num_refs=50, bib_refs=None):
|
35 |
+
"""
|
36 |
+
This function handles the setup process for paper generation; it contains three folds
|
37 |
+
1. Copy the template to the outputs folder. Create the log file `generation.log`
|
38 |
+
2. Collect references based on the given `title` and `description`
|
39 |
+
3. Generate the basic `paper` object (a dictionary)
|
40 |
+
|
41 |
+
Parameters:
|
42 |
+
title (str): The title of the paper.
|
43 |
+
description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
|
44 |
+
template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
|
45 |
+
tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be generated for the collected papers. Defaults to False.
|
46 |
+
max_kw_refs (int, optional): The maximum number of references that can be associated with each keyword. Defaults to 10.
|
47 |
+
max_num_refs (int, optional): The maximum number of references that can be included in the paper. Defaults to 50.
|
48 |
+
bib_refs (list, optional): A list of pre-existing references in BibTeX format. Defaults to None.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
tuple: A tuple containing the following elements:
|
52 |
+
- paper (dict): A dictionary containing the generated paper information.
|
53 |
+
- destination_folder (str): The path to the destination folder where the generation log is saved.
|
54 |
+
- all_paper_ids (list): A list of all paper IDs collected for the references.
|
55 |
+
"""
|
56 |
+
print("Generation setup...")
|
57 |
paper = {}
|
58 |
paper_body = {}
|
59 |
|
|
|
64 |
# Generate keywords and references
|
65 |
print("Initialize the paper information ...")
|
66 |
input_dict = {"title": title, "description": description}
|
67 |
+
# keywords, usage = keywords_generation(input_dict, model="gpt-3.5-turbo", max_kw_refs=max_kw_refs)
|
68 |
+
keywords, usage = keywords_generation(input_dict)
|
69 |
log_usage(usage, "keywords")
|
70 |
|
71 |
+
# generate keywords dictionary
|
72 |
+
keywords = {keyword:max_kw_refs for keyword in keywords}
|
73 |
+
print(f"keywords: {keywords}\n\n")
|
74 |
+
|
75 |
+
ref = References(title, bib_refs)
|
76 |
+
ref.collect_papers(keywords, tldr=tldr)
|
77 |
+
all_paper_ids = ref.to_bibtex(bibtex_path, max_num_refs) #todo: max_num_refs has not implemented yet
|
78 |
|
79 |
print(f"The paper information has been initialized. References are saved to {bibtex_path}.")
|
80 |
|
|
|
83 |
paper["references"] = ref.to_prompts()
|
84 |
paper["body"] = paper_body
|
85 |
paper["bibtex"] = bibtex_path
|
86 |
+
return paper, destination_folder, all_paper_ids #todo: use `all_paper_ids` to check if all citations are in this list
|
87 |
|
88 |
|
89 |
|
90 |
def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-4"):
|
91 |
+
# todo: to match the current generation setup
|
92 |
paper, destination_folder, _ = _generation_setup(title, description, template, model)
|
93 |
|
94 |
for section in ["introduction", "related works", "backgrounds"]:
|
|
|
106 |
return make_archive(destination_folder, filename)
|
107 |
|
108 |
|
109 |
+
def generate_draft(title, description="", template="ICLR2022",
|
110 |
+
tldr=True, max_kw_refs=10, max_num_refs=30, sections=None, bib_refs=None, model="gpt-4"):
|
111 |
+
# pre-processing `sections` parameter;
|
112 |
+
if sections is None:
|
113 |
+
sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion", "abstract"]
|
114 |
+
|
115 |
+
# todo: add more parameters; select which section to generate; select maximum refs.
|
116 |
+
paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs, max_num_refs, bib_refs)
|
117 |
+
for section in sections:
|
118 |
+
max_attempts = 4
|
119 |
+
attempts_count = 0
|
120 |
+
while attempts_count < max_attempts:
|
121 |
+
try:
|
122 |
+
usage = section_generation(paper, section, destination_folder, model=model)
|
123 |
+
log_usage(usage, section)
|
124 |
+
break
|
125 |
+
except Exception as e:
|
126 |
+
message = f"Failed to generate {section}. {type(e).__name__} was raised: {e}"
|
127 |
+
print(message)
|
128 |
+
logging.info(message)
|
129 |
+
attempts_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
time.sleep(20)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
input_dict = {"title": title, "description": description, "generator": "generate_draft"}
|
133 |
filename = hash_name(input_dict) + ".zip"
|
134 |
+
print("\nMission completed.\n")
|
135 |
return make_archive(destination_folder, filename)
|
136 |
|
137 |
|
138 |
if __name__ == "__main__":
|
139 |
+
import openai
|
140 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
141 |
+
|
142 |
title = "Using interpretable boosting algorithms for modeling environmental and agricultural data"
|
143 |
description = ""
|
144 |
+
output = generate_draft(title, description, tldr=True, max_kw_refs=10)
|
145 |
print(output)
|
latex_templates/ICLR2022/fig.png
ADDED
latex_templates/ICLR2022/template.tex
CHANGED
@@ -6,7 +6,8 @@
|
|
6 |
\input{math_commands.tex}
|
7 |
\usepackage{hyperref}
|
8 |
\usepackage{url}
|
9 |
-
\usepackage{
|
|
|
10 |
|
11 |
\title{TITLE}
|
12 |
\author{GPT-4}
|
|
|
6 |
\input{math_commands.tex}
|
7 |
\usepackage{hyperref}
|
8 |
\usepackage{url}
|
9 |
+
\usepackage{algorithm}
|
10 |
+
\usepackage{algorithmic}
|
11 |
|
12 |
\title{TITLE}
|
13 |
\author{GPT-4}
|
latex_templates/example_references.bib
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@inproceedings{ma2020understanding,
|
2 |
+
title={Understanding the impact of model incoherence on convergence of incremental sgd with random reshuffle},
|
3 |
+
author={Ma, Shaocong and Zhou, Yi},
|
4 |
+
booktitle={International Conference on Machine Learning},
|
5 |
+
pages={6565--6574},
|
6 |
+
year={2020},
|
7 |
+
organization={PMLR}
|
8 |
+
}
|
9 |
+
|
10 |
+
@inproceedings{ma2020variance,
|
11 |
+
author = {Ma, Shaocong and Zhou, Yi and Zou, Shaofeng},
|
12 |
+
booktitle = {Advances in Neural Information Processing Systems},
|
13 |
+
editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
|
14 |
+
pages = {14796--14806},
|
15 |
+
publisher = {Curran Associates, Inc.},
|
16 |
+
title = {Variance-Reduced Off-Policy TDC Learning: Non-Asymptotic Convergence Analysis},
|
17 |
+
url = {https://proceedings.neurips.cc/paper_files/paper/2020/file/a992995ef4f0439b258f2360dbb85511-Paper.pdf},
|
18 |
+
volume = {33},
|
19 |
+
year = {2020}
|
20 |
+
}
|
latex_templates/pre_refs.bib
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
|
2 |
-
@article{1512.07669,
|
3 |
-
title = {Reinforcement Learning: Stochastic Approximation Algorithms for Markov
|
4 |
-
Decision Processes},
|
5 |
-
author = {Vikram Krishnamurthy},
|
6 |
-
journal={arXiv preprint arXiv:1512.07669},
|
7 |
-
year = {2015},
|
8 |
-
url = {http://arxiv.org/abs/1512.07669v1}
|
9 |
-
}
|
10 |
-
|
11 |
-
@article{1511.02377,
|
12 |
-
title = {The Value Functions of Markov Decision Processes},
|
13 |
-
author = {Ehud Lehrer , Eilon Solan , Omri N. Solan},
|
14 |
-
journal={arXiv preprint arXiv:1511.02377},
|
15 |
-
year = {2015},
|
16 |
-
url = {http://arxiv.org/abs/1511.02377v1}
|
17 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
section_generator.py
CHANGED
@@ -3,6 +3,9 @@ from utils.gpt_interaction import get_responses, extract_responses, extract_keyw
|
|
3 |
from utils.figures import generate_random_figures
|
4 |
import time
|
5 |
import os
|
|
|
|
|
|
|
6 |
|
7 |
# three GPT-based content generator:
|
8 |
# 1. section_generation: used to generate main content of the paper
|
@@ -23,7 +26,7 @@ def section_generation_bg(paper, section, save_to_path, model):
|
|
23 |
print(f"Generating {section}...")
|
24 |
prompts = generate_bg_summary_prompts(paper, section)
|
25 |
gpt_response, usage = get_responses(prompts, model)
|
26 |
-
output = extract_responses(gpt_response)
|
27 |
paper["body"][section] = output
|
28 |
tex_file = os.path.join(save_to_path, f"{section}.tex")
|
29 |
# tex_file = save_to_path + f"/{section}.tex"
|
@@ -56,36 +59,46 @@ def section_generation(paper, section, save_to_path, model):
|
|
56 |
print(f"Generating {section}...")
|
57 |
prompts = generate_paper_prompts(paper, section)
|
58 |
gpt_response, usage = get_responses(prompts, model)
|
59 |
-
output = extract_responses(gpt_response)
|
60 |
paper["body"][section] = output
|
61 |
tex_file = os.path.join(save_to_path, f"{section}.tex")
|
62 |
# tex_file = save_to_path + f"/{section}.tex"
|
63 |
if section == "abstract":
|
64 |
with open(tex_file, "w") as f:
|
65 |
-
f.write(r"\begin{abstract}")
|
66 |
-
with open(tex_file, "a") as f:
|
67 |
f.write(output)
|
68 |
-
with open(tex_file, "a") as f:
|
69 |
-
f.write(r"\end{abstract}")
|
70 |
else:
|
71 |
with open(tex_file, "w") as f:
|
72 |
-
f.write(f"\section{{{section.upper()}}}\n")
|
73 |
-
with open(tex_file, "a") as f:
|
74 |
f.write(output)
|
75 |
time.sleep(5)
|
76 |
print(f"{section} has been generated. Saved to {tex_file}.")
|
77 |
return usage
|
78 |
|
79 |
-
def keywords_generation(input_dict, model, max_kw_refs = 10):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
title = input_dict.get("title")
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def figures_generation(paper, save_to_path, model):
|
91 |
prompts = generate_experiments_prompts(paper)
|
|
|
3 |
from utils.figures import generate_random_figures
|
4 |
import time
|
5 |
import os
|
6 |
+
from utils.prompts import KEYWORDS_SYSTEM
|
7 |
+
from utils.gpt_interaction import get_gpt_responses
|
8 |
+
import json
|
9 |
|
10 |
# three GPT-based content generator:
|
11 |
# 1. section_generation: used to generate main content of the paper
|
|
|
26 |
print(f"Generating {section}...")
|
27 |
prompts = generate_bg_summary_prompts(paper, section)
|
28 |
gpt_response, usage = get_responses(prompts, model)
|
29 |
+
output = gpt_response # extract_responses(gpt_response)
|
30 |
paper["body"][section] = output
|
31 |
tex_file = os.path.join(save_to_path, f"{section}.tex")
|
32 |
# tex_file = save_to_path + f"/{section}.tex"
|
|
|
59 |
print(f"Generating {section}...")
|
60 |
prompts = generate_paper_prompts(paper, section)
|
61 |
gpt_response, usage = get_responses(prompts, model)
|
62 |
+
output = gpt_response # extract_responses(gpt_response)
|
63 |
paper["body"][section] = output
|
64 |
tex_file = os.path.join(save_to_path, f"{section}.tex")
|
65 |
# tex_file = save_to_path + f"/{section}.tex"
|
66 |
if section == "abstract":
|
67 |
with open(tex_file, "w") as f:
|
|
|
|
|
68 |
f.write(output)
|
|
|
|
|
69 |
else:
|
70 |
with open(tex_file, "w") as f:
|
|
|
|
|
71 |
f.write(output)
|
72 |
time.sleep(5)
|
73 |
print(f"{section} has been generated. Saved to {tex_file}.")
|
74 |
return usage
|
75 |
|
76 |
+
# def keywords_generation(input_dict, model, max_kw_refs = 10):
|
77 |
+
# title = input_dict.get("title")
|
78 |
+
# description = input_dict.get("description", "")
|
79 |
+
# if title is not None:
|
80 |
+
# prompts = generate_keywords_prompts(title, description, max_kw_refs)
|
81 |
+
# gpt_response, usage = get_responses(prompts, model)
|
82 |
+
# keywords = extract_keywords(gpt_response)
|
83 |
+
# return keywords, usage
|
84 |
+
# else:
|
85 |
+
# raise ValueError("`input_dict` must include the key 'title'.")
|
86 |
+
|
87 |
+
def keywords_generation(input_dict):
|
88 |
title = input_dict.get("title")
|
89 |
+
max_attempts = 10
|
90 |
+
attempts_count = 0
|
91 |
+
while attempts_count < max_attempts:
|
92 |
+
try:
|
93 |
+
keywords, usage= get_gpt_responses(KEYWORDS_SYSTEM.format(min_refs_num=3, max_refs_num=5), title,
|
94 |
+
model="gpt-3.5-turbo", temperature=0.4)
|
95 |
+
print(keywords)
|
96 |
+
output = json.loads(keywords)
|
97 |
+
return output, usage
|
98 |
+
except json.decoder.JSONDecodeError:
|
99 |
+
attempts_count += 1
|
100 |
+
time.sleep(20)
|
101 |
+
raise RuntimeError("Fail to generate keywords.")
|
102 |
|
103 |
def figures_generation(paper, save_to_path, model):
|
104 |
prompts = generate_experiments_prompts(paper)
|
utils/gpt_interaction.py
CHANGED
@@ -76,6 +76,22 @@ def get_responses(user_message, model="gpt-4", temperature=0.4, openai_key=None)
|
|
76 |
log.info(assistant_message)
|
77 |
return assistant_message, usage
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
if __name__ == "__main__":
|
81 |
test_strings = [r"f.write(r'hello world')", r"f.write(r'''hello world''')", r"f.write(r'''hello world",
|
|
|
76 |
log.info(assistant_message)
|
77 |
return assistant_message, usage
|
78 |
|
79 |
+
def get_gpt_responses(systems, prompts, model="gpt-4", temperature=0.4):
|
80 |
+
conversation_history = [
|
81 |
+
{"role": "system", "content": systems},
|
82 |
+
{"role": "user", "content": prompts}
|
83 |
+
]
|
84 |
+
response = openai.ChatCompletion.create(
|
85 |
+
model=model,
|
86 |
+
messages=conversation_history,
|
87 |
+
n=1, # Number of responses you want to generate
|
88 |
+
temperature=temperature, # Controls the creativity of the generated response
|
89 |
+
)
|
90 |
+
assistant_message = response['choices'][0]["message"]["content"]
|
91 |
+
usage = response['usage']
|
92 |
+
log.info(assistant_message)
|
93 |
+
return assistant_message, usage
|
94 |
+
|
95 |
|
96 |
if __name__ == "__main__":
|
97 |
test_strings = [r"f.write(r'hello world')", r"f.write(r'''hello world''')", r"f.write(r'''hello world",
|
utils/prompts.py
CHANGED
@@ -1,24 +1,13 @@
|
|
1 |
import logging
|
2 |
-
|
3 |
|
4 |
-
INSTRUCTIONS = {"introduction": "Please include five paragraph: Establishing the motivation for the research. Explaining its importance and relevance to the AI community. Clearly state the problem you're addressing, your proposed solution, and the specific research questions or objectives. Briefly mention key related work for context. Explain the main differences from your work. ",
|
5 |
-
"related works": r"Please discuss key publications, methods, and techniques in your research area. Analyze the strengths and weaknesses of existing methods, and present the related works in a logical manner, often chronologically. Consider using a taxonomy or categorization to structure the discussion. Do not use \section{...} or \subsection{...}; use \paragraph{...} instead. ",
|
6 |
-
"backgrounds": r"Please clearly state the central problem in this field. Explain the foundational theories, concepts, and principles that underpin your research using as many as mathematical formulas or equations (written in LaTeX). Introduce any necessary mathematical notations, equations, or algorithms that are central to this field (written them in LaTeX). Do not include \section{...} but you can have \subsection{...}. ",
|
7 |
-
"methodology": "Please read the paper I have written and write the methodology section with three subsections: Concisely describe the techniques, algorithms, and procedures employed to address the research problem (use as many as formulas written in LaTeX). Explain the rationale behind choosing these methods, and provide sufficient detail for replication (use as many as formulas written in LaTeX). Do not make any list steps; instead, just put them in the same paragraph with sufficient explainations. Do not include \section{...} but you can have \subsection{...}. ",
|
8 |
-
"results": "Please write the theoretical results section using LaTeX. Include theorem and corollary to support this paper (with formulas). Explain what assumptions are used and why they are standard and necessary. Do not include \section{...}. ",
|
9 |
-
"experiments": "Please write the experiment section using LaTeX. Include a table to compare with other methods and bold our method. Include one figure comparison.png; this figure compares the loss curve with other methods. Do not include \section{...}. ",
|
10 |
-
"conclusion": "Please read the paper I have written and write the conclusion section.",
|
11 |
-
"abstract": "Please read the paper I have written and write the abstract."}
|
12 |
-
|
13 |
-
INSTRUCTIONS["related works"] = r"Please discuss three to five main related fields to this paper. For each field, select " \
|
14 |
-
r"five to ten key publications from references. For each reference, analyze its strengths and weaknesses in one or two sentences. " \
|
15 |
-
r"Do not use \section{...} or \subsection{...}; use \paragraph{...} to list related fields. "
|
16 |
|
|
|
17 |
|
18 |
-
BG_INSTRUCTIONS = {"introduction": "Please include four paragraph: Establishing the motivation for this survey. Explaining its importance and relevance to the AI community. Clearly state the coverage of this survey and the specific research questions or objectives. Briefly mention key related work for context. ",
|
19 |
-
"related works": r"Please discuss key publications, methods, and techniques in related research area. Analyze the strengths and weaknesses of existing methods, and present the related works in a logical manner, often chronologically. Consider using a taxonomy or categorization to structure the discussion. Do not use \section{...} or \subsection{...}; use \paragraph{...} instead. ",
|
20 |
-
"backgrounds": r"Please clearly state the central problem in this field. Explain the foundational theories, concepts, and principles that underpin your research using as many as mathematical formulas or equations (written in LaTeX). Introduce any necessary mathematical notations, equations, or algorithms that are central to this field (written them in LaTeX). Do not include \section{...} but you can have \subsection{...}. ",}
|
21 |
|
|
|
|
|
|
|
22 |
def generate_keywords_prompts(title, description="", num_refs=5):
|
23 |
prompts = f"I am writing a machine learning paper with the title '{title}'. {description}\n" \
|
24 |
f"Generate three to five keywords. For each keyword, rate it from 1 to {num_refs}; the larger number means more important." \
|
@@ -39,6 +28,83 @@ def generate_experiments_prompts(paper_info):
|
|
39 |
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def generate_paper_prompts(paper_info, section):
|
43 |
title = paper_info["title"]
|
44 |
description = paper_info["description"]
|
@@ -47,34 +113,57 @@ def generate_paper_prompts(paper_info, section):
|
|
47 |
|
48 |
# fundamental_subprompt - describe the basic information of paper
|
49 |
# instruction_subprompt - tell AI what to do
|
50 |
-
#
|
51 |
# self_subprompt - give AI existing written parts
|
52 |
# output_subprompt - tell AI how to output
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
if section in ["introduction", "related works", "backgrounds"]:
|
70 |
# title + references + instruction
|
71 |
-
prompts =
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
# title + instruction + paper
|
77 |
-
prompts =
|
|
|
|
|
|
|
|
|
78 |
else:
|
79 |
raise NotImplementedError
|
80 |
|
@@ -82,6 +171,16 @@ def generate_paper_prompts(paper_info, section):
|
|
82 |
return prompts
|
83 |
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
def generate_bg_summary_prompts(paper_info, section):
|
86 |
title = paper_info["title"]
|
87 |
description = paper_info["description"]
|
|
|
1 |
import logging
|
2 |
+
from langchain import PromptTemplate
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
log = logging.getLogger(__name__)
|
6 |
|
|
|
|
|
|
|
7 |
|
8 |
+
######################################################################################################################
|
9 |
+
# Some basic functions
|
10 |
+
######################################################################################################################
|
11 |
def generate_keywords_prompts(title, description="", num_refs=5):
|
12 |
prompts = f"I am writing a machine learning paper with the title '{title}'. {description}\n" \
|
13 |
f"Generate three to five keywords. For each keyword, rate it from 1 to {num_refs}; the larger number means more important." \
|
|
|
28 |
|
29 |
|
30 |
|
31 |
+
######################################################################################################################
|
32 |
+
# System Message
|
33 |
+
######################################################################################################################
|
34 |
+
|
35 |
+
# two parameters: min_refs_num, max_refs_num
|
36 |
+
# keywords_system_template = """You are an assistant designed to provide accurate and informative keywords of searching academic papers.
|
37 |
+
# Instructions
|
38 |
+
# - Your response should always be a Python list; e.g. ["keyword1", "keyword2", "keyword3"]
|
39 |
+
# - The length of list should between {min_refs_num} and {max_refs_num}
|
40 |
+
# - Use specific phrases as keywords and avoid using too general words (e.g. machine learning)"""
|
41 |
+
keywords_system_template = """You are an assistant designed to provide accurate and informative keywords of searching academic papers.\n
|
42 |
+
Instructions:\n
|
43 |
+
- Your response should follow the following output format: ["field1", "field2", "field3", "field4"]\n
|
44 |
+
- The length of this Python list should between {min_refs_num} and {max_refs_num}."""
|
45 |
+
|
46 |
+
# two parameters: min_refs_num, max_refs_num
|
47 |
+
exp_methods_system_template = """You are an assistant designed to provide most related algorithms or methods to a given paper title.
|
48 |
+
Instructions
|
49 |
+
- Your response should always be a Python list; e.g. ["method_name_1", "method_name_2", "method_name_3"]
|
50 |
+
- The length of list should between {min_exps_num} and {max_exps_num}
|
51 |
+
- Use abbreviation to make each method's name have 5 characters or less."""
|
52 |
+
|
53 |
+
# one parameter: research_field
|
54 |
+
section_generation_system_template = r"""You are an assistant designed to write academic papers in the field of {research_field} using LaTeX.
|
55 |
+
Instructions
|
56 |
+
- Your response should be professional and in academic tone.
|
57 |
+
- Always give a high-level overview at the beginning of each section or subsection.
|
58 |
+
"""
|
59 |
+
|
60 |
+
KEYWORDS_SYSTEM = PromptTemplate(input_variables=["min_refs_num", "max_refs_num"],
|
61 |
+
template=keywords_system_template)
|
62 |
+
EXP_METHODS_SYSTEM = PromptTemplate(input_variables=["min_exps_num", "max_exps_num"],
|
63 |
+
template=exp_methods_system_template)
|
64 |
+
SECTION_GENERATION_SYSTEM = PromptTemplate(input_variables=["research_field"],
|
65 |
+
template=section_generation_system_template)
|
66 |
+
|
67 |
+
|
68 |
+
######################################################################################################################
|
69 |
+
# Academic Paper
|
70 |
+
######################################################################################################################
|
71 |
+
|
72 |
+
INSTRUCTIONS = {"introduction":
|
73 |
+
"- Include five paragraph: Establishing the motivation for the research. Explaining its importance and relevance to the AI community. Clearly state the problem you're addressing, your proposed solution, and the specific research questions or objectives. Briefly mention key related works for context and explain the main differences from this work. List three novel contributions of this paper.",
|
74 |
+
"results":
|
75 |
+
"Write the theoretical results section using LaTeX. Include theorem and corollary to support this paper (with formulas). Explain what assumptions are used and why they are standard and necessary. Do not include \section{...}. ",
|
76 |
+
"conclusion":
|
77 |
+
"- Read the existing parts of paper and write the conclusion section.",
|
78 |
+
"abstract":
|
79 |
+
"- Read the existing parts of paper and write the abstract."}
|
80 |
+
|
81 |
+
|
82 |
+
INSTRUCTIONS["backgrounds"] = "- Start from one high-level paragraph to state the central problem in this field with detailed examples in industrial applications and theoretical challenges. \n" \
|
83 |
+
"- Followed by two to three subsections: Explain the foundational concepts and notations that underpin your research using as many as mathematical formulas (written in LaTeX). " \
|
84 |
+
"Introduce more necessary mathematical notations, equations, or algorithms that are connected to this work. Present detailed discussions on how these concepts are applied in this paper."
|
85 |
+
|
86 |
+
|
87 |
+
INSTRUCTIONS["related works"] = r"- Discuss three to five main related fields to this paper. " \
|
88 |
+
r"For each field, select five to ten key publications from references. " \
|
89 |
+
r"For each reference, analyze its strengths and weaknesses in one or two sentences. " \
|
90 |
+
r"Present the related works in a logical manner, often chronologically. " \
|
91 |
+
r"Consider using a taxonomy or categorization to structure the discussion. " \
|
92 |
+
r"Do not use \section{...} or \subsection{...}; use \paragraph{...} to list related fields. "
|
93 |
+
|
94 |
+
INSTRUCTIONS["methodology"] = "- Provide a high-level overview of the proposed method at the beginning of this section. \n " \
|
95 |
+
"- Assume you have some figures ('fig1.png', 'fig2.png', ...); they can be any figures you need (e.g. flow chart, model architecture, sample output, simulation result, or others you need). Insert figures you need with informative caption. \n" \
|
96 |
+
"- Use one subsection to give a detailed formulation of the proposed method and explain how it overcomes the weakness of existing methods mentioned in this paper. " \
|
97 |
+
" If necessary, write pseudo codes wrapped by \\begin{{algorithm}} ... \\end{{algorithm}} to explain the detailed steps instead of simply listing them. \n" \
|
98 |
+
"- Use one follow-up subsection to highlight the key concepts in the proposed method. " \
|
99 |
+
" Elaborate the novelty of these key concepts using formulas and inserting appropriate figures. \n" \
|
100 |
+
"- Ensure the name of each subsection to be specific. \n"
|
101 |
+
|
102 |
+
INSTRUCTIONS["experiments"] = "- Provide a high-level overview at the beginning of this section.\n " \
|
103 |
+
"- If necessary, include a table to compare with other methods and bold our method.\n" \
|
104 |
+
"- Assume you have some figures ('exp1.png', 'exp2.png', ...); they can be any figures you need (e.g. loss curves, comparison with other methods, visualization, or others you need). Insert figures you need with informative caption. \n" \
|
105 |
+
"- If necessary, use different subsections to distinguish different experimental setup."
|
106 |
+
|
107 |
+
|
108 |
def generate_paper_prompts(paper_info, section):
|
109 |
title = paper_info["title"]
|
110 |
description = paper_info["description"]
|
|
|
113 |
|
114 |
# fundamental_subprompt - describe the basic information of paper
|
115 |
# instruction_subprompt - tell AI what to do
|
116 |
+
# ref_instruction_subprompt - give AI references
|
117 |
# self_subprompt - give AI existing written parts
|
118 |
# output_subprompt - tell AI how to output
|
119 |
+
fundamental_subprompt = "Your task is to write the {section} section of the machine learning paper with the title '{title}'. {description}\n"
|
120 |
+
instruction_subprompt = "\n" \
|
121 |
+
"Your response should follow the following instructions:\n" \
|
122 |
+
"{instruction}\n" \
|
123 |
+
"- Start with \section{{{section}}}\n"
|
124 |
+
ref_instruction_subprompt = "- Read references. " \
|
125 |
+
"Every time you use information from the references, you need to appropriately cite it (using \citep or \citet)." \
|
126 |
+
"For example of \citep, the sentence where you use information from lei2022adaptive \citep{{lei2022adaptive}}. " \
|
127 |
+
"For example of \citet, \citet{{lei2022adaptive}} claims some information.\n" \
|
128 |
+
"- Avoid citing the same reference in a same paragraph.\n" \
|
129 |
+
"\n" \
|
130 |
+
"References:\n" \
|
131 |
+
"{references}"
|
132 |
+
self_subprompt = "The existing parts of this paper is provided here: {paper}.\n"
|
133 |
+
output_subprompt = "Your response should start with \section{{{section}}}. Ensure that it can be directly compiled by LeTaX."
|
134 |
+
abstract_output_subprompt = "Your response should start with \\begin{{abstract}} and should end with \\end{{abstract}}. Ensure that it can be directly compiled by LeTaX."
|
135 |
+
|
136 |
+
reivew_prompts = PromptTemplate(
|
137 |
+
input_variables=["title", "description", "instruction", "section", "references"],
|
138 |
+
template=fundamental_subprompt + instruction_subprompt + ref_instruction_subprompt + output_subprompt)
|
139 |
+
summarization_prompts = PromptTemplate(
|
140 |
+
input_variables=["title", "description", "instruction", "section", "paper"],
|
141 |
+
template=fundamental_subprompt + instruction_subprompt + self_subprompt + output_subprompt)
|
142 |
+
abstract_prompts = PromptTemplate(
|
143 |
+
input_variables=["title", "description", "instruction", "section", "paper"],
|
144 |
+
template=fundamental_subprompt + instruction_subprompt + self_subprompt + abstract_output_subprompt)
|
145 |
|
146 |
if section in ["introduction", "related works", "backgrounds"]:
|
147 |
# title + references + instruction
|
148 |
+
prompts = reivew_prompts.format(title=title,
|
149 |
+
description=description,
|
150 |
+
instruction=INSTRUCTIONS[section],
|
151 |
+
section=section,
|
152 |
+
references=references)
|
153 |
+
elif section in ["abstract"]:
|
154 |
+
# title + instruction + paper
|
155 |
+
prompts = abstract_prompts.format(title=title,
|
156 |
+
description=description,
|
157 |
+
instruction=INSTRUCTIONS[section],
|
158 |
+
section=section,
|
159 |
+
paper=paper)
|
160 |
+
elif section in ["methodology", "experiments", "conclusion"]:
|
161 |
# title + instruction + paper
|
162 |
+
prompts = summarization_prompts.format(title=title,
|
163 |
+
description=description,
|
164 |
+
instruction=INSTRUCTIONS[section],
|
165 |
+
section=section,
|
166 |
+
paper=paper)
|
167 |
else:
|
168 |
raise NotImplementedError
|
169 |
|
|
|
171 |
return prompts
|
172 |
|
173 |
|
174 |
+
######################################################################################################################
|
175 |
+
# Literature Review
|
176 |
+
######################################################################################################################
|
177 |
+
|
178 |
+
BG_INSTRUCTIONS = {"introduction": "Please include four paragraph: Establishing the motivation for this survey. Explaining its importance and relevance to the AI community. Clearly state the coverage of this survey and the specific research questions or objectives. Briefly mention key related work for context. ",
|
179 |
+
"related works": r"Please discuss key publications, methods, and techniques in related research area. Analyze the strengths and weaknesses of existing methods, and present the related works in a logical manner, often chronologically. Consider using a taxonomy or categorization to structure the discussion. Do not use \section{...} or \subsection{...}; use \paragraph{...} instead. ",
|
180 |
+
"backgrounds": r"Please clearly state the central problem in this field. Explain the foundational theories, concepts, and principles that underpin your research using as many as mathematical formulas or equations (written in LaTeX). Introduce any necessary mathematical notations, equations, or algorithms that are central to this field (written them in LaTeX). Do not include \section{...} but you can have \subsection{...}. ",}
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
def generate_bg_summary_prompts(paper_info, section):
|
185 |
title = paper_info["title"]
|
186 |
description = paper_info["description"]
|
utils/references.py
CHANGED
@@ -1,17 +1,27 @@
|
|
1 |
# Each `paper` is a dictionary containing:
|
2 |
-
# (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal
|
3 |
#
|
4 |
# Generate references:
|
5 |
# `Reference` class:
|
6 |
# 1. Read a given .bib file to collect papers; use `search_paper_abstract` method to fill missing abstract.
|
7 |
-
# 2. Given some keywords; use
|
8 |
# 3. Generate bibtex from the selected papers. --> to_bibtex()
|
9 |
# 4. Generate prompts from the selected papers: --> to_prompts()
|
10 |
# A sample prompt: {"paper_id": "paper summary"}
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import requests
|
13 |
import re
|
14 |
import bibtexparser
|
|
|
15 |
from scholarly import scholarly
|
16 |
from scholarly import ProxyGenerator
|
17 |
|
@@ -31,11 +41,14 @@ def remove_newlines(serie):
|
|
31 |
def search_paper_abstract(title):
|
32 |
pg = ProxyGenerator()
|
33 |
success = pg.ScraperAPI("921b16f94d701308b9d9b4456ddde155")
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
def load_papers_from_bibtex(bib_file_path):
|
@@ -46,6 +59,7 @@ def load_papers_from_bibtex(bib_file_path):
|
|
46 |
else:
|
47 |
bib_papers = []
|
48 |
for bibitem in bib_database.entries:
|
|
|
49 |
paper_id = bibitem.get("ID")
|
50 |
title = bibitem.get("title")
|
51 |
if title is None:
|
@@ -68,7 +82,6 @@ def load_papers_from_bibtex(bib_file_path):
|
|
68 |
bib_papers.append(result)
|
69 |
return bib_papers
|
70 |
|
71 |
-
|
72 |
######################################################################################################################
|
73 |
# Semantic Scholar (SS) API
|
74 |
######################################################################################################################
|
@@ -124,6 +137,7 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
|
|
124 |
authors_str = " and ".join(authors)
|
125 |
try:
|
126 |
last_name = authors[0].split()[-1]
|
|
|
127 |
except IndexError:
|
128 |
last_name = "ma"
|
129 |
# pattern = r'^\w+'
|
@@ -131,6 +145,9 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
|
|
131 |
return authors_str, last_name
|
132 |
|
133 |
def parse_search_results(search_results_ss):
|
|
|
|
|
|
|
134 |
# turn the search result to a list of paper dictionary.
|
135 |
papers_ss = []
|
136 |
for raw_paper in search_results_ss:
|
@@ -140,16 +157,26 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
|
|
140 |
authors_str, last_name = extract_author_info(raw_paper['authors'])
|
141 |
year_str = str(raw_paper['year'])
|
142 |
title = raw_paper['title']
|
|
|
143 |
# some journal may contain &; replace it. e.g. journal={IEEE Power & Energy Society General Meeting}
|
144 |
journal = raw_paper['venue'].replace("&", "\\&")
|
145 |
if not journal:
|
146 |
journal = "arXiv preprint"
|
|
|
147 |
paper_id = extract_paper_id(last_name, year_str, title).lower()
|
148 |
link = externalIds2link(raw_paper['externalIds'])
|
|
|
149 |
if tldr and raw_paper['tldr'] is not None:
|
150 |
abstract = raw_paper['tldr']['text']
|
151 |
else:
|
152 |
abstract = remove_newlines(raw_paper['abstract'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
result = {
|
154 |
"paper_id": paper_id,
|
155 |
"title": title,
|
@@ -157,134 +184,64 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
|
|
157 |
"link": link,
|
158 |
"authors": authors_str,
|
159 |
"year": year_str,
|
160 |
-
"journal": journal
|
|
|
161 |
}
|
162 |
papers_ss.append(result)
|
163 |
return papers_ss
|
164 |
|
165 |
raw_results = ss_search(keyword, limit=counts)
|
166 |
if raw_results is not None:
|
167 |
-
search_results = raw_results
|
|
|
|
|
168 |
else:
|
169 |
search_results = []
|
170 |
results = parse_search_results(search_results)
|
171 |
return results
|
172 |
|
173 |
-
|
174 |
-
######################################################################################################################
|
175 |
-
# ArXiv API
|
176 |
-
######################################################################################################################
|
177 |
-
def _collect_papers_arxiv(keyword, counts=3, tldr=False):
|
178 |
-
# Build the arXiv API query URL with the given keyword and other parameters
|
179 |
-
def build_query_url(keyword, results_limit=3, sort_by="relevance", sort_order="descending"):
|
180 |
-
base_url = "http://export.arxiv.org/api/query?"
|
181 |
-
query = f"search_query=all:{keyword}&start=0&max_results={results_limit}"
|
182 |
-
query += f"&sortBy={sort_by}&sortOrder={sort_order}"
|
183 |
-
return base_url + query
|
184 |
-
|
185 |
-
# Fetch search results from the arXiv API using the constructed URL
|
186 |
-
def fetch_search_results(query_url):
|
187 |
-
response = requests.get(query_url)
|
188 |
-
return response.text
|
189 |
-
|
190 |
-
# Parse the XML content of the API response to extract paper information
|
191 |
-
def parse_results(content):
|
192 |
-
from xml.etree import ElementTree as ET
|
193 |
-
|
194 |
-
root = ET.fromstring(content)
|
195 |
-
namespace = "{http://www.w3.org/2005/Atom}"
|
196 |
-
entries = root.findall(f"{namespace}entry")
|
197 |
-
|
198 |
-
results = []
|
199 |
-
for entry in entries:
|
200 |
-
title = entry.find(f"{namespace}title").text
|
201 |
-
link = entry.find(f"{namespace}id").text
|
202 |
-
summary = entry.find(f"{namespace}summary").text
|
203 |
-
summary = remove_newlines(summary)
|
204 |
-
|
205 |
-
# Extract the authors
|
206 |
-
authors = entry.findall(f"{namespace}author")
|
207 |
-
author_list = []
|
208 |
-
for author in authors:
|
209 |
-
name = author.find(f"{namespace}name").text
|
210 |
-
author_list.append(name)
|
211 |
-
authors_str = " and ".join(author_list)
|
212 |
-
|
213 |
-
# Extract the year
|
214 |
-
published = entry.find(f"{namespace}published").text
|
215 |
-
year = published.split("-")[0]
|
216 |
-
|
217 |
-
founds = re.search(r'\d+\.\d+', link)
|
218 |
-
if founds is None:
|
219 |
-
# some links are not standard; such as "https://arxiv.org/abs/cs/0603127v1".
|
220 |
-
# will be solved in the future.
|
221 |
-
continue
|
222 |
-
else:
|
223 |
-
arxiv_id = founds.group(0)
|
224 |
-
journal = f"arXiv preprint arXiv:{arxiv_id}"
|
225 |
-
result = {
|
226 |
-
"paper_id": arxiv_id,
|
227 |
-
"title": title,
|
228 |
-
"link": link,
|
229 |
-
"abstract": summary,
|
230 |
-
"authors": authors_str,
|
231 |
-
"year": year,
|
232 |
-
"journal": journal
|
233 |
-
}
|
234 |
-
results.append(result)
|
235 |
-
|
236 |
-
return results
|
237 |
-
|
238 |
-
query_url = build_query_url(keyword, counts)
|
239 |
-
content = fetch_search_results(query_url)
|
240 |
-
results = parse_results(content)
|
241 |
-
return results
|
242 |
-
|
243 |
-
|
244 |
######################################################################################################################
|
245 |
# References Class
|
246 |
######################################################################################################################
|
247 |
|
248 |
class References:
|
249 |
-
def __init__(self, load_papers
|
250 |
-
if load_papers:
|
251 |
-
|
252 |
-
|
253 |
-
# (3) may use langchain to support long input
|
254 |
-
self.papers = load_papers_from_bibtex(load_papers)
|
255 |
else:
|
256 |
-
self.papers =
|
|
|
|
|
|
|
|
|
257 |
|
258 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
"""
|
260 |
keywords_dict:
|
261 |
{"machine learning": 5, "language model": 2};
|
262 |
the first is the keyword, the second is how many references are needed.
|
263 |
"""
|
264 |
-
match method:
|
265 |
-
case "arxiv":
|
266 |
-
process = _collect_papers_arxiv
|
267 |
-
case "ss":
|
268 |
-
process = _collect_papers_ss
|
269 |
-
case _:
|
270 |
-
raise NotImplementedError("Other sources have not been not supported yet.")
|
271 |
for key, counts in keywords_dict.items():
|
272 |
-
self.papers =
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
for paper in self.papers:
|
277 |
-
paper_id = paper["paper_id"]
|
278 |
-
if paper_id not in seen:
|
279 |
-
seen.add(paper_id)
|
280 |
-
papers.append(paper)
|
281 |
-
self.papers = papers
|
282 |
-
|
283 |
-
def to_bibtex(self, path_to_bibtex="ref.bib"):
|
284 |
"""
|
285 |
Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`.
|
286 |
"""
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
# clear the bibtex file
|
290 |
with open(path_to_bibtex, "w", encoding="utf-8") as file:
|
@@ -292,7 +249,12 @@ class References:
|
|
292 |
|
293 |
bibtex_entries = []
|
294 |
paper_ids = []
|
|
|
295 |
for paper in papers:
|
|
|
|
|
|
|
|
|
296 |
bibtex_entry = f"""@article{{{paper["paper_id"]},
|
297 |
title = {{{paper["title"]}}},
|
298 |
author = {{{paper["authors"]}}},
|
@@ -308,31 +270,69 @@ class References:
|
|
308 |
file.write("\n\n")
|
309 |
return paper_ids
|
310 |
|
311 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
# `prompts`:
|
313 |
# {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"}
|
314 |
# this will be used to instruct GPT model to cite the correct bibtex entry.
|
|
|
315 |
prompts = {}
|
316 |
-
for paper in
|
317 |
prompts[paper["paper_id"]] = paper["abstract"]
|
318 |
return prompts
|
319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
if __name__ == "__main__":
|
322 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
# keywords_dict = {
|
324 |
-
# "Deep Q-Networks":
|
325 |
-
# "Policy Gradient Methods": 24,
|
326 |
# "Actor-Critic Algorithms": 4,
|
327 |
-
# "
|
328 |
-
# "Exploration-Exploitation Trade-off": 7
|
329 |
# }
|
330 |
-
# refs.collect_papers(keywords_dict,
|
331 |
-
# for
|
332 |
-
#
|
333 |
-
#
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Each `paper` is a dictionary containing:
|
2 |
+
# (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal (8) embeddings
|
3 |
#
|
4 |
# Generate references:
|
5 |
# `Reference` class:
|
6 |
# 1. Read a given .bib file to collect papers; use `search_paper_abstract` method to fill missing abstract.
|
7 |
+
# 2. Given some keywords; use Semantic Scholar API to find papers.
|
8 |
# 3. Generate bibtex from the selected papers. --> to_bibtex()
|
9 |
# 4. Generate prompts from the selected papers: --> to_prompts()
|
10 |
# A sample prompt: {"paper_id": "paper summary"}
|
11 |
|
12 |
+
# todo: (1) citations & citedby of provided papers:
|
13 |
+
# load the pre-defined papers; use S2 to find all related works
|
14 |
+
# add all citations to `bib_papers`
|
15 |
+
# add all citedby to `bib_papers`
|
16 |
+
# use Semantic Scholar to find their embeddings
|
17 |
+
# (2) separate references:
|
18 |
+
# divide references into different groups to reduce the tokens count
|
19 |
+
# for generating different paragraph of related works, use different set of references
|
20 |
+
|
21 |
import requests
|
22 |
import re
|
23 |
import bibtexparser
|
24 |
+
import random
|
25 |
from scholarly import scholarly
|
26 |
from scholarly import ProxyGenerator
|
27 |
|
|
|
41 |
def search_paper_abstract(title):
|
42 |
pg = ProxyGenerator()
|
43 |
success = pg.ScraperAPI("921b16f94d701308b9d9b4456ddde155")
|
44 |
+
if success:
|
45 |
+
scholarly.use_proxy(pg)
|
46 |
+
# input the title of a paper, return its abstract
|
47 |
+
search_query = scholarly.search_pubs(title)
|
48 |
+
found_paper = next(search_query)
|
49 |
+
else:
|
50 |
+
raise RuntimeError("ScraperAPI fails.")
|
51 |
+
return remove_newlines(found_paper['bib']['abstract'])
|
52 |
|
53 |
|
54 |
def load_papers_from_bibtex(bib_file_path):
|
|
|
59 |
else:
|
60 |
bib_papers = []
|
61 |
for bibitem in bib_database.entries:
|
62 |
+
# Add each paper to `bib_papers`
|
63 |
paper_id = bibitem.get("ID")
|
64 |
title = bibitem.get("title")
|
65 |
if title is None:
|
|
|
82 |
bib_papers.append(result)
|
83 |
return bib_papers
|
84 |
|
|
|
85 |
######################################################################################################################
|
86 |
# Semantic Scholar (SS) API
|
87 |
######################################################################################################################
|
|
|
137 |
authors_str = " and ".join(authors)
|
138 |
try:
|
139 |
last_name = authors[0].split()[-1]
|
140 |
+
last_name = last_name.replace("'", "")
|
141 |
except IndexError:
|
142 |
last_name = "ma"
|
143 |
# pattern = r'^\w+'
|
|
|
145 |
return authors_str, last_name
|
146 |
|
147 |
def parse_search_results(search_results_ss):
|
148 |
+
if len(search_results_ss) == 0:
|
149 |
+
return []
|
150 |
+
|
151 |
# turn the search result to a list of paper dictionary.
|
152 |
papers_ss = []
|
153 |
for raw_paper in search_results_ss:
|
|
|
157 |
authors_str, last_name = extract_author_info(raw_paper['authors'])
|
158 |
year_str = str(raw_paper['year'])
|
159 |
title = raw_paper['title']
|
160 |
+
|
161 |
# some journal may contain &; replace it. e.g. journal={IEEE Power & Energy Society General Meeting}
|
162 |
journal = raw_paper['venue'].replace("&", "\\&")
|
163 |
if not journal:
|
164 |
journal = "arXiv preprint"
|
165 |
+
|
166 |
paper_id = extract_paper_id(last_name, year_str, title).lower()
|
167 |
link = externalIds2link(raw_paper['externalIds'])
|
168 |
+
|
169 |
if tldr and raw_paper['tldr'] is not None:
|
170 |
abstract = raw_paper['tldr']['text']
|
171 |
else:
|
172 |
abstract = remove_newlines(raw_paper['abstract'])
|
173 |
+
|
174 |
+
# some papers have no embeddings; handle this case
|
175 |
+
embeddings_dict = raw_paper.get('embedding')
|
176 |
+
if embeddings_dict is None:
|
177 |
+
continue
|
178 |
+
else:
|
179 |
+
embeddings = raw_paper['embedding']['vector']
|
180 |
result = {
|
181 |
"paper_id": paper_id,
|
182 |
"title": title,
|
|
|
184 |
"link": link,
|
185 |
"authors": authors_str,
|
186 |
"year": year_str,
|
187 |
+
"journal": journal,
|
188 |
+
"embeddings": embeddings
|
189 |
}
|
190 |
papers_ss.append(result)
|
191 |
return papers_ss
|
192 |
|
193 |
raw_results = ss_search(keyword, limit=counts)
|
194 |
if raw_results is not None:
|
195 |
+
search_results = raw_results.get("data")
|
196 |
+
if search_results is None:
|
197 |
+
search_results = []
|
198 |
else:
|
199 |
search_results = []
|
200 |
results = parse_search_results(search_results)
|
201 |
return results
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
######################################################################################################################
|
204 |
# References Class
|
205 |
######################################################################################################################
|
206 |
|
207 |
class References:
|
208 |
+
def __init__(self, title, load_papers):
|
209 |
+
if load_papers is not None:
|
210 |
+
self.papers = {}
|
211 |
+
self.papers["customized_refs"] = load_papers_from_bibtex(load_papers)
|
|
|
|
|
212 |
else:
|
213 |
+
self.papers = {}
|
214 |
+
self.title = title
|
215 |
+
|
216 |
+
def load_papers(self, bibtex, keyword):
|
217 |
+
self.papers[keyword] = load_papers_from_bibtex(bibtex)
|
218 |
|
219 |
+
def generate_keywords_dict(self):
|
220 |
+
keywords_dict = {}
|
221 |
+
for k in self.papers:
|
222 |
+
keywords_dict[k] = len(self.papers[k])
|
223 |
+
return keywords_dict
|
224 |
+
|
225 |
+
def collect_papers(self, keywords_dict, tldr=False):
|
226 |
"""
|
227 |
keywords_dict:
|
228 |
{"machine learning": 5, "language model": 2};
|
229 |
the first is the keyword, the second is how many references are needed.
|
230 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
for key, counts in keywords_dict.items():
|
232 |
+
self.papers[key] = _collect_papers_ss(key, counts, tldr)
|
233 |
|
234 |
+
|
235 |
+
def to_bibtex(self, path_to_bibtex="ref.bib", max_num_refs=50):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
"""
|
237 |
Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`.
|
238 |
"""
|
239 |
+
# todo:
|
240 |
+
# use embeddings to evaluate; keep top k relevant references in papers
|
241 |
+
# send (title, .bib file) to evaluate embeddings; recieve truncated papers
|
242 |
+
papers = self._get_papers(keyword = "_all")
|
243 |
+
random.shuffle(papers)
|
244 |
+
papers = papers[:max_num_refs]
|
245 |
|
246 |
# clear the bibtex file
|
247 |
with open(path_to_bibtex, "w", encoding="utf-8") as file:
|
|
|
249 |
|
250 |
bibtex_entries = []
|
251 |
paper_ids = []
|
252 |
+
seen = set()
|
253 |
for paper in papers:
|
254 |
+
if paper["paper_id"] in seen:
|
255 |
+
continue
|
256 |
+
else:
|
257 |
+
seen.add(paper["paper_id"])
|
258 |
bibtex_entry = f"""@article{{{paper["paper_id"]},
|
259 |
title = {{{paper["title"]}}},
|
260 |
author = {{{paper["authors"]}}},
|
|
|
270 |
file.write("\n\n")
|
271 |
return paper_ids
|
272 |
|
273 |
+
def _get_papers(self, keyword = "_all"):
|
274 |
+
if keyword == "_all":
|
275 |
+
papers = []
|
276 |
+
for k, v in self.papers.items():
|
277 |
+
papers = papers + v
|
278 |
+
else:
|
279 |
+
papers = self.papers["keyword"]
|
280 |
+
return papers
|
281 |
+
|
282 |
+
def to_prompts(self, keyword = "_all"):
|
283 |
# `prompts`:
|
284 |
# {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"}
|
285 |
# this will be used to instruct GPT model to cite the correct bibtex entry.
|
286 |
+
papers = self._get_papers(keyword)
|
287 |
prompts = {}
|
288 |
+
for paper in papers:
|
289 |
prompts[paper["paper_id"]] = paper["abstract"]
|
290 |
return prompts
|
291 |
|
292 |
+
def to_json(self, keyword = "_all"):
|
293 |
+
papers = self._get_papers(keyword)
|
294 |
+
papers_json = {}
|
295 |
+
for paper in papers:
|
296 |
+
papers_json[paper["paper_id"]] = paper
|
297 |
+
return papers_json
|
298 |
+
|
299 |
+
|
300 |
|
301 |
if __name__ == "__main__":
|
302 |
+
# testing search results
|
303 |
+
r = ss_search("Deep Q-Networks", limit=1) # a list of raw papers
|
304 |
+
if r['total'] > 0:
|
305 |
+
paper = r['data'][0]
|
306 |
+
# print(paper)
|
307 |
+
|
308 |
+
# resting References
|
309 |
+
refs = References()
|
310 |
# keywords_dict = {
|
311 |
+
# "Deep Q-Networks": 5,
|
|
|
312 |
# "Actor-Critic Algorithms": 4,
|
313 |
+
# "Exploration-Exploitation Trade-off": 3
|
|
|
314 |
# }
|
315 |
+
# refs.collect_papers(keywords_dict, tldr=True)
|
316 |
+
# for k in refs.papers:
|
317 |
+
# papers = refs.papers[k] # for each keyword, there is a list of papers
|
318 |
+
# print("keyword: ", k)
|
319 |
+
# for paper in papers:
|
320 |
+
# print(paper["paper_id"])
|
321 |
+
#
|
322 |
+
# refs.to_bibtex()
|
323 |
+
# papers_json = refs.to_json() # this json can be used to find the most relevant papers
|
324 |
+
# with open("papers.json", "w", encoding='utf-8') as text_file:
|
325 |
+
# text_file.write(f"{papers_json}")
|
326 |
+
#
|
327 |
+
# prompts = refs.to_prompts()
|
328 |
+
# print(prompts)
|
329 |
+
|
330 |
+
bib = "test.bib"
|
331 |
+
refs.load_papers(bib, "variance-reduction rl")
|
332 |
+
print(refs.papers)
|
333 |
+
|
334 |
+
prompts = refs.to_prompts()
|
335 |
+
for k in prompts:
|
336 |
+
print(f"{k}: {prompts[k]}\n")
|
337 |
+
# for paper in papers:
|
338 |
+
# print(paper)
|
utils/tex_processing.py
CHANGED
@@ -19,10 +19,11 @@ def replace_title(save_to_path, title):
|
|
19 |
|
20 |
# check if citations are in bibtex.
|
21 |
|
22 |
-
|
23 |
# replace citations
|
24 |
|
25 |
# sometimes the output may include thebibliography and bibitem . remove all of it.
|
26 |
|
|
|
|
|
27 |
|
28 |
|
|
|
19 |
|
20 |
# check if citations are in bibtex.
|
21 |
|
|
|
22 |
# replace citations
|
23 |
|
24 |
# sometimes the output may include thebibliography and bibitem . remove all of it.
|
25 |
|
26 |
+
# return all .png and replace it using placeholder.
|
27 |
+
|
28 |
|
29 |
|