Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
•
efe0924
1
Parent(s):
2204f8f
Add application file and dependencies
Browse files- LICENSE +201 -0
- app.py +1513 -0
- client_test.py +121 -0
- finetune.py +930 -0
- h2o-logo.svg +1 -0
- prompter.py +106 -0
- requirements.txt +44 -0
- stopping.py +139 -0
- utils.py +39 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
ADDED
@@ -0,0 +1,1513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import inspect
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import traceback
|
6 |
+
import typing
|
7 |
+
|
8 |
+
from utils import set_seed, flatten_list, clear_torch_cache
|
9 |
+
|
10 |
+
SEED = 1236
|
11 |
+
set_seed(SEED)
|
12 |
+
|
13 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
14 |
+
from typing import Union
|
15 |
+
import numpy as np
|
16 |
+
import pandas as pd
|
17 |
+
|
18 |
+
import fire
|
19 |
+
import torch
|
20 |
+
from peft import PeftModel
|
21 |
+
from transformers import GenerationConfig, StoppingCriteriaList, AutoModel
|
22 |
+
from accelerate import init_empty_weights, infer_auto_device_map
|
23 |
+
|
24 |
+
from prompter import Prompter
|
25 |
+
|
26 |
+
from finetune import get_loaders, example_data_points, generate_prompt, get_githash, prompt_types_strings, \
|
27 |
+
human, bot, prompt_type_to_model_name, inv_prompt_type_to_model_lower
|
28 |
+
from stopping import CallbackToGenerator, Stream, StoppingCriteriaSub
|
29 |
+
|
30 |
+
|
31 |
+
def main(
|
32 |
+
load_8bit: bool = False,
|
33 |
+
load_half: bool = True,
|
34 |
+
infer_devices: bool = True,
|
35 |
+
base_model: str = '',
|
36 |
+
tokenizer_base_model: str = '',
|
37 |
+
lora_weights: str = "",
|
38 |
+
force_1_gpu: bool = True,
|
39 |
+
|
40 |
+
prompt_type: Union[int, str] = None,
|
41 |
+
# input to generation
|
42 |
+
temperature: float = None,
|
43 |
+
top_p: float = None,
|
44 |
+
top_k: int = None,
|
45 |
+
num_beams: int = None,
|
46 |
+
repetition_penalty: float = None,
|
47 |
+
num_return_sequences: int = None,
|
48 |
+
do_sample: bool = None,
|
49 |
+
max_new_tokens: int = None,
|
50 |
+
min_new_tokens: int = None,
|
51 |
+
early_stopping: Union[bool, str] = None,
|
52 |
+
max_time: float = None,
|
53 |
+
|
54 |
+
llama_type: bool = None,
|
55 |
+
debug: bool = False,
|
56 |
+
share: bool = True,
|
57 |
+
local_files_only: bool = False,
|
58 |
+
resume_download: bool = True,
|
59 |
+
use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
|
60 |
+
|
61 |
+
src_lang: str = "English",
|
62 |
+
tgt_lang: str = "Russian",
|
63 |
+
|
64 |
+
gradio: bool = True,
|
65 |
+
gradio_avoid_processing_markdown: bool = True,
|
66 |
+
chat: bool = True,
|
67 |
+
chat_history: int = 4096, # character length of chat context/history
|
68 |
+
stream_output: bool = True,
|
69 |
+
show_examples: bool = None,
|
70 |
+
verbose: bool = False,
|
71 |
+
h2ocolors: bool = True,
|
72 |
+
height: int = 400,
|
73 |
+
show_lora: bool = True,
|
74 |
+
# set to True to load --base_model after client logs in,
|
75 |
+
# to be able to free GPU memory when model is swapped
|
76 |
+
login_mode_if_model0: bool = False,
|
77 |
+
|
78 |
+
sanitize_user_prompt: bool = True,
|
79 |
+
sanitize_bot_response: bool = True,
|
80 |
+
|
81 |
+
extra_model_options: typing.List[str] = [],
|
82 |
+
extra_lora_options: typing.List[str] = [],
|
83 |
+
|
84 |
+
score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
|
85 |
+
auto_score: bool = True,
|
86 |
+
|
87 |
+
eval_sharegpt_prompts_only: int = 0,
|
88 |
+
eval_sharegpt_prompts_only_seed: int = 1234,
|
89 |
+
eval_sharegpt_as_output: bool = False,
|
90 |
+
):
|
91 |
+
# allow set token directly
|
92 |
+
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
93 |
+
# override share if in spaces
|
94 |
+
if os.environ.get("HUGGINGFACE_SPACES"):
|
95 |
+
share = False
|
96 |
+
base_model = 'h2oai/h2ogpt-oasst1-512-12b'
|
97 |
+
load_8bit = True
|
98 |
+
|
99 |
+
# get defaults
|
100 |
+
model_lower = base_model.lower()
|
101 |
+
if not gradio:
|
102 |
+
# force, else not single response like want to look at
|
103 |
+
stream_output = False
|
104 |
+
# else prompt removal can mess up output
|
105 |
+
chat = False
|
106 |
+
|
107 |
+
placeholder_instruction, placeholder_input, \
|
108 |
+
stream_output, show_examples, \
|
109 |
+
prompt_type, temperature, top_p, top_k, num_beams, \
|
110 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time, \
|
111 |
+
repetition_penalty, num_return_sequences, \
|
112 |
+
do_sample, \
|
113 |
+
src_lang, tgt_lang, \
|
114 |
+
examples, \
|
115 |
+
task_info = \
|
116 |
+
get_generate_params(model_lower, chat,
|
117 |
+
stream_output, show_examples,
|
118 |
+
prompt_type, temperature, top_p, top_k, num_beams,
|
119 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time,
|
120 |
+
repetition_penalty, num_return_sequences,
|
121 |
+
do_sample,
|
122 |
+
)
|
123 |
+
|
124 |
+
if not gradio:
|
125 |
+
if eval_sharegpt_prompts_only > 0:
|
126 |
+
# override default examples with shareGPT ones for human-level eval purposes only
|
127 |
+
filename = 'ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json'
|
128 |
+
if not os.path.isfile(filename):
|
129 |
+
os.system('wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
|
130 |
+
import json
|
131 |
+
data = json.load(open(filename, 'rt'))
|
132 |
+
# focus on data that starts with human, else likely chopped from other data
|
133 |
+
turn_start = 0 # odd in general
|
134 |
+
data = [x for x in data if len(x['conversations']) > turn_start + 1 and
|
135 |
+
x['conversations'][turn_start]['from'] == 'human' and
|
136 |
+
x['conversations'][turn_start + 1]['from'] == 'gpt']
|
137 |
+
np.random.seed(eval_sharegpt_prompts_only_seed)
|
138 |
+
example1 = examples[-1] # pick reference example
|
139 |
+
examples = []
|
140 |
+
responses = []
|
141 |
+
for i in list(np.random.randint(0, len(data), size=eval_sharegpt_prompts_only)):
|
142 |
+
assert data[i]['conversations'][turn_start]['from'] == 'human'
|
143 |
+
instruction = data[i]['conversations'][turn_start]['value']
|
144 |
+
assert data[i]['conversations'][turn_start + 1]['from'] == 'gpt'
|
145 |
+
output = data[i]['conversations'][turn_start + 1]['value']
|
146 |
+
examplenew = example1.copy()
|
147 |
+
examplenew[0] = instruction
|
148 |
+
examplenew[1] = '' # no input
|
149 |
+
examplenew[2] = '' # no context
|
150 |
+
examples.append(examplenew)
|
151 |
+
responses.append(output)
|
152 |
+
|
153 |
+
with torch.device("cuda"):
|
154 |
+
# ensure was set right above before examples generated
|
155 |
+
assert not stream_output, "stream_output=True does not make sense with example loop"
|
156 |
+
import time
|
157 |
+
from functools import partial
|
158 |
+
|
159 |
+
# get score model
|
160 |
+
smodel, stokenizer, sdevice = get_score_model(**locals())
|
161 |
+
|
162 |
+
if not eval_sharegpt_as_output:
|
163 |
+
model, tokenizer, device = get_model(**locals())
|
164 |
+
model_state = [model, tokenizer, device, base_model]
|
165 |
+
fun = partial(evaluate, model_state, debug=debug, chat=chat)
|
166 |
+
else:
|
167 |
+
assert eval_sharegpt_prompts_only > 0
|
168 |
+
|
169 |
+
def get_response(*args, exi=0):
|
170 |
+
# assumes same ordering of examples and responses
|
171 |
+
yield responses[exi]
|
172 |
+
|
173 |
+
fun = get_response
|
174 |
+
t0 = time.time()
|
175 |
+
score_dump = []
|
176 |
+
num_examples = len(examples)
|
177 |
+
|
178 |
+
import matplotlib.pyplot as plt
|
179 |
+
|
180 |
+
for exi, ex in enumerate(examples):
|
181 |
+
clear_torch_cache()
|
182 |
+
print("")
|
183 |
+
print("START" + "=" * 100)
|
184 |
+
print("Question: %s %s" % (ex[0], ('input=%s' % ex[1] if ex[1] else '')))
|
185 |
+
print("-" * 105)
|
186 |
+
# fun yields as generator, so have to iterate over it
|
187 |
+
# Also means likely do NOT want --stream_output=True, else would show all generations
|
188 |
+
for res in fun(*tuple(ex), exi=exi):
|
189 |
+
print(res)
|
190 |
+
if smodel:
|
191 |
+
score_with_prompt = False
|
192 |
+
if score_with_prompt:
|
193 |
+
data_point = dict(instruction=ex[0], input=ex[1])
|
194 |
+
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
195 |
+
prompt = prompter.generate_prompt(data_point)
|
196 |
+
else:
|
197 |
+
# just raw input and output
|
198 |
+
assert ex[1] in [None, ''] # should be no iinput
|
199 |
+
assert ex[2] in [None, ''] # should be no context
|
200 |
+
prompt = ex[0]
|
201 |
+
cutoff_len = 768 if os.environ.get("HUGGINGFACE_SPACES") else 2048
|
202 |
+
inputs = stokenizer(prompt, res,
|
203 |
+
return_tensors="pt",
|
204 |
+
truncation=True,
|
205 |
+
max_length=cutoff_len)
|
206 |
+
try:
|
207 |
+
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
|
208 |
+
except torch.cuda.OutOfMemoryError as e:
|
209 |
+
print("GPU OOM: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
|
210 |
+
traceback.print_exc()
|
211 |
+
score = 0.0
|
212 |
+
clear_torch_cache()
|
213 |
+
print("SCORE %s: %s" % (exi, score), flush=True)
|
214 |
+
score_dump.append(ex + [prompt, res, score])
|
215 |
+
# dump every score in case abort
|
216 |
+
scoring_path = 'scoring'
|
217 |
+
os.makedirs(scoring_path, exist_ok=True)
|
218 |
+
if eval_sharegpt_as_output:
|
219 |
+
used_base_model = 'gpt35'
|
220 |
+
used_lora_weights = ''
|
221 |
+
else:
|
222 |
+
used_base_model = str(base_model.split('/')[-1])
|
223 |
+
used_lora_weights = str(lora_weights.split('/')[-1])
|
224 |
+
df_scores = pd.DataFrame(score_dump, columns=eval_func_param_names + ['prompt', 'response', 'score'])
|
225 |
+
filename = "df_scores_%s_%s_%s_%s_%s_%s.parquet" % (num_examples, eval_sharegpt_prompts_only,
|
226 |
+
eval_sharegpt_prompts_only_seed,
|
227 |
+
eval_sharegpt_as_output,
|
228 |
+
used_base_model,
|
229 |
+
used_lora_weights)
|
230 |
+
filename = os.path.join(scoring_path, filename)
|
231 |
+
df_scores.to_parquet(filename, index=False)
|
232 |
+
# plot histogram so far
|
233 |
+
plt.figure(figsize=(10, 10))
|
234 |
+
plt.hist(df_scores['score'], bins=20)
|
235 |
+
score_avg = np.mean(df_scores['score'])
|
236 |
+
score_median = np.median(df_scores['score'])
|
237 |
+
plt.title("Score avg: %s median: %s" % (score_avg, score_median))
|
238 |
+
plt.savefig(filename.replace('.parquet', '.png'))
|
239 |
+
plt.close()
|
240 |
+
|
241 |
+
print("END" + "=" * 102)
|
242 |
+
print("")
|
243 |
+
t2 = time.time()
|
244 |
+
print("Time taken so far: %.4f about %.4g per example" % (t2 - t0, (t2 - t0) / (1 + exi)))
|
245 |
+
t1 = time.time()
|
246 |
+
print("Total time taken: %.4f about %.4g per example" % (t1 - t0, (t1 - t0) / num_examples))
|
247 |
+
return
|
248 |
+
if gradio:
|
249 |
+
go_gradio(**locals())
|
250 |
+
|
251 |
+
|
252 |
+
def get_device():
|
253 |
+
if torch.cuda.is_available():
|
254 |
+
device = "cuda"
|
255 |
+
else:
|
256 |
+
raise RuntimeError("only cuda supported")
|
257 |
+
|
258 |
+
return device
|
259 |
+
|
260 |
+
|
261 |
+
def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type, force_1_gpu=True, use_auth_token=False):
|
262 |
+
"""
|
263 |
+
Ensure model gets on correct device
|
264 |
+
:param base_model:
|
265 |
+
:param model_loader:
|
266 |
+
:param load_half:
|
267 |
+
:param model_kwargs:
|
268 |
+
:param reward_type:
|
269 |
+
:return:
|
270 |
+
"""
|
271 |
+
with init_empty_weights():
|
272 |
+
from transformers import AutoConfig
|
273 |
+
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token)
|
274 |
+
model = AutoModel.from_config(
|
275 |
+
config,
|
276 |
+
)
|
277 |
+
|
278 |
+
# NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
|
279 |
+
# NOTE: Some models require avoiding sharding some layers,
|
280 |
+
# then would pass no_split_module_classes and give list of those layers.
|
281 |
+
device_map = infer_auto_device_map(
|
282 |
+
model,
|
283 |
+
dtype=torch.float16 if load_half else torch.float32,
|
284 |
+
)
|
285 |
+
if hasattr(model, 'model'):
|
286 |
+
device_map_model = infer_auto_device_map(
|
287 |
+
model.model,
|
288 |
+
dtype=torch.float16 if load_half else torch.float32,
|
289 |
+
)
|
290 |
+
device_map.update(device_map_model)
|
291 |
+
print('device_map: %s' % device_map, flush=True)
|
292 |
+
|
293 |
+
if force_1_gpu:
|
294 |
+
# FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set.
|
295 |
+
# So avoid for now, just put on first GPU, unless score_model, put on last
|
296 |
+
n_gpus = torch.cuda.device_count()
|
297 |
+
if reward_type:
|
298 |
+
device_map = {'': n_gpus - 1}
|
299 |
+
else:
|
300 |
+
device_map = {'': 0}
|
301 |
+
|
302 |
+
load_in_8bit = model_kwargs.get('load_in_8bit', False)
|
303 |
+
model_kwargs['device_map'] = device_map
|
304 |
+
|
305 |
+
if load_in_8bit or not load_half:
|
306 |
+
model = model_loader.from_pretrained(
|
307 |
+
base_model,
|
308 |
+
**model_kwargs,
|
309 |
+
)
|
310 |
+
else:
|
311 |
+
model = model_loader.from_pretrained(
|
312 |
+
base_model,
|
313 |
+
**model_kwargs,
|
314 |
+
).half()
|
315 |
+
return model
|
316 |
+
|
317 |
+
|
318 |
+
def get_model(
|
319 |
+
load_8bit: bool = False,
|
320 |
+
load_half: bool = True,
|
321 |
+
infer_devices: bool = True,
|
322 |
+
base_model: str = '',
|
323 |
+
tokenizer_base_model: str = '',
|
324 |
+
lora_weights: str = "",
|
325 |
+
force_1_gpu: bool = False,
|
326 |
+
|
327 |
+
llama_type: bool = None,
|
328 |
+
reward_type: bool = None,
|
329 |
+
local_files_only: bool = False,
|
330 |
+
resume_download: bool = True,
|
331 |
+
use_auth_token: Union[str, bool] = False,
|
332 |
+
compile: bool = True,
|
333 |
+
**kwargs,
|
334 |
+
):
|
335 |
+
"""
|
336 |
+
|
337 |
+
:param load_8bit: load model in 8-bit, not supported by all models
|
338 |
+
:param load_half: load model in 16-bit
|
339 |
+
:param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
|
340 |
+
For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
|
341 |
+
So it is not the default
|
342 |
+
:param base_model: name/path of base model
|
343 |
+
:param tokenizer_base_model: name/path of tokenizer
|
344 |
+
:param lora_weights: name/path
|
345 |
+
:param force_1_gpu:
|
346 |
+
:param llama_type: whether LLaMa type model
|
347 |
+
:param reward_type: reward type model for sequence classification
|
348 |
+
:param local_files_only: use local files instead of from HF
|
349 |
+
:param resume_download: resume downloads from HF
|
350 |
+
:param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
|
351 |
+
:parm compile: whether to compile torch model
|
352 |
+
:param kwargs:
|
353 |
+
:return:
|
354 |
+
"""
|
355 |
+
print("Get %s model" % base_model, flush=True)
|
356 |
+
if lora_weights is not None and lora_weights.strip():
|
357 |
+
print("Get %s lora weights" % lora_weights, flush=True)
|
358 |
+
device = get_device()
|
359 |
+
|
360 |
+
if 'gpt2' in base_model.lower():
|
361 |
+
# RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
|
362 |
+
load_8bit = False
|
363 |
+
|
364 |
+
assert base_model.strip(), (
|
365 |
+
"Please choose a base model with --base_model (CLI) or in Models Tab (gradio)"
|
366 |
+
)
|
367 |
+
llama_type = llama_type or "llama" in base_model
|
368 |
+
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
|
369 |
+
if not tokenizer_base_model:
|
370 |
+
tokenizer_base_model = base_model
|
371 |
+
|
372 |
+
if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
|
373 |
+
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
374 |
+
local_files_only=local_files_only,
|
375 |
+
resume_download=resume_download,
|
376 |
+
use_auth_token=use_auth_token,
|
377 |
+
)
|
378 |
+
else:
|
379 |
+
tokenizer = tokenizer_loader
|
380 |
+
|
381 |
+
if isinstance(tokenizer, str):
|
382 |
+
# already a pipeline, tokenizer_loader is string for task
|
383 |
+
model = model_loader(tokenizer,
|
384 |
+
model=base_model,
|
385 |
+
device=0 if device == "cuda" else -1,
|
386 |
+
torch_dtype=torch.float16)
|
387 |
+
else:
|
388 |
+
assert device == "cuda", "Unsupported device %s" % device
|
389 |
+
model_kwargs = dict(local_files_only=local_files_only,
|
390 |
+
torch_dtype=torch.float16,
|
391 |
+
resume_download=resume_download,
|
392 |
+
use_auth_token=use_auth_token)
|
393 |
+
if 'mbart-' not in base_model.lower():
|
394 |
+
model_kwargs.update(dict(load_in_8bit=load_8bit,
|
395 |
+
device_map={"": 0} if load_8bit else "auto",
|
396 |
+
))
|
397 |
+
if 'OpenAssistant/reward-model'.lower() in base_model.lower():
|
398 |
+
# could put on other GPUs
|
399 |
+
model_kwargs['device_map'] = {"": 0}
|
400 |
+
model_kwargs.pop('torch_dtype', None)
|
401 |
+
|
402 |
+
if not lora_weights:
|
403 |
+
with torch.device("cuda"):
|
404 |
+
if infer_devices:
|
405 |
+
model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
406 |
+
force_1_gpu=force_1_gpu, use_auth_token=use_auth_token)
|
407 |
+
else:
|
408 |
+
if load_half and not load_8bit:
|
409 |
+
model = model_loader.from_pretrained(
|
410 |
+
base_model,
|
411 |
+
**model_kwargs).half()
|
412 |
+
else:
|
413 |
+
model = model_loader.from_pretrained(
|
414 |
+
base_model,
|
415 |
+
**model_kwargs)
|
416 |
+
elif load_8bit:
|
417 |
+
model = model_loader.from_pretrained(
|
418 |
+
base_model,
|
419 |
+
**model_kwargs
|
420 |
+
)
|
421 |
+
model = PeftModel.from_pretrained(
|
422 |
+
model,
|
423 |
+
lora_weights,
|
424 |
+
torch_dtype=torch.float16,
|
425 |
+
local_files_only=local_files_only,
|
426 |
+
resume_download=resume_download,
|
427 |
+
use_auth_token=use_auth_token,
|
428 |
+
device_map={"": 0}, # seems to be required
|
429 |
+
)
|
430 |
+
else:
|
431 |
+
with torch.device("cuda"):
|
432 |
+
model = model_loader.from_pretrained(
|
433 |
+
base_model,
|
434 |
+
**model_kwargs
|
435 |
+
)
|
436 |
+
model = PeftModel.from_pretrained(
|
437 |
+
model,
|
438 |
+
lora_weights,
|
439 |
+
torch_dtype=torch.float16,
|
440 |
+
local_files_only=local_files_only,
|
441 |
+
resume_download=resume_download,
|
442 |
+
use_auth_token=use_auth_token,
|
443 |
+
device_map="auto",
|
444 |
+
)
|
445 |
+
if load_half:
|
446 |
+
model.half()
|
447 |
+
|
448 |
+
# unwind broken decapoda-research config
|
449 |
+
if llama_type:
|
450 |
+
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
|
451 |
+
model.config.bos_token_id = 1
|
452 |
+
model.config.eos_token_id = 2
|
453 |
+
if 'gpt2' in base_model.lower():
|
454 |
+
# add special tokens that otherwise all share the same id
|
455 |
+
tokenizer.add_special_tokens({'bos_token': '<bos>',
|
456 |
+
'eos_token': '<eos>',
|
457 |
+
'pad_token': '<pad>'})
|
458 |
+
|
459 |
+
if not isinstance(tokenizer, str):
|
460 |
+
model.eval()
|
461 |
+
if torch.__version__ >= "2" and sys.platform != "win32" and compile:
|
462 |
+
model = torch.compile(model)
|
463 |
+
|
464 |
+
return model, tokenizer, device
|
465 |
+
|
466 |
+
|
467 |
+
def get_score_model(**kwargs):
|
468 |
+
# score model
|
469 |
+
if kwargs.get('score_model') is not None and kwargs.get('score_model').strip():
|
470 |
+
score_all_kwargs = kwargs.copy()
|
471 |
+
score_all_kwargs['load_8bit'] = False
|
472 |
+
score_all_kwargs['load_half'] = False
|
473 |
+
score_all_kwargs['base_model'] = kwargs.get('score_model').strip()
|
474 |
+
score_all_kwargs['tokenizer_base_model'] = ''
|
475 |
+
score_all_kwargs['lora_weights'] = ''
|
476 |
+
score_all_kwargs['llama_type'] = False
|
477 |
+
score_all_kwargs['compile'] = False
|
478 |
+
smodel, stokenizer, sdevice = get_model(**score_all_kwargs)
|
479 |
+
else:
|
480 |
+
smodel, stokenizer, sdevice = None, None, None
|
481 |
+
return smodel, stokenizer, sdevice
|
482 |
+
|
483 |
+
|
484 |
+
def go_gradio(**kwargs):
|
485 |
+
|
486 |
+
# get default model
|
487 |
+
all_kwargs = kwargs.copy()
|
488 |
+
all_kwargs.update(locals())
|
489 |
+
if kwargs.get('base_model') and not kwargs['login_mode_if_model0']:
|
490 |
+
model0, tokenizer0, device = get_model(**all_kwargs)
|
491 |
+
else:
|
492 |
+
# if empty model, then don't load anything, just get gradio up
|
493 |
+
model0, tokenizer0, device = None, None, None
|
494 |
+
model_state0 = [model0, tokenizer0, device, kwargs['base_model']]
|
495 |
+
|
496 |
+
# get score model
|
497 |
+
smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
|
498 |
+
|
499 |
+
if 'mbart-' in kwargs['model_lower']:
|
500 |
+
instruction_label = "Text to translate"
|
501 |
+
else:
|
502 |
+
instruction_label = "Instruction"
|
503 |
+
if kwargs['chat']:
|
504 |
+
instruction_label = "You (Shift-Enter or push Submit to send message)"
|
505 |
+
|
506 |
+
title = 'h2oGPT'
|
507 |
+
if kwargs['verbose']:
|
508 |
+
description = f"""Model {kwargs['base_model']} Instruct dataset.
|
509 |
+
For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).
|
510 |
+
Command: {str(' '.join(sys.argv))}
|
511 |
+
Hash: {get_githash()}
|
512 |
+
"""
|
513 |
+
else:
|
514 |
+
description = "For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).<br>"
|
515 |
+
if os.environ.get("HUGGINGFACE_SPACES"):
|
516 |
+
description += """<p><b> DISCLAIMERS: </b><ul><i><li>The data used to train this model include The Pile and other sources. These may contain objectionable content, so the model may reproduce that material. Use application and responses at own risk.</i></li>"""
|
517 |
+
if kwargs['load_8bit']:
|
518 |
+
description += """<i><li> Model is loaded in 8-bit and 768 token context length to fit on HF GPUs, so model may perform worse than 16-bit with 2048 token limit.</i></li>"""
|
519 |
+
description += """<i><li>Model loading and unloading disabled on HF SPACES to avoid GPU OOM for multi-user environment.</i></li></ul></p>"""
|
520 |
+
|
521 |
+
if kwargs['verbose']:
|
522 |
+
task_info_md = f"""
|
523 |
+
### Task: {kwargs['task_info']}"""
|
524 |
+
else:
|
525 |
+
task_info_md = ''
|
526 |
+
|
527 |
+
css_code = """footer {visibility: hidden}
|
528 |
+
body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en/site/header/master/_jcr_content/root/container/header_copy/logo.coreimg.svg/1678976605175/h2o-logo.svg");}}"""
|
529 |
+
|
530 |
+
from gradio.themes.utils import colors, fonts, sizes
|
531 |
+
if kwargs['h2ocolors']:
|
532 |
+
colors_dict = dict(primary_hue=colors.yellow,
|
533 |
+
secondary_hue=colors.yellow,
|
534 |
+
neutral_hue=colors.gray,
|
535 |
+
spacing_size=sizes.spacing_md,
|
536 |
+
radius_size=sizes.radius_md,
|
537 |
+
text_size=sizes.text_md,
|
538 |
+
)
|
539 |
+
else:
|
540 |
+
colors_dict = dict(primary_hue=colors.indigo,
|
541 |
+
secondary_hue=colors.indigo,
|
542 |
+
neutral_hue=colors.gray,
|
543 |
+
spacing_size=sizes.spacing_md,
|
544 |
+
radius_size=sizes.radius_md,
|
545 |
+
text_size=sizes.text_md,
|
546 |
+
)
|
547 |
+
|
548 |
+
import gradio as gr
|
549 |
+
|
550 |
+
if kwargs['gradio_avoid_processing_markdown']:
|
551 |
+
from gradio_client import utils as client_utils
|
552 |
+
from gradio.components import Chatbot
|
553 |
+
|
554 |
+
# gradio has issue with taking too long to process input/output for markdown etc.
|
555 |
+
# Avoid for now, allow raw html to render, good enough for chatbot.
|
556 |
+
def _postprocess_chat_messages(self, chat_message: str):
|
557 |
+
if chat_message is None:
|
558 |
+
return None
|
559 |
+
elif isinstance(chat_message, (tuple, list)):
|
560 |
+
filepath = chat_message[0]
|
561 |
+
mime_type = client_utils.get_mimetype(filepath)
|
562 |
+
filepath = self.make_temp_copy_if_needed(filepath)
|
563 |
+
return {
|
564 |
+
"name": filepath,
|
565 |
+
"mime_type": mime_type,
|
566 |
+
"alt_text": chat_message[1] if len(chat_message) > 1 else None,
|
567 |
+
"data": None, # These last two fields are filled in by the frontend
|
568 |
+
"is_file": True,
|
569 |
+
}
|
570 |
+
elif isinstance(chat_message, str):
|
571 |
+
return chat_message
|
572 |
+
else:
|
573 |
+
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
574 |
+
Chatbot._postprocess_chat_messages = _postprocess_chat_messages
|
575 |
+
|
576 |
+
demo = gr.Blocks(theme=gr.themes.Soft(**colors_dict), css=css_code, title="h2oGPT", analytics_enabled=False)
|
577 |
+
callback = gr.CSVLogger()
|
578 |
+
# css_code = 'body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en/site/header/master/_jcr_content/root/container/header_copy/logo.coreimg.svg/1678976605175/h2o-logo.svg");}'
|
579 |
+
# demo = gr.Blocks(theme='gstaff/xkcd', css=css_code)
|
580 |
+
|
581 |
+
model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
|
582 |
+
if kwargs['base_model'].strip() not in model_options:
|
583 |
+
lora_options = [kwargs['base_model'].strip()] + model_options
|
584 |
+
lora_options = kwargs['extra_lora_options']
|
585 |
+
if kwargs['lora_weights'].strip() not in lora_options:
|
586 |
+
lora_options = [kwargs['lora_weights'].strip()] + lora_options
|
587 |
+
# always add in no lora case
|
588 |
+
# add fake space so doesn't go away in gradio dropdown
|
589 |
+
lora_options = [' '] + kwargs['extra_lora_options']
|
590 |
+
|
591 |
+
output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get('base_model') else 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
|
592 |
+
|
593 |
+
with demo:
|
594 |
+
# avoid actual model/tokenizer here or anything that would be bad to deepcopy
|
595 |
+
# https://github.com/gradio-app/gradio/issues/3558
|
596 |
+
model_state = gr.State(['model', 'tokenizer', device, kwargs['base_model']])
|
597 |
+
model_options_state = gr.State([model_options])
|
598 |
+
lora_options_state = gr.State([lora_options])
|
599 |
+
gr.Markdown(
|
600 |
+
f"""
|
601 |
+
<h1 align="center"> {title}</h1>
|
602 |
+
|
603 |
+
{description}
|
604 |
+
{task_info_md}
|
605 |
+
""")
|
606 |
+
|
607 |
+
# go button visible if
|
608 |
+
base_wanted = bool(kwargs['base_model']) and kwargs['login_mode_if_model0']
|
609 |
+
go_btn = gr.Button(value="LOGIN", visible=base_wanted, variant="primary")
|
610 |
+
normal_block = gr.Row(visible=not base_wanted)
|
611 |
+
with normal_block:
|
612 |
+
with gr.Tabs():
|
613 |
+
with gr.Row():
|
614 |
+
if not kwargs['chat']:
|
615 |
+
with gr.Column():
|
616 |
+
instruction = gr.Textbox(
|
617 |
+
lines=4, label=instruction_label,
|
618 |
+
placeholder=kwargs['placeholder_instruction'],
|
619 |
+
)
|
620 |
+
iinput = gr.Textbox(lines=4, label="Input",
|
621 |
+
placeholder=kwargs['placeholder_input'])
|
622 |
+
flag_btn = gr.Button("Flag")
|
623 |
+
if kwargs['score_model']:
|
624 |
+
if not kwargs['auto_score']:
|
625 |
+
with gr.Column():
|
626 |
+
score_btn = gr.Button("Score last prompt & response")
|
627 |
+
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
628 |
+
else:
|
629 |
+
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
630 |
+
with gr.Column():
|
631 |
+
if kwargs['chat']:
|
632 |
+
text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400)
|
633 |
+
with gr.Row():
|
634 |
+
with gr.Column(scale=50):
|
635 |
+
instruction = gr.Textbox(
|
636 |
+
lines=4, label=instruction_label,
|
637 |
+
placeholder=kwargs['placeholder_instruction'],
|
638 |
+
)
|
639 |
+
with gr.Row(): # .style(equal_height=False, equal_width=False):
|
640 |
+
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
|
641 |
+
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
|
642 |
+
with gr.Row():
|
643 |
+
clear = gr.Button("New Conversation")
|
644 |
+
flag_btn = gr.Button("Flag")
|
645 |
+
if kwargs['score_model']:
|
646 |
+
if not kwargs['auto_score']:
|
647 |
+
with gr.Column():
|
648 |
+
score_btn = gr.Button("Score last prompt & response").style(full_width=False, size='sm')
|
649 |
+
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
650 |
+
else:
|
651 |
+
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
652 |
+
retry = gr.Button("Regenerate")
|
653 |
+
undo = gr.Button("Undo")
|
654 |
+
else:
|
655 |
+
text_output = gr.Textbox(lines=5, label=output_label0)
|
656 |
+
with gr.TabItem("Input/Output"):
|
657 |
+
with gr.Row():
|
658 |
+
if 'mbart-' in kwargs['model_lower']:
|
659 |
+
src_lang = gr.Dropdown(list(languages_covered().keys()),
|
660 |
+
value=kwargs['src_lang'],
|
661 |
+
label="Input Language")
|
662 |
+
tgt_lang = gr.Dropdown(list(languages_covered().keys()),
|
663 |
+
value=kwargs['tgt_lang'],
|
664 |
+
label="Output Language")
|
665 |
+
with gr.TabItem("Expert"):
|
666 |
+
with gr.Row():
|
667 |
+
with gr.Column():
|
668 |
+
stream_output = gr.components.Checkbox(label="Stream output",
|
669 |
+
value=kwargs['stream_output'])
|
670 |
+
prompt_type = gr.Dropdown(prompt_types_strings,
|
671 |
+
value=kwargs['prompt_type'], label="Prompt Type")
|
672 |
+
temperature = gr.Slider(minimum=0, maximum=3,
|
673 |
+
value=kwargs['temperature'],
|
674 |
+
label="Temperature",
|
675 |
+
info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
|
676 |
+
top_p = gr.Slider(minimum=0, maximum=1,
|
677 |
+
value=kwargs['top_p'], label="Top p",
|
678 |
+
info="Cumulative probability of tokens to sample from")
|
679 |
+
top_k = gr.Slider(
|
680 |
+
minimum=0, maximum=100, step=1,
|
681 |
+
value=kwargs['top_k'], label="Top k",
|
682 |
+
info='Num. tokens to sample from'
|
683 |
+
)
|
684 |
+
num_beams = gr.Slider(minimum=1, maximum=8, step=1,
|
685 |
+
value=kwargs['num_beams'], label="Beams",
|
686 |
+
info="Number of searches for optimal overall probability. Uses more GPU memory/compute")
|
687 |
+
max_new_tokens = gr.Slider(
|
688 |
+
minimum=1, maximum=2048, step=1,
|
689 |
+
value=kwargs['max_new_tokens'], label="Max output length"
|
690 |
+
)
|
691 |
+
min_new_tokens = gr.Slider(
|
692 |
+
minimum=0, maximum=2048, step=1,
|
693 |
+
value=kwargs['min_new_tokens'], label="Min output length"
|
694 |
+
)
|
695 |
+
early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
|
696 |
+
value=kwargs['early_stopping'])
|
697 |
+
max_time = gr.Slider(minimum=0, maximum=60 * 5, step=1,
|
698 |
+
value=kwargs['max_time'], label="Max. time",
|
699 |
+
info="Max. time to search optimal output.")
|
700 |
+
repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
|
701 |
+
value=kwargs['repetition_penalty'],
|
702 |
+
label="Repetition Penalty")
|
703 |
+
num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
|
704 |
+
value=kwargs['num_return_sequences'],
|
705 |
+
label="Number Returns", info="Must be <= num_beams")
|
706 |
+
do_sample = gr.Checkbox(label="Sample", info="Sample, for diverse output(s)",
|
707 |
+
value=kwargs['do_sample'])
|
708 |
+
if kwargs['chat']:
|
709 |
+
iinput = gr.Textbox(lines=4, label="Input",
|
710 |
+
placeholder=kwargs['placeholder_input'])
|
711 |
+
context = gr.Textbox(lines=1, label="Context",
|
712 |
+
info="Ignored in chat mode.") # nominally empty for chat mode
|
713 |
+
|
714 |
+
with gr.TabItem("Models"):
|
715 |
+
with gr.Row():
|
716 |
+
with gr.Column():
|
717 |
+
with gr.Row(scale=1):
|
718 |
+
with gr.Column(scale=50):
|
719 |
+
model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model", value=kwargs['base_model'])
|
720 |
+
lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA", value=kwargs['lora_weights'], visible=kwargs['show_lora'])
|
721 |
+
with gr.Column(scale=1):
|
722 |
+
load_msg = "Load Model/LORA" if not os.environ.get("HUGGINGFACE_SPACES") \
|
723 |
+
else "LOAD DISABLED ON HF SPACES"
|
724 |
+
load_model_button = gr.Button(load_msg)
|
725 |
+
model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
|
726 |
+
lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'], visible=kwargs['show_lora'])
|
727 |
+
with gr.Row(scale=1):
|
728 |
+
with gr.Column(scale=50):
|
729 |
+
new_model = gr.Textbox(label="New Model HF name/path")
|
730 |
+
new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
|
731 |
+
with gr.Column(scale=1):
|
732 |
+
add_model_button = gr.Button("Add new model name")
|
733 |
+
add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
|
734 |
+
|
735 |
+
inputs_list = get_inputs_list(locals(), kwargs['model_lower'])
|
736 |
+
from functools import partial
|
737 |
+
all_kwargs = kwargs.copy()
|
738 |
+
all_kwargs.update(locals())
|
739 |
+
kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
|
740 |
+
fun = partial(evaluate,
|
741 |
+
**kwargs_evaluate)
|
742 |
+
|
743 |
+
dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
|
744 |
+
size="sm",
|
745 |
+
)
|
746 |
+
dark_mode_btn.click(
|
747 |
+
None,
|
748 |
+
None,
|
749 |
+
None,
|
750 |
+
_js="""() => {
|
751 |
+
if (document.querySelectorAll('.dark').length) {
|
752 |
+
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
|
753 |
+
} else {
|
754 |
+
document.querySelector('body').classList.add('dark');
|
755 |
+
}
|
756 |
+
}""",
|
757 |
+
api_name="dark",
|
758 |
+
)
|
759 |
+
if not kwargs['chat']:
|
760 |
+
submit = gr.Button("Submit")
|
761 |
+
submit_event = submit.click(fun, inputs=[model_state] + inputs_list, outputs=text_output, api_name='submit')
|
762 |
+
|
763 |
+
# examples after submit or any other buttons for chat or no chat
|
764 |
+
if kwargs['examples'] is not None and kwargs['show_examples']:
|
765 |
+
gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
|
766 |
+
|
767 |
+
# Score
|
768 |
+
def score_last_response(*args):
|
769 |
+
""" Similar to user() """
|
770 |
+
args_list = list(args)
|
771 |
+
history = args_list[-1]
|
772 |
+
if history is None:
|
773 |
+
print("Bad history in scoring last response, fix for now", flush=True)
|
774 |
+
history = []
|
775 |
+
if smodel is not None and \
|
776 |
+
stokenizer is not None and \
|
777 |
+
sdevice is not None and \
|
778 |
+
history is not None and len(history) > 0 and \
|
779 |
+
history[-1] is not None and \
|
780 |
+
len(history[-1]) >= 2:
|
781 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
782 |
+
|
783 |
+
max_length_tokenize = 512 if os.environ.get("HUGGINGFACE_SPACES") else 2048
|
784 |
+
cutoff_len = max_length_tokenize*4 # restrict deberta related to max for LLM
|
785 |
+
|
786 |
+
question = history[-1][0]
|
787 |
+
question = question[-cutoff_len:]
|
788 |
+
|
789 |
+
answer = history[-1][1]
|
790 |
+
answer = answer[-cutoff_len:]
|
791 |
+
|
792 |
+
inputs = stokenizer(question, answer,
|
793 |
+
return_tensors="pt",
|
794 |
+
truncation=True,
|
795 |
+
max_length=max_length_tokenize).to(smodel.device)
|
796 |
+
try:
|
797 |
+
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
|
798 |
+
except torch.cuda.OutOfMemoryError as e:
|
799 |
+
print("GPU OOM: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
|
800 |
+
del inputs
|
801 |
+
traceback.print_exc()
|
802 |
+
clear_torch_cache()
|
803 |
+
return 'Response Score: GPU OOM'
|
804 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
805 |
+
return 'Response Score: {:.1%}'.format(score)
|
806 |
+
else:
|
807 |
+
return 'Response Score: NA'
|
808 |
+
|
809 |
+
if kwargs['score_model']:
|
810 |
+
score_args = dict(fn=score_last_response,
|
811 |
+
inputs=inputs_list + [text_output],
|
812 |
+
outputs=[score_text],
|
813 |
+
)
|
814 |
+
if not kwargs['auto_score']:
|
815 |
+
score_event = score_btn.click(**score_args, queue=stream_output, api_name='score')
|
816 |
+
|
817 |
+
if kwargs['chat']:
|
818 |
+
def user(*args, undo=False, sanitize_user_prompt=True):
|
819 |
+
args_list = list(args)
|
820 |
+
user_message = args_list[0]
|
821 |
+
input1 = args_list[1]
|
822 |
+
context1 = args_list[2]
|
823 |
+
if input1 and not user_message.endswith(':'):
|
824 |
+
user_message1 = user_message + ":" + input1
|
825 |
+
elif input1:
|
826 |
+
user_message1 = user_message + input1
|
827 |
+
else:
|
828 |
+
user_message1 = user_message
|
829 |
+
if sanitize_user_prompt:
|
830 |
+
from better_profanity import profanity
|
831 |
+
user_message1 = profanity.censor(user_message1)
|
832 |
+
|
833 |
+
history = args_list[-1]
|
834 |
+
if undo and history:
|
835 |
+
history.pop()
|
836 |
+
args_list = args_list[:-1]
|
837 |
+
if history is None:
|
838 |
+
print("Bad history, fix for now", flush=True)
|
839 |
+
history = []
|
840 |
+
if undo:
|
841 |
+
return "", history
|
842 |
+
else:
|
843 |
+
return "", history + [[user_message1, None]]
|
844 |
+
|
845 |
+
def bot(*args, retry=False):
|
846 |
+
args_list = list(args)
|
847 |
+
history = args_list[-1]
|
848 |
+
if retry and history:
|
849 |
+
history.pop()
|
850 |
+
if not history:
|
851 |
+
print("No history", flush=True)
|
852 |
+
return
|
853 |
+
instruction1 = history[-1][0]
|
854 |
+
context1 = ''
|
855 |
+
if kwargs['chat_history'] > 0:
|
856 |
+
prompt_type1 = args_list[prompt_type_arg_id]
|
857 |
+
context1 = ''
|
858 |
+
for histi in range(len(history) - 1):
|
859 |
+
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
|
860 |
+
context1 += generate_prompt(data_point, prompt_type1, kwargs['chat'], reduced=True)[0].replace(
|
861 |
+
'<br>', '\n')
|
862 |
+
if not context1.endswith('\n'):
|
863 |
+
context1 += '\n'
|
864 |
+
if context1 and not context1.endswith('\n'):
|
865 |
+
context1 += '\n' # ensure if terminates abruptly, then human continues on next line
|
866 |
+
args_list[0] = instruction1
|
867 |
+
# only include desired chat history
|
868 |
+
args_list[2] = context1[-kwargs['chat_history']:]
|
869 |
+
model_state1 = args_list[-2]
|
870 |
+
args_list = args_list[:-2]
|
871 |
+
fun1 = partial(evaluate,
|
872 |
+
model_state1,
|
873 |
+
**kwargs_evaluate)
|
874 |
+
try:
|
875 |
+
for output in fun1(*tuple(args_list)):
|
876 |
+
bot_message = output
|
877 |
+
history[-1][1] = bot_message
|
878 |
+
yield history
|
879 |
+
except StopIteration:
|
880 |
+
yield history
|
881 |
+
except RuntimeError as e:
|
882 |
+
if "generator raised StopIteration" in str(e):
|
883 |
+
# assume last entry was bad, undo
|
884 |
+
history.pop()
|
885 |
+
yield history
|
886 |
+
raise
|
887 |
+
except Exception as e:
|
888 |
+
# put error into user input
|
889 |
+
history[-1][0] = "Exception: %s" % str(e)
|
890 |
+
yield history
|
891 |
+
raise
|
892 |
+
return
|
893 |
+
|
894 |
+
user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
|
895 |
+
inputs=inputs_list + [text_output],
|
896 |
+
outputs=[instruction, text_output],
|
897 |
+
)
|
898 |
+
bot_args = dict(fn=bot,
|
899 |
+
inputs=inputs_list + [model_state] + [text_output],
|
900 |
+
outputs=[text_output],
|
901 |
+
)
|
902 |
+
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
|
903 |
+
inputs=inputs_list + [model_state] + [text_output],
|
904 |
+
outputs=[text_output],
|
905 |
+
)
|
906 |
+
undo_user_args = dict(fn=functools.partial(user, undo=True),
|
907 |
+
inputs=inputs_list + [text_output],
|
908 |
+
outputs=[instruction, text_output],
|
909 |
+
)
|
910 |
+
|
911 |
+
if kwargs['auto_score']:
|
912 |
+
submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction').then(
|
913 |
+
**bot_args, api_name='instruction_bot',
|
914 |
+
).then(**score_args, api_name='instruction_bot_score').then(clear_torch_cache)
|
915 |
+
submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit').then(
|
916 |
+
**bot_args, api_name='submit_bot',
|
917 |
+
).then(**score_args, api_name='submit_bot_score').then(clear_torch_cache)
|
918 |
+
submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry').then(
|
919 |
+
**retry_bot_args, api_name='retry_bot',
|
920 |
+
).then(**score_args, api_name='retry_bot_score').then(clear_torch_cache)
|
921 |
+
submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo').then(**score_args, api_name='undo_score')
|
922 |
+
else:
|
923 |
+
submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction').then(
|
924 |
+
**bot_args, api_name='instruction_bot',
|
925 |
+
).then(clear_torch_cache)
|
926 |
+
submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit').then(
|
927 |
+
**bot_args, api_name='submit_bot',
|
928 |
+
).then(clear_torch_cache)
|
929 |
+
submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry').then(
|
930 |
+
**retry_bot_args, api_name='retry_bot',
|
931 |
+
).then(clear_torch_cache)
|
932 |
+
submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo')
|
933 |
+
clear.click(lambda: None, None, text_output, queue=False, api_name='clear')
|
934 |
+
|
935 |
+
def load_model(model_name, lora_weights, model_state_old, prompt_type_old):
|
936 |
+
# ensure old model removed from GPU memory
|
937 |
+
if kwargs['debug']:
|
938 |
+
print("Pre-switch pre-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
|
939 |
+
|
940 |
+
if isinstance(model_state_old[0], str) and model0 is not None:
|
941 |
+
# best can do, move model loaded at first to CPU
|
942 |
+
model0.cpu()
|
943 |
+
|
944 |
+
if model_state_old[0] is not None and not isinstance(model_state_old[0], str):
|
945 |
+
try:
|
946 |
+
model_state_old[0].cpu()
|
947 |
+
except Exception as e:
|
948 |
+
# sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data!
|
949 |
+
print("Unable to put model on CPU: %s" % str(e), flush=True)
|
950 |
+
del model_state_old[0]
|
951 |
+
model_state_old[0] = None
|
952 |
+
|
953 |
+
if model_state_old[1] is not None and not isinstance(model_state_old[1], str):
|
954 |
+
del model_state_old[1]
|
955 |
+
model_state_old[1] = None
|
956 |
+
|
957 |
+
clear_torch_cache()
|
958 |
+
if kwargs['debug']:
|
959 |
+
print("Pre-switch post-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
|
960 |
+
all_kwargs['base_model'] = model_name.strip()
|
961 |
+
model_lower = model_name.strip().lower()
|
962 |
+
if model_lower in inv_prompt_type_to_model_lower:
|
963 |
+
prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
|
964 |
+
else:
|
965 |
+
prompt_type1 = prompt_type_old
|
966 |
+
|
967 |
+
all_kwargs['lora_weights'] = lora_weights.strip()
|
968 |
+
model1, tokenizer1, device1 = get_model(**all_kwargs)
|
969 |
+
clear_torch_cache()
|
970 |
+
|
971 |
+
if kwargs['debug']:
|
972 |
+
print("Post-switch GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
|
973 |
+
return {model_state: [model1, tokenizer1, device1, model_name],
|
974 |
+
model_used: model_name,
|
975 |
+
lora_used: lora_weights,
|
976 |
+
prompt_type: prompt_type1}
|
977 |
+
|
978 |
+
def dropdown_prompt_type_list(x):
|
979 |
+
return gr.Dropdown.update(value=x)
|
980 |
+
|
981 |
+
def chatbot_list(x, model_used_in):
|
982 |
+
return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
|
983 |
+
|
984 |
+
load_model_args = dict(fn=load_model,
|
985 |
+
inputs=[model_choice, lora_choice, model_state, prompt_type],
|
986 |
+
outputs=[model_state, model_used, lora_used, prompt_type])
|
987 |
+
prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
|
988 |
+
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
|
989 |
+
if not os.environ.get("HUGGINGFACE_SPACES"):
|
990 |
+
load_model_event = load_model_button.click(**load_model_args).then(**prompt_update_args).then(**chatbot_update_args).then(clear_torch_cache)
|
991 |
+
|
992 |
+
def dropdown_model_list(list0, x):
|
993 |
+
new_state = [list0[0] + [x]]
|
994 |
+
new_options = [*new_state[0]]
|
995 |
+
return gr.Dropdown.update(value=x, choices=new_options), '', new_state
|
996 |
+
|
997 |
+
add_model_event = add_model_button.click(fn=dropdown_model_list,
|
998 |
+
inputs=[model_options_state, new_model],
|
999 |
+
outputs=[model_choice, new_model, model_options_state])
|
1000 |
+
|
1001 |
+
def dropdown_lora_list(list0, x):
|
1002 |
+
new_state = [list0[0] + [x]]
|
1003 |
+
new_options = [*new_state[0]]
|
1004 |
+
return gr.Dropdown.update(value=x, choices=new_options), '', new_state
|
1005 |
+
|
1006 |
+
add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
|
1007 |
+
inputs=[lora_options_state, new_lora],
|
1008 |
+
outputs=[lora_choice, new_lora, lora_options_state])
|
1009 |
+
|
1010 |
+
go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go") \
|
1011 |
+
.then(lambda: gr.update(visible=True), None, normal_block) \
|
1012 |
+
.then(**load_model_args).then(**prompt_update_args)
|
1013 |
+
|
1014 |
+
# callback for logging flagged input/output
|
1015 |
+
callback.setup(inputs_list + [text_output], "flagged_data_points")
|
1016 |
+
flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
|
1017 |
+
api_name='flag')
|
1018 |
+
if kwargs['chat']:
|
1019 |
+
|
1020 |
+
# don't pass text_output, don't want to clear output, just stop it
|
1021 |
+
# FIXME: have to click once to stop output and second time to stop GPUs going
|
1022 |
+
stop_btn.click(lambda: None, None, None, cancels=[submit_event, submit_event2, submit_event3],
|
1023 |
+
queue=False, api_name='stop').then(clear_torch_cache)
|
1024 |
+
|
1025 |
+
demo.queue(concurrency_count=1)
|
1026 |
+
favicon_path = "h2o-logo.svg"
|
1027 |
+
demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
|
1028 |
+
favicon_path=favicon_path, prevent_thread_lock=True) # , enable_queue=True)
|
1029 |
+
print("Started GUI", flush=True)
|
1030 |
+
demo.block_thread()
|
1031 |
+
|
1032 |
+
|
1033 |
+
input_args_list = ['model_state']
|
1034 |
+
inputs_kwargs_list = ['debug', 'chat', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
|
1035 |
+
|
1036 |
+
|
1037 |
+
def get_inputs_list(inputs_dict, model_lower):
|
1038 |
+
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
1039 |
+
inputs_list = []
|
1040 |
+
for k in inputs_list_names:
|
1041 |
+
if k == 'kwargs':
|
1042 |
+
continue
|
1043 |
+
if k in input_args_list + inputs_kwargs_list:
|
1044 |
+
# these are added via partial, not taken as input
|
1045 |
+
continue
|
1046 |
+
if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
|
1047 |
+
continue
|
1048 |
+
inputs_list.append(inputs_dict[k])
|
1049 |
+
return inputs_list
|
1050 |
+
|
1051 |
+
|
1052 |
+
# index of prompt_type in evaluate function, after model_state
|
1053 |
+
prompt_type_arg_id = 4
|
1054 |
+
|
1055 |
+
eval_func_param_names = ['instruction',
|
1056 |
+
'iinput',
|
1057 |
+
'context',
|
1058 |
+
'stream_output',
|
1059 |
+
'prompt_type',
|
1060 |
+
'temperature',
|
1061 |
+
'top_p',
|
1062 |
+
'top_k',
|
1063 |
+
'num_beams',
|
1064 |
+
'max_new_tokens',
|
1065 |
+
'min_new_tokens',
|
1066 |
+
'early_stopping',
|
1067 |
+
'max_time',
|
1068 |
+
'repetition_penalty',
|
1069 |
+
'num_return_sequences',
|
1070 |
+
'do_sample',
|
1071 |
+
]
|
1072 |
+
|
1073 |
+
|
1074 |
+
def evaluate(
|
1075 |
+
model_state,
|
1076 |
+
# START NOTE: Examples must have same order of parameters
|
1077 |
+
instruction,
|
1078 |
+
iinput,
|
1079 |
+
context,
|
1080 |
+
stream_output,
|
1081 |
+
prompt_type,
|
1082 |
+
temperature,
|
1083 |
+
top_p,
|
1084 |
+
top_k,
|
1085 |
+
num_beams,
|
1086 |
+
max_new_tokens,
|
1087 |
+
min_new_tokens,
|
1088 |
+
early_stopping,
|
1089 |
+
max_time,
|
1090 |
+
repetition_penalty,
|
1091 |
+
num_return_sequences,
|
1092 |
+
do_sample,
|
1093 |
+
# END NOTE: Examples must have same order of parameters
|
1094 |
+
src_lang=None,
|
1095 |
+
tgt_lang=None,
|
1096 |
+
debug=False,
|
1097 |
+
chat=False,
|
1098 |
+
hard_stop_list=None,
|
1099 |
+
sanitize_bot_response=True,
|
1100 |
+
model_state0=None,
|
1101 |
+
**kwargs,
|
1102 |
+
):
|
1103 |
+
if debug:
|
1104 |
+
locals_dict = locals().copy()
|
1105 |
+
locals_dict.pop('model_state', None)
|
1106 |
+
print(locals_dict)
|
1107 |
+
|
1108 |
+
no_model_msg = "Please choose a base model with --base_model (CLI) or in Models Tab (gradio).\nThen start New Conversation"
|
1109 |
+
|
1110 |
+
if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
|
1111 |
+
# try to free-up original model (i.e. list was passed as reference)
|
1112 |
+
if model_state0 is not None and model_state0[0] is not None:
|
1113 |
+
model_state0[0].cpu()
|
1114 |
+
model_state0[0] = None
|
1115 |
+
# try to free-up original tokenizer (i.e. list was passed as reference)
|
1116 |
+
if model_state0 is not None and model_state0[1] is not None:
|
1117 |
+
model_state0[1] = None
|
1118 |
+
clear_torch_cache()
|
1119 |
+
model, tokenizer, device, base_model = model_state
|
1120 |
+
elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
|
1121 |
+
assert isinstance(model_state[0], str)
|
1122 |
+
model, tokenizer, device, base_model = model_state0
|
1123 |
+
else:
|
1124 |
+
raise AssertionError(no_model_msg)
|
1125 |
+
|
1126 |
+
assert base_model.strip(), no_model_msg
|
1127 |
+
assert model, "Model is missing"
|
1128 |
+
assert tokenizer, "Tokenizer is missing"
|
1129 |
+
|
1130 |
+
data_point = dict(context=context, instruction=instruction, input=iinput)
|
1131 |
+
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
1132 |
+
prompt = prompter.generate_prompt(data_point)
|
1133 |
+
|
1134 |
+
if hard_stop_list is None:
|
1135 |
+
# acts like undo on user entry and bot response
|
1136 |
+
hard_stop_list = []
|
1137 |
+
|
1138 |
+
if isinstance(tokenizer, str):
|
1139 |
+
# pipeline
|
1140 |
+
if tokenizer == "summarization":
|
1141 |
+
key = 'summary_text'
|
1142 |
+
else:
|
1143 |
+
raise RuntimeError("No such task type %s" % tokenizer)
|
1144 |
+
# NOTE: uses max_length only
|
1145 |
+
yield model(prompt, max_length=max_new_tokens)[0][key]
|
1146 |
+
|
1147 |
+
if 'mbart-' in base_model.lower():
|
1148 |
+
assert src_lang is not None
|
1149 |
+
tokenizer.src_lang = languages_covered()[src_lang]
|
1150 |
+
|
1151 |
+
if chat:
|
1152 |
+
# override, ignore user change
|
1153 |
+
num_return_sequences = 1
|
1154 |
+
if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
|
1155 |
+
if prompt_type == 'human_bot':
|
1156 |
+
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
1157 |
+
# stopping only starts once output is beyond prompt
|
1158 |
+
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
1159 |
+
stop_words = [human, bot]
|
1160 |
+
encounters = [1, 2]
|
1161 |
+
elif prompt_type == 'instruct_vicuna':
|
1162 |
+
# even below is not enough, generic strings and many ways to encode
|
1163 |
+
stop_words = [
|
1164 |
+
'### Human:',
|
1165 |
+
"""
|
1166 |
+
### Human:""",
|
1167 |
+
"""
|
1168 |
+
### Human:
|
1169 |
+
""",
|
1170 |
+
'### Assistant:',
|
1171 |
+
"""
|
1172 |
+
### Assistant:""",
|
1173 |
+
"""
|
1174 |
+
### Assistant:
|
1175 |
+
""",
|
1176 |
+
]
|
1177 |
+
encounters = [1, 2]
|
1178 |
+
else:
|
1179 |
+
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
1180 |
+
stop_words = ['### End']
|
1181 |
+
encounters = [1]
|
1182 |
+
stop_words_ids = [
|
1183 |
+
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
1184 |
+
# handle single token case
|
1185 |
+
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
1186 |
+
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
1187 |
+
# avoid padding in front of tokens
|
1188 |
+
if tokenizer.pad_token:
|
1189 |
+
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
1190 |
+
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
|
1191 |
+
else:
|
1192 |
+
stopping_criteria = StoppingCriteriaList()
|
1193 |
+
|
1194 |
+
# help to avoid errors like:
|
1195 |
+
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
|
1196 |
+
# RuntimeError: expected scalar type Half but found Float
|
1197 |
+
# with - 256
|
1198 |
+
max_length_tokenize = 768 - 256 if os.environ.get("HUGGINGFACE_SPACES") else 2048 - 256
|
1199 |
+
cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
|
1200 |
+
output_smallest = 30 * 4
|
1201 |
+
prompt = prompt[-cutoff_len - output_smallest:]
|
1202 |
+
inputs = tokenizer(prompt,
|
1203 |
+
return_tensors="pt",
|
1204 |
+
truncation=True,
|
1205 |
+
max_length=max_length_tokenize)
|
1206 |
+
if debug and len(inputs["input_ids"]) > 0:
|
1207 |
+
print('input_ids length', len(inputs["input_ids"][0]), flush=True)
|
1208 |
+
input_ids = inputs["input_ids"].to(device)
|
1209 |
+
generation_config = GenerationConfig(
|
1210 |
+
temperature=float(temperature),
|
1211 |
+
top_p=float(top_p),
|
1212 |
+
top_k=top_k,
|
1213 |
+
num_beams=num_beams,
|
1214 |
+
do_sample=do_sample,
|
1215 |
+
repetition_penalty=float(repetition_penalty),
|
1216 |
+
num_return_sequences=num_return_sequences,
|
1217 |
+
renormalize_logits=True,
|
1218 |
+
remove_invalid_values=True,
|
1219 |
+
**kwargs,
|
1220 |
+
)
|
1221 |
+
|
1222 |
+
gen_kwargs = dict(input_ids=input_ids,
|
1223 |
+
generation_config=generation_config,
|
1224 |
+
return_dict_in_generate=True,
|
1225 |
+
output_scores=True,
|
1226 |
+
max_new_tokens=max_new_tokens, # prompt + new
|
1227 |
+
min_new_tokens=min_new_tokens, # prompt + new
|
1228 |
+
early_stopping=early_stopping, # False, True, "never"
|
1229 |
+
max_time=max_time,
|
1230 |
+
stopping_criteria=stopping_criteria,
|
1231 |
+
)
|
1232 |
+
if 'gpt2' in base_model.lower():
|
1233 |
+
gen_kwargs.update(dict(bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.eos_token_id))
|
1234 |
+
elif 'mbart-' in base_model.lower():
|
1235 |
+
assert tgt_lang is not None
|
1236 |
+
tgt_lang = languages_covered()[tgt_lang]
|
1237 |
+
gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
|
1238 |
+
else:
|
1239 |
+
gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
|
1240 |
+
|
1241 |
+
decoder = functools.partial(tokenizer.decode,
|
1242 |
+
skip_special_tokens=True,
|
1243 |
+
clean_up_tokenization_spaces=True,
|
1244 |
+
)
|
1245 |
+
decoder_raw = functools.partial(tokenizer.decode,
|
1246 |
+
skip_special_tokens=False,
|
1247 |
+
clean_up_tokenization_spaces=True,
|
1248 |
+
)
|
1249 |
+
|
1250 |
+
with torch.no_grad():
|
1251 |
+
# decoded tokenized prompt can deviate from prompt due to special characters
|
1252 |
+
inputs_decoded = decoder(input_ids[0])
|
1253 |
+
inputs_decoded_raw = decoder_raw(input_ids[0])
|
1254 |
+
if inputs_decoded == prompt:
|
1255 |
+
# normal
|
1256 |
+
pass
|
1257 |
+
elif inputs_decoded.lstrip() == prompt.lstrip():
|
1258 |
+
# sometimes extra space in front, make prompt same for prompt removal
|
1259 |
+
prompt = inputs_decoded
|
1260 |
+
elif inputs_decoded_raw == prompt:
|
1261 |
+
# some models specify special tokens that are part of normal prompt, so can't skip them
|
1262 |
+
inputs_decoded_raw = inputs_decoded
|
1263 |
+
decoder = decoder_raw
|
1264 |
+
else:
|
1265 |
+
print("WARNING: Special characters in prompt", flush=True)
|
1266 |
+
if stream_output:
|
1267 |
+
def generate(callback=None, **kwargs):
|
1268 |
+
# re-order stopping so Stream first and get out all chunks before stop for other reasons
|
1269 |
+
stopping_criteria0 = kwargs.get('stopping_criteria', StoppingCriteriaList()).copy()
|
1270 |
+
kwargs['stopping_criteria'] = StoppingCriteriaList()
|
1271 |
+
kwargs['stopping_criteria'].append(Stream(func=callback))
|
1272 |
+
for stopping_criteria1 in stopping_criteria0:
|
1273 |
+
kwargs['stopping_criteria'].append(stopping_criteria1)
|
1274 |
+
|
1275 |
+
try:
|
1276 |
+
model.generate(**kwargs)
|
1277 |
+
except torch.cuda.OutOfMemoryError as e:
|
1278 |
+
print("GPU OOM: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)), flush=True)
|
1279 |
+
if kwargs['input_ids'] is not None:
|
1280 |
+
kwargs['input_ids'].cpu()
|
1281 |
+
kwargs['input_ids'] = None
|
1282 |
+
traceback.print_exc()
|
1283 |
+
clear_torch_cache()
|
1284 |
+
return
|
1285 |
+
|
1286 |
+
for output in CallbackToGenerator(generate, callback=None, **gen_kwargs):
|
1287 |
+
decoded_output = decoder(output)
|
1288 |
+
if output[-1] in [tokenizer.eos_token_id]:
|
1289 |
+
if debug:
|
1290 |
+
print("HIT EOS", flush=True)
|
1291 |
+
break
|
1292 |
+
if any(ele in decoded_output for ele in hard_stop_list):
|
1293 |
+
raise StopIteration
|
1294 |
+
yield prompter.get_response(decoded_output, prompt=inputs_decoded,
|
1295 |
+
sanitize_bot_response=sanitize_bot_response)
|
1296 |
+
return
|
1297 |
+
else:
|
1298 |
+
outputs = model.generate(**gen_kwargs)
|
1299 |
+
outputs = [decoder(s) for s in outputs.sequences]
|
1300 |
+
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
1301 |
+
sanitize_bot_response=sanitize_bot_response)
|
1302 |
+
|
1303 |
+
|
1304 |
+
def get_generate_params(model_lower, chat,
|
1305 |
+
stream_output, show_examples,
|
1306 |
+
prompt_type, temperature, top_p, top_k, num_beams,
|
1307 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time,
|
1308 |
+
repetition_penalty, num_return_sequences,
|
1309 |
+
do_sample):
|
1310 |
+
use_defaults = False
|
1311 |
+
use_default_examples = True
|
1312 |
+
examples = []
|
1313 |
+
task_info = f"{prompt_type}"
|
1314 |
+
if model_lower:
|
1315 |
+
print(f"Using Model {model_lower}", flush=True)
|
1316 |
+
else:
|
1317 |
+
print("No model defined yet", flush=True)
|
1318 |
+
|
1319 |
+
min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
|
1320 |
+
early_stopping = early_stopping if early_stopping is not None else False
|
1321 |
+
max_time_defaults = 60 * 3
|
1322 |
+
max_time = max_time if max_time is not None else max_time_defaults
|
1323 |
+
|
1324 |
+
if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
|
1325 |
+
prompt_type = inv_prompt_type_to_model_lower[model_lower]
|
1326 |
+
|
1327 |
+
if show_examples is None:
|
1328 |
+
if chat:
|
1329 |
+
show_examples = False
|
1330 |
+
else:
|
1331 |
+
show_examples = True
|
1332 |
+
|
1333 |
+
summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker?
|
1334 |
+
Philipp: Sure you can use the new Hugging Face Deep Learning Container.
|
1335 |
+
Jeff: ok.
|
1336 |
+
Jeff: and how can I get started?
|
1337 |
+
Jeff: where can I find documentation?
|
1338 |
+
Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face"""
|
1339 |
+
|
1340 |
+
if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower:
|
1341 |
+
placeholder_instruction = summarize_example1
|
1342 |
+
placeholder_input = ""
|
1343 |
+
use_defaults = True
|
1344 |
+
use_default_examples = False
|
1345 |
+
examples += [
|
1346 |
+
[placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
|
1347 |
+
1.0, 1,
|
1348 |
+
False]]
|
1349 |
+
task_info = "Summarization"
|
1350 |
+
elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower:
|
1351 |
+
placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?"
|
1352 |
+
placeholder_input = ""
|
1353 |
+
use_defaults = True
|
1354 |
+
use_default_examples = True
|
1355 |
+
task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc. Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)"
|
1356 |
+
elif 'mbart-' in model_lower:
|
1357 |
+
placeholder_instruction = "The girl has long hair."
|
1358 |
+
placeholder_input = ""
|
1359 |
+
use_defaults = True
|
1360 |
+
use_default_examples = False
|
1361 |
+
examples += [
|
1362 |
+
[placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
|
1363 |
+
1.0, 1,
|
1364 |
+
False]]
|
1365 |
+
elif 'gpt2' in model_lower:
|
1366 |
+
placeholder_instruction = "The sky is"
|
1367 |
+
placeholder_input = ""
|
1368 |
+
prompt_type = prompt_type or 'plain'
|
1369 |
+
use_default_examples = True # some will be odd "continuations" but can be ok
|
1370 |
+
examples += [
|
1371 |
+
[placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
|
1372 |
+
1.0, 1,
|
1373 |
+
False]]
|
1374 |
+
task_info = "Auto-complete phrase, code, etc."
|
1375 |
+
use_defaults = True
|
1376 |
+
else:
|
1377 |
+
if chat:
|
1378 |
+
placeholder_instruction = "Enter a question or imperative."
|
1379 |
+
else:
|
1380 |
+
placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
|
1381 |
+
placeholder_input = ""
|
1382 |
+
if model_lower:
|
1383 |
+
prompt_type = prompt_type or 'human_bot'
|
1384 |
+
else:
|
1385 |
+
prompt_type = ''
|
1386 |
+
examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
|
1387 |
+
stream_output, prompt_type or 'plain', 0.1, 0.75, 40, 4, 256, 0, False, max_time_defaults, 1.0, 1, False]]
|
1388 |
+
task_info = "No task"
|
1389 |
+
if prompt_type == 'instruct':
|
1390 |
+
task_info = "Answer question or follow imperative as instruction with optionally input."
|
1391 |
+
elif prompt_type == 'plain':
|
1392 |
+
task_info = "Auto-complete phrase, code, etc."
|
1393 |
+
elif prompt_type == 'human_bot':
|
1394 |
+
if chat:
|
1395 |
+
task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)"
|
1396 |
+
else:
|
1397 |
+
task_info = "Ask question/imperative (input concatenated with instruction)"
|
1398 |
+
|
1399 |
+
# revert to plain if still nothing
|
1400 |
+
prompt_type = prompt_type or 'plain'
|
1401 |
+
if use_defaults:
|
1402 |
+
temperature = 1.0 if temperature is None else temperature
|
1403 |
+
top_p = 1.0 if top_p is None else top_p
|
1404 |
+
top_k = 40 if top_k is None else top_k
|
1405 |
+
num_beams = num_beams or 1
|
1406 |
+
max_new_tokens = max_new_tokens or 128
|
1407 |
+
repetition_penalty = repetition_penalty or 1.07
|
1408 |
+
num_return_sequences = min(num_beams, num_return_sequences or 1)
|
1409 |
+
do_sample = False if do_sample is None else do_sample
|
1410 |
+
else:
|
1411 |
+
temperature = 0.1 if temperature is None else temperature
|
1412 |
+
top_p = 0.75 if top_p is None else top_p
|
1413 |
+
top_k = 40 if top_k is None else top_k
|
1414 |
+
if chat:
|
1415 |
+
num_beams = num_beams or 1
|
1416 |
+
else:
|
1417 |
+
num_beams = num_beams or 4
|
1418 |
+
max_new_tokens = max_new_tokens or 256
|
1419 |
+
repetition_penalty = repetition_penalty or 1.07
|
1420 |
+
num_return_sequences = min(num_beams, num_return_sequences or 1)
|
1421 |
+
do_sample = False if do_sample is None else do_sample
|
1422 |
+
params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
|
1423 |
+
early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
|
1424 |
+
|
1425 |
+
if use_default_examples:
|
1426 |
+
examples += [
|
1427 |
+
["Translate English to French", "Good morning"] + params_list,
|
1428 |
+
["Give detailed answer for whether Einstein or Newton is smarter.", ''] + params_list,
|
1429 |
+
["Explain in detailed list, all the best practices for coding in python.", ''] + params_list,
|
1430 |
+
[
|
1431 |
+
"Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.",
|
1432 |
+
''] + params_list,
|
1433 |
+
['Translate to German: My name is Arthur', ''] + params_list,
|
1434 |
+
["Please answer to the following question. Who is going to be the next Ballon d'or?", ''] + params_list,
|
1435 |
+
['Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.',
|
1436 |
+
''] + params_list,
|
1437 |
+
['Please answer the following question. What is the boiling point of Nitrogen?', ''] + params_list,
|
1438 |
+
['Answer the following yes/no question. Can you write a whole Haiku in a single tweet?', ''] + params_list,
|
1439 |
+
["Simplify the following expression: (False or False and True). Explain your answer.", ''] + params_list,
|
1440 |
+
[
|
1441 |
+
"Premise: At my age you will probably have learnt one lesson. Hypothesis: It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?",
|
1442 |
+
''] + params_list,
|
1443 |
+
['The square root of x is the cube root of y. What is y to the power of 2, if x = 4?', ''] + params_list,
|
1444 |
+
[
|
1445 |
+
'Answer the following question by reasoning step by step. The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?',
|
1446 |
+
''] + params_list,
|
1447 |
+
["""def area_of_rectangle(a: float, b: float):
|
1448 |
+
\"\"\"Return the area of the rectangle.\"\"\"""", ''] + params_list,
|
1449 |
+
["""# a function in native python:
|
1450 |
+
def mean(a):
|
1451 |
+
return sum(a)/len(a)
|
1452 |
+
|
1453 |
+
# the same function using numpy:
|
1454 |
+
import numpy as np
|
1455 |
+
def mean(a):""", ''] + params_list,
|
1456 |
+
["""X = np.random.randn(100, 100)
|
1457 |
+
y = np.random.randint(0, 1, 100)
|
1458 |
+
|
1459 |
+
# fit random forest classifier with 20 estimators""", ''] + params_list,
|
1460 |
+
]
|
1461 |
+
|
1462 |
+
src_lang = "English"
|
1463 |
+
tgt_lang = "Russian"
|
1464 |
+
|
1465 |
+
return placeholder_instruction, placeholder_input, \
|
1466 |
+
stream_output, show_examples, \
|
1467 |
+
prompt_type, temperature, top_p, top_k, num_beams, \
|
1468 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time, \
|
1469 |
+
repetition_penalty, num_return_sequences, \
|
1470 |
+
do_sample, \
|
1471 |
+
src_lang, tgt_lang, \
|
1472 |
+
examples, \
|
1473 |
+
task_info
|
1474 |
+
|
1475 |
+
|
1476 |
+
def languages_covered():
|
1477 |
+
# https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered
|
1478 |
+
covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)"""
|
1479 |
+
covered = covered.split(', ')
|
1480 |
+
covered = {x.split(' ')[0]: x.split(' ')[1].replace(')', '').replace('(', '') for x in covered}
|
1481 |
+
return covered
|
1482 |
+
|
1483 |
+
|
1484 |
+
def test_test_prompt(prompt_type='instruct', data_point=0):
|
1485 |
+
example_data_point = example_data_points[data_point]
|
1486 |
+
example_data_point.pop('output', None)
|
1487 |
+
return generate_prompt(example_data_point, prompt_type, False, False)
|
1488 |
+
|
1489 |
+
|
1490 |
+
if __name__ == "__main__":
|
1491 |
+
print("""
|
1492 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
|
1493 |
+
python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
|
1494 |
+
python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
|
1495 |
+
|
1496 |
+
# generate without lora weights, no prompt
|
1497 |
+
python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain'
|
1498 |
+
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq'
|
1499 |
+
|
1500 |
+
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq'
|
1501 |
+
# OpenChatKit settings:
|
1502 |
+
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0
|
1503 |
+
|
1504 |
+
python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False
|
1505 |
+
python generate.py --base_model='t5-large' --prompt_type='simple_instruct'
|
1506 |
+
python generate.py --base_model='philschmid/bart-large-cnn-samsum'
|
1507 |
+
python generate.py --base_model='philschmid/flan-t5-base-samsum'
|
1508 |
+
python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt'
|
1509 |
+
|
1510 |
+
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
|
1511 |
+
|
1512 |
+
""", flush=True)
|
1513 |
+
fire.Fire(main)
|
client_test.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Client test. Simplest case is chat=False and stream_output=False
|
3 |
+
|
4 |
+
Run server with same choices:
|
5 |
+
|
6 |
+
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-256-6.9b --chat=False --stream_output=False
|
7 |
+
|
8 |
+
NOTE: For private models, add --use-auth_token=True
|
9 |
+
|
10 |
+
NOTE: --infer_devices=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
|
11 |
+
Currently, this will force model to be on a single GPU.
|
12 |
+
|
13 |
+
Then run this client as:
|
14 |
+
|
15 |
+
python client_test.py
|
16 |
+
"""
|
17 |
+
|
18 |
+
debug = False
|
19 |
+
|
20 |
+
import time
|
21 |
+
import os
|
22 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
23 |
+
from gradio_client import Client
|
24 |
+
|
25 |
+
client = Client("http://localhost:7860")
|
26 |
+
if debug:
|
27 |
+
print(client.view_api(all_endpoints=True))
|
28 |
+
|
29 |
+
instruction = "Who are you?"
|
30 |
+
iinput = ''
|
31 |
+
context = ''
|
32 |
+
# streaming output is supported, loops over and outputs each generation in streaming mode
|
33 |
+
# but leave stream_output=False for simple input/output mode
|
34 |
+
stream_output = False
|
35 |
+
prompt_type = 'human_bot'
|
36 |
+
temperature = 0.1
|
37 |
+
top_p = 0.75
|
38 |
+
top_k = 40
|
39 |
+
num_beams = 1
|
40 |
+
max_new_tokens = 500
|
41 |
+
min_new_tokens = 0
|
42 |
+
early_stopping = False
|
43 |
+
max_time = 180
|
44 |
+
repetition_penalty = 1.0
|
45 |
+
num_return_sequences = 1
|
46 |
+
do_sample = True
|
47 |
+
|
48 |
+
# CHOOSE: must match server
|
49 |
+
# NOTE chat mode works through files on gradio
|
50 |
+
# and client currently would have to work through those files
|
51 |
+
# in tmp, so not best for client. So default to False
|
52 |
+
chat = False
|
53 |
+
|
54 |
+
|
55 |
+
def test_client_basic():
|
56 |
+
args = [instruction,
|
57 |
+
iinput,
|
58 |
+
context,
|
59 |
+
stream_output,
|
60 |
+
prompt_type,
|
61 |
+
temperature,
|
62 |
+
top_p,
|
63 |
+
top_k,
|
64 |
+
num_beams,
|
65 |
+
max_new_tokens,
|
66 |
+
min_new_tokens,
|
67 |
+
early_stopping,
|
68 |
+
max_time,
|
69 |
+
repetition_penalty,
|
70 |
+
num_return_sequences,
|
71 |
+
do_sample]
|
72 |
+
|
73 |
+
if not chat:
|
74 |
+
# requires generate.py to run with --chat=False
|
75 |
+
api_name = '/submit'
|
76 |
+
res = client.predict(
|
77 |
+
*tuple(args),
|
78 |
+
api_name=api_name,
|
79 |
+
)
|
80 |
+
print(md_to_text(res))
|
81 |
+
else:
|
82 |
+
api_name = '/instruction'
|
83 |
+
import json
|
84 |
+
foofile = '/tmp/foo.json'
|
85 |
+
with open(foofile, 'wt') as f:
|
86 |
+
json.dump([['', None]], f)
|
87 |
+
args += [foofile]
|
88 |
+
if not stream_output:
|
89 |
+
for res in client.predict(
|
90 |
+
*tuple(args),
|
91 |
+
api_name=api_name,
|
92 |
+
):
|
93 |
+
print(res)
|
94 |
+
res_file = client.predict(*tuple(args), api_name='/instruction_bot')
|
95 |
+
res = json.load(open(res_file, "rt"))[-1][-1]
|
96 |
+
print(md_to_text(res))
|
97 |
+
else:
|
98 |
+
print("streaming instruction_bot", flush=True)
|
99 |
+
job = client.submit(*tuple(args), api_name='/instruction_bot')
|
100 |
+
while not job.done():
|
101 |
+
outputs_list = job.communicator.job.outputs
|
102 |
+
if outputs_list:
|
103 |
+
res_file = job.communicator.job.outputs[-1]
|
104 |
+
res = json.load(open(res_file, "rt"))[-1][-1]
|
105 |
+
print(md_to_text(res))
|
106 |
+
time.sleep(0.1)
|
107 |
+
print(job.outputs())
|
108 |
+
|
109 |
+
|
110 |
+
import markdown # pip install markdown
|
111 |
+
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
112 |
+
|
113 |
+
|
114 |
+
def md_to_text(md):
|
115 |
+
html = markdown.markdown(md)
|
116 |
+
soup = BeautifulSoup(html, features='html.parser')
|
117 |
+
return soup.get_text()
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
test_client_basic()
|
finetune.py
ADDED
@@ -0,0 +1,930 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from datetime import datetime
|
9 |
+
from typing import List, Union
|
10 |
+
import fire
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from datasets import load_dataset, concatenate_datasets
|
14 |
+
import transformers
|
15 |
+
import torch.distributed as dist
|
16 |
+
|
17 |
+
from peft import (
|
18 |
+
prepare_model_for_int8_training,
|
19 |
+
LoraConfig,
|
20 |
+
get_peft_model,
|
21 |
+
get_peft_model_state_dict,
|
22 |
+
set_peft_model_state_dict,
|
23 |
+
)
|
24 |
+
|
25 |
+
from peft import mapping
|
26 |
+
lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
27 |
+
|
28 |
+
|
29 |
+
def log(*args, **kwargs):
|
30 |
+
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
31 |
+
print(*args, **kwargs)
|
32 |
+
|
33 |
+
|
34 |
+
try:
|
35 |
+
import neptune
|
36 |
+
from transformers.integrations import NeptuneCallback
|
37 |
+
|
38 |
+
neptune_run = neptune.init_run(
|
39 |
+
source_files=[],
|
40 |
+
)
|
41 |
+
log("Connected to Neptune.")
|
42 |
+
except ImportError:
|
43 |
+
neptune_run = None
|
44 |
+
log("Please pip install neptune for tracking.")
|
45 |
+
except neptune.exceptions.NeptuneMissingApiTokenException:
|
46 |
+
neptune_run = None
|
47 |
+
os.environ["NEPTUNE_MODE"] = 'debug'
|
48 |
+
log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
|
49 |
+
|
50 |
+
from enum import Enum
|
51 |
+
|
52 |
+
|
53 |
+
class PromptType(Enum):
|
54 |
+
plain = 0
|
55 |
+
instruct = 1
|
56 |
+
quality = 2
|
57 |
+
human_bot = 3
|
58 |
+
dai_faq = 4
|
59 |
+
summarize = 5
|
60 |
+
simple_instruct = 6
|
61 |
+
instruct_vicuna = 7
|
62 |
+
instruct_with_end = 8
|
63 |
+
human_bot_orig = 9
|
64 |
+
|
65 |
+
|
66 |
+
prompt_type_to_model_name = {
|
67 |
+
'plain': [
|
68 |
+
'EleutherAI/gpt-j-6B',
|
69 |
+
'EleutherAI/pythia-6.9b',
|
70 |
+
'EleutherAI/pythia-12b',
|
71 |
+
'EleutherAI/pythia-12b-deduped',
|
72 |
+
'EleutherAI/gpt-neox-20b',
|
73 |
+
'decapoda-research/llama-7b-hf',
|
74 |
+
'decapoda-research/llama-13b-hf',
|
75 |
+
'decapoda-research/llama-30b-hf',
|
76 |
+
'facebook/mbart-large-50-many-to-many-mmt',
|
77 |
+
'philschmid/bart-large-cnn-samsum',
|
78 |
+
'philschmid/flan-t5-base-samsum',
|
79 |
+
'gpt2',
|
80 |
+
'distilgpt2',
|
81 |
+
],
|
82 |
+
'instruct': [],
|
83 |
+
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
84 |
+
'quality': [],
|
85 |
+
'human_bot': [
|
86 |
+
'h2oai/h2ogpt-oig-oasst1-256-12b',
|
87 |
+
'h2oai/h2ogpt-oasst1-512-12b',
|
88 |
+
'h2oai/h2ogpt-oasst1-256-20b',
|
89 |
+
'h2oai/h2ogpt-oasst1-512-20b',
|
90 |
+
'h2oai/h2ogpt-oig-oasst1-256-6.9b',
|
91 |
+
],
|
92 |
+
'dai_faq': [],
|
93 |
+
'summarize': [],
|
94 |
+
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
95 |
+
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
|
96 |
+
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
97 |
+
}
|
98 |
+
|
99 |
+
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
100 |
+
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
101 |
+
|
102 |
+
human = '<human>:'
|
103 |
+
bot = "<bot>:"
|
104 |
+
|
105 |
+
prompt_types_strings = []
|
106 |
+
for p in PromptType:
|
107 |
+
prompt_types_strings.extend([p.name])
|
108 |
+
|
109 |
+
|
110 |
+
prompt_types = []
|
111 |
+
for p in PromptType:
|
112 |
+
prompt_types.extend([p.name, p.value, str(p.value)])
|
113 |
+
|
114 |
+
|
115 |
+
# supported by huggingface evaluate
|
116 |
+
supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
|
117 |
+
|
118 |
+
|
119 |
+
def train(
|
120 |
+
save_code: bool = False,
|
121 |
+
run_id: int = None,
|
122 |
+
|
123 |
+
base_model: str = 'EleutherAI/gpt-neox-20b',
|
124 |
+
# base_model: str = 'EleutherAI/pythia-12b-deduped',
|
125 |
+
# base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
|
126 |
+
# base_model: str = 'decapoda-research/llama-7b-hf',
|
127 |
+
# base_model: str = 'decapoda-research/llama-13b-hf',
|
128 |
+
# base_model: str = 'decapoda-research/llama-30b-hf',
|
129 |
+
# base_model: str = 'EleutherAI/gpt-j-6B',
|
130 |
+
|
131 |
+
# only needed if base_model is self-exported HF state without tokenizer
|
132 |
+
tokenizer_base_model: str = None,
|
133 |
+
# tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
|
134 |
+
|
135 |
+
data_path: str = None,
|
136 |
+
data_col_dict: dict = None,
|
137 |
+
# data_path: str = "./dai_docs.train.json",
|
138 |
+
prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
|
139 |
+
|
140 |
+
valid_path: str = None,
|
141 |
+
# valid_path: str = "./dai_docs.valid.json",
|
142 |
+
|
143 |
+
# data_mix_in_path: str = "laion/OIG", # way too big, medium quality
|
144 |
+
data_mix_in_path: str = "0-hero/OIG-small-chip2", # high quality, 50 MB, good enough for now
|
145 |
+
data_mix_in_factor: float = 0.0, # >1: more mix-in data, <1: more of data_path data
|
146 |
+
data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
|
147 |
+
data_mix_in_prompt_type: str = "instruct", # just instruction->output, same as instruct
|
148 |
+
|
149 |
+
output_dir: str = None,
|
150 |
+
|
151 |
+
# LoRA checkpoint continuation
|
152 |
+
lora_weights: str = "",
|
153 |
+
|
154 |
+
# batching training hyperparams
|
155 |
+
batch_size: int = 128,
|
156 |
+
micro_batch_size: int = 4,
|
157 |
+
gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
|
158 |
+
fp16=True,
|
159 |
+
|
160 |
+
# general training hyperparams
|
161 |
+
num_epochs: float = 1,
|
162 |
+
learning_rate: float = 3e-4,
|
163 |
+
|
164 |
+
# validation settings
|
165 |
+
val_set_size: int = None,
|
166 |
+
val_metrics: List[str] = [],
|
167 |
+
eval_steps: int = None, # to control eval steps via steps
|
168 |
+
eval_epochs: float = None, # to control eval steps via epochs
|
169 |
+
|
170 |
+
# lora hyperparams
|
171 |
+
lora_r: int = 8,
|
172 |
+
lora_alpha: int = 16,
|
173 |
+
lora_dropout: float = 0.05,
|
174 |
+
lora_target_modules: List[str] = None,
|
175 |
+
llama_type: bool = None,
|
176 |
+
|
177 |
+
# llm hyperparams
|
178 |
+
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
179 |
+
group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
|
180 |
+
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
181 |
+
cutoff_len: int = 1024, # Good default, especially when have high quality non-trivial data
|
182 |
+
|
183 |
+
# torch training params
|
184 |
+
ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
|
185 |
+
local_files_only: bool = False, # else will download new versions, normally unwanted
|
186 |
+
resume_download: bool = True,
|
187 |
+
use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
|
188 |
+
warmup_steps: int = 100,
|
189 |
+
logging_steps: int = 1,
|
190 |
+
save_steps: int = None, # must be round multiple of eval_steps
|
191 |
+
add_eos_token: bool = False,
|
192 |
+
):
|
193 |
+
# allow set token directly
|
194 |
+
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
195 |
+
|
196 |
+
prompt_type = str(prompt_type) # migration from integers
|
197 |
+
assert prompt_type in prompt_types
|
198 |
+
|
199 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
200 |
+
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
201 |
+
rank = int(os.getenv("RANK", 0))
|
202 |
+
print(f"local_rank: {local_rank}")
|
203 |
+
print(f"global rank: {rank}")
|
204 |
+
|
205 |
+
gpus = max(world_size, torch.cuda.device_count())
|
206 |
+
run_id = run_id or 0
|
207 |
+
if not data_path:
|
208 |
+
raise ValueError("No data_path provided")
|
209 |
+
if not output_dir:
|
210 |
+
output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
|
211 |
+
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
212 |
+
raise FileExistsError(f"output_dir based on run_id {run_id} already exists. Please pick a different run_id.")
|
213 |
+
else:
|
214 |
+
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
215 |
+
raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
|
216 |
+
device_map = "auto"
|
217 |
+
|
218 |
+
if save_code:
|
219 |
+
copy_code(run_id)
|
220 |
+
if tokenizer_base_model is None:
|
221 |
+
tokenizer_base_model = base_model
|
222 |
+
if llama_type is None:
|
223 |
+
llama_type = "llama" in base_model.lower()
|
224 |
+
assert (
|
225 |
+
base_model
|
226 |
+
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
227 |
+
gradient_accumulation_steps = batch_size // micro_batch_size
|
228 |
+
assert gradient_accumulation_steps >= world_size, "must increase batch_size for multi-GPU"
|
229 |
+
|
230 |
+
device_map = "auto"
|
231 |
+
|
232 |
+
locals_dict = locals()
|
233 |
+
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
234 |
+
log(f"Training model with params:\n{locals_print}")
|
235 |
+
log("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()))
|
236 |
+
|
237 |
+
max_memory = None
|
238 |
+
if gpus > 1:
|
239 |
+
if ddp:
|
240 |
+
log("Distributed: data parallel")
|
241 |
+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
242 |
+
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
243 |
+
else:
|
244 |
+
free_in_GB = int(min(torch.cuda.mem_get_info()) / 1024 ** 3)
|
245 |
+
max_memory = f"{free_in_GB - 2}GB"
|
246 |
+
max_memory = {i: max_memory for i in range(gpus)}
|
247 |
+
log("world_size: %d" % world_size)
|
248 |
+
log("num_gpus: %d" % gpus)
|
249 |
+
log("max mem: %s" % max_memory)
|
250 |
+
|
251 |
+
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
|
252 |
+
|
253 |
+
model = model_loader.from_pretrained(
|
254 |
+
base_model,
|
255 |
+
load_in_8bit=True,
|
256 |
+
device_map=device_map,
|
257 |
+
torch_dtype=torch.float16,
|
258 |
+
max_memory=max_memory,
|
259 |
+
local_files_only=local_files_only,
|
260 |
+
resume_download=resume_download,
|
261 |
+
use_auth_token=use_auth_token,
|
262 |
+
)
|
263 |
+
if gpus > 1:
|
264 |
+
if not ddp:
|
265 |
+
log("model parallel")
|
266 |
+
model.is_parallelizable = True
|
267 |
+
model.model_parallel = True
|
268 |
+
|
269 |
+
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
270 |
+
local_files_only=local_files_only,
|
271 |
+
resume_download=resume_download,
|
272 |
+
use_auth_token=use_auth_token)
|
273 |
+
|
274 |
+
tokenizer.pad_token_id = 0 # different from the eos token
|
275 |
+
# when generating, we will use the logits of right-most token to predict the next token
|
276 |
+
# so the padding should be on the left,
|
277 |
+
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
|
278 |
+
tokenizer.padding_side = "left" # Allow batched inference
|
279 |
+
|
280 |
+
def tokenize(prompt, add_eos_token=True):
|
281 |
+
# there's probably a way to do this with the tokenizer settings
|
282 |
+
# but again, gotta move fast
|
283 |
+
result = tokenizer(
|
284 |
+
prompt,
|
285 |
+
truncation=True,
|
286 |
+
max_length=cutoff_len,
|
287 |
+
padding=False,
|
288 |
+
return_tensors=None,
|
289 |
+
)
|
290 |
+
if (
|
291 |
+
result["input_ids"][-1] != tokenizer.eos_token_id
|
292 |
+
and len(result["input_ids"]) < cutoff_len
|
293 |
+
and add_eos_token
|
294 |
+
):
|
295 |
+
result["input_ids"].append(tokenizer.eos_token_id)
|
296 |
+
result["attention_mask"].append(1)
|
297 |
+
|
298 |
+
result["labels"] = result["input_ids"].copy()
|
299 |
+
|
300 |
+
return result
|
301 |
+
|
302 |
+
def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
|
303 |
+
full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
|
304 |
+
tokenized_full_prompt = tokenize(full_prompt)
|
305 |
+
if not train_on_inputs:
|
306 |
+
user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
|
307 |
+
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos)
|
308 |
+
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
309 |
+
if add_eos:
|
310 |
+
user_prompt_len -= 1
|
311 |
+
|
312 |
+
# ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
|
313 |
+
tokenized_full_prompt["labels"] = [
|
314 |
+
-100
|
315 |
+
] * user_prompt_len + tokenized_full_prompt["labels"][
|
316 |
+
user_prompt_len:
|
317 |
+
] # could be sped up, probably
|
318 |
+
return tokenized_full_prompt
|
319 |
+
|
320 |
+
if "gpt-neox" not in base_model or True:
|
321 |
+
model = prepare_model_for_int8_training(model)
|
322 |
+
else:
|
323 |
+
model = prepare_model_for_int8_training(
|
324 |
+
model,
|
325 |
+
output_embedding_layer_name="embed_out", # keep output logits in float32
|
326 |
+
layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
|
327 |
+
)
|
328 |
+
if lora_weights:
|
329 |
+
from peft import PeftModel
|
330 |
+
model = PeftModel.from_pretrained(
|
331 |
+
model,
|
332 |
+
lora_weights,
|
333 |
+
torch_dtype=torch.float16,
|
334 |
+
device_map=device_map,
|
335 |
+
local_files_only=local_files_only,
|
336 |
+
resume_download=resume_download,
|
337 |
+
use_auth_token=use_auth_token,
|
338 |
+
)
|
339 |
+
else:
|
340 |
+
if lora_target_modules is None:
|
341 |
+
base_model_lower = base_model.lower()
|
342 |
+
if base_model_lower in lora_mappings:
|
343 |
+
lora_target_modules_cand = [lora_mappings[base_model_lower]]
|
344 |
+
else:
|
345 |
+
lora_target_modules_cand = [["query_key_value"], ["q_proj", "v_proj"]]
|
346 |
+
else:
|
347 |
+
lora_target_modules_cand = [lora_target_modules]
|
348 |
+
|
349 |
+
for lora_target_modules in lora_target_modules_cand:
|
350 |
+
try:
|
351 |
+
config = LoraConfig(
|
352 |
+
r=lora_r,
|
353 |
+
lora_alpha=lora_alpha,
|
354 |
+
target_modules=lora_target_modules,
|
355 |
+
lora_dropout=lora_dropout,
|
356 |
+
bias="none",
|
357 |
+
task_type="CAUSAL_LM",
|
358 |
+
)
|
359 |
+
model = get_peft_model(model, config)
|
360 |
+
break
|
361 |
+
except ValueError as e:
|
362 |
+
if "Target modules" in str(e) and "not found" in str(e):
|
363 |
+
continue
|
364 |
+
else:
|
365 |
+
raise
|
366 |
+
from peft import PeftModel
|
367 |
+
assert isinstance(model, PeftModel), "LoRA failed. Please provide --lora_target_modules explicitly."
|
368 |
+
if resume_from_checkpoint:
|
369 |
+
# Check the available weights and load them
|
370 |
+
checkpoint_name = os.path.join(
|
371 |
+
resume_from_checkpoint, "pytorch_model.bin"
|
372 |
+
) # Full checkpoint
|
373 |
+
if not os.path.exists(checkpoint_name):
|
374 |
+
checkpoint_name = os.path.join(
|
375 |
+
resume_from_checkpoint, "adapter_model.bin"
|
376 |
+
) # only LoRA model - LoRA config above has to fit
|
377 |
+
resume_from_checkpoint = False # So the trainer won't try loading its state
|
378 |
+
# The two files above have a different name depending on how they were saved, but are actually the same.
|
379 |
+
if os.path.exists(checkpoint_name):
|
380 |
+
log(f"Restarting from {checkpoint_name}")
|
381 |
+
adapters_weights = torch.load(checkpoint_name)
|
382 |
+
model = set_peft_model_state_dict(model, adapters_weights)
|
383 |
+
else:
|
384 |
+
log(f"Checkpoint {checkpoint_name} not found")
|
385 |
+
|
386 |
+
print(model)
|
387 |
+
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
388 |
+
|
389 |
+
metrics = {}
|
390 |
+
for name in supported_metrics:
|
391 |
+
if name in val_metrics:
|
392 |
+
import evaluate # Causes hang for 'python generate.py' on dual 4090 if imported early, 100% reproducible
|
393 |
+
metrics[name] = evaluate.load(name)
|
394 |
+
log("Using Validation Metrics: %s" % str(list(metrics.keys())))
|
395 |
+
log("Supported Metrics: %s" % supported_metrics)
|
396 |
+
|
397 |
+
if val_set_size is None:
|
398 |
+
if len(metrics) == 0:
|
399 |
+
val_set_size = 1000
|
400 |
+
else:
|
401 |
+
val_set_size = 100
|
402 |
+
log("Auto set val_set_size %s" % val_set_size)
|
403 |
+
elif val_set_size < 1.0 and val_set_size != 0:
|
404 |
+
raise RuntimeError("Fractional validation size not supported.")
|
405 |
+
|
406 |
+
if valid_path:
|
407 |
+
data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
|
408 |
+
else:
|
409 |
+
if "json" in data_path:
|
410 |
+
data = load_dataset("json", data_files={"train": data_path})
|
411 |
+
else:
|
412 |
+
data = load_dataset(data_path)
|
413 |
+
data = data.rename_columns(data_col_dict or {})
|
414 |
+
|
415 |
+
valid_data = None
|
416 |
+
train_data_mix_in = None
|
417 |
+
valid_data_mix_in = None
|
418 |
+
|
419 |
+
if data_mix_in_path and data_mix_in_factor > 0:
|
420 |
+
# get mix-in training/validation data - to keep model "sane"
|
421 |
+
num_rows = data["train"].num_rows
|
422 |
+
log("Loading mix-in dataset: %s" % data_mix_in_path)
|
423 |
+
if "json" in data_mix_in_path:
|
424 |
+
data_mix_in = load_dataset("json", data_files={"train": data_mix_in_path})["train"]
|
425 |
+
else:
|
426 |
+
data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
|
427 |
+
data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
|
428 |
+
|
429 |
+
# only get as much as we need to balance
|
430 |
+
valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
|
431 |
+
train_size = max(1, min(data_mix_in.num_rows - valid_size, int(num_rows * data_mix_in_factor)))
|
432 |
+
mixin_small = data_mix_in.train_test_split(
|
433 |
+
test_size=train_size + valid_size,
|
434 |
+
shuffle=True, seed=np.random.randint(10000),
|
435 |
+
)["test"]
|
436 |
+
if valid_size:
|
437 |
+
mixin_train_test = mixin_small.train_test_split(
|
438 |
+
test_size=valid_size, shuffle=False,
|
439 |
+
)
|
440 |
+
train_data_mix_in = mixin_train_test["train"]
|
441 |
+
valid_data_mix_in = mixin_train_test["test"]
|
442 |
+
else:
|
443 |
+
train_data_mix_in = mixin_small
|
444 |
+
|
445 |
+
if "prompt_type" not in train_data_mix_in.column_names:
|
446 |
+
train_data_mix_in = train_data_mix_in.add_column(
|
447 |
+
"prompt_type",
|
448 |
+
[data_mix_in_prompt_type] * train_data_mix_in.num_rows,
|
449 |
+
)
|
450 |
+
log("Added prompt type %s to mix-in training data" % data_mix_in_prompt_type)
|
451 |
+
if valid_data_mix_in and "prompt_type" not in valid_data_mix_in.column_names:
|
452 |
+
valid_data_mix_in = valid_data_mix_in.add_column(
|
453 |
+
"prompt_type",
|
454 |
+
[data_mix_in_prompt_type] * valid_data_mix_in.num_rows,
|
455 |
+
)
|
456 |
+
log("Added prompt type %s to mix-in validation data" % data_mix_in_prompt_type)
|
457 |
+
log("Created mix-in data:\nTrain %s\nValid %s" % (train_data_mix_in, valid_data_mix_in))
|
458 |
+
|
459 |
+
# get our own training/validation data - for fine-tuning
|
460 |
+
if val_set_size > 0 and not valid_path and not data_mix_in_path:
|
461 |
+
# create valid split from train
|
462 |
+
train_val = data["train"].train_test_split(
|
463 |
+
test_size=val_set_size, shuffle=True, seed=42
|
464 |
+
)
|
465 |
+
train_data = train_val["train"]
|
466 |
+
valid_data = train_val["test"]
|
467 |
+
else:
|
468 |
+
train_data = data["train"]
|
469 |
+
if valid_path:
|
470 |
+
# use given valid split, has priority over data_mix_in_path
|
471 |
+
valid_data = data["valid"]
|
472 |
+
if "prompt_type" not in train_data.column_names:
|
473 |
+
train_data = train_data.add_column(
|
474 |
+
"prompt_type",
|
475 |
+
[prompt_type] * train_data.num_rows,
|
476 |
+
)
|
477 |
+
log("Added prompt type %s to training data" % prompt_type)
|
478 |
+
if valid_data and "prompt_type" not in valid_data.column_names:
|
479 |
+
valid_data = valid_data.add_column(
|
480 |
+
"prompt_type",
|
481 |
+
[prompt_type] * valid_data.num_rows,
|
482 |
+
)
|
483 |
+
log("Added prompt type %s to validation data" % prompt_type)
|
484 |
+
|
485 |
+
assert train_data is not None
|
486 |
+
|
487 |
+
# shuffle and tokenize data
|
488 |
+
if train_data_mix_in:
|
489 |
+
train_data = concatenate_datasets([train_data, train_data_mix_in])
|
490 |
+
train_data = train_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
|
491 |
+
train_set_size = len(train_data)
|
492 |
+
|
493 |
+
if valid_data and valid_data_mix_in:
|
494 |
+
valid_data = concatenate_datasets([valid_data, valid_data_mix_in])
|
495 |
+
elif valid_data_mix_in:
|
496 |
+
valid_data = valid_data_mix_in
|
497 |
+
|
498 |
+
if valid_data:
|
499 |
+
valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
|
500 |
+
val_set_size = len(valid_data)
|
501 |
+
else:
|
502 |
+
val_set_size = 0
|
503 |
+
log("Final fine-tuning data:\nTrain %s\nValid %s" % (train_data, valid_data))
|
504 |
+
sample_row_dict = train_data[:1]
|
505 |
+
del sample_row_dict['input_ids']
|
506 |
+
del sample_row_dict['attention_mask']
|
507 |
+
del sample_row_dict['labels']
|
508 |
+
log("Sample input: %s" % sample_row_dict)
|
509 |
+
|
510 |
+
if neptune_run:
|
511 |
+
neptune_callback = NeptuneCallback(run=neptune_run)
|
512 |
+
callbacks = [neptune_callback]
|
513 |
+
else:
|
514 |
+
from transformers.integrations import TensorBoardCallback, is_tensorboard_available
|
515 |
+
if is_tensorboard_available:
|
516 |
+
# tensorboard --logdir=runs/
|
517 |
+
from torch.utils.tensorboard import SummaryWriter
|
518 |
+
tb_writer = SummaryWriter()
|
519 |
+
callbacks = [TensorBoardCallback(tb_writer=tb_writer)]
|
520 |
+
else:
|
521 |
+
callbacks = []
|
522 |
+
|
523 |
+
expected_steps = (train_set_size * num_epochs) // batch_size
|
524 |
+
if eval_steps is None and eval_epochs is None:
|
525 |
+
# 20 evaluations for a run
|
526 |
+
eval_steps = max(1, int(expected_steps / 20))
|
527 |
+
log("Auto set eval_steps to %s out of %s total training steps" % (eval_steps, expected_steps))
|
528 |
+
elif eval_steps is None and eval_epochs is not None:
|
529 |
+
eval_steps = max(1, int(expected_steps * eval_epochs / num_epochs))
|
530 |
+
log("Auto converted eval_epochs=%s to eval_steps %s"
|
531 |
+
" out of %s total training steps" % (eval_epochs, eval_steps, expected_steps))
|
532 |
+
if save_steps is None:
|
533 |
+
save_steps = eval_steps
|
534 |
+
log("Auto step save_steps to %s" % save_steps)
|
535 |
+
elif save_steps > eval_steps:
|
536 |
+
# save steps must be round multiple of eval_steps
|
537 |
+
save_steps0 = save_steps
|
538 |
+
save_steps = max(1, (save_steps//eval_steps)) * eval_steps
|
539 |
+
if save_steps0 != save_steps:
|
540 |
+
log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
|
541 |
+
|
542 |
+
def compute_metrics(eval_preds):
|
543 |
+
# e.g. see: https://huggingface.co/docs/transformers/v4.25.1/en/tasks/translation#evaluate
|
544 |
+
inputs = eval_preds.inputs
|
545 |
+
label_ids = eval_preds.label_ids
|
546 |
+
predictions = eval_preds.predictions
|
547 |
+
|
548 |
+
#inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
|
549 |
+
#decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
|
550 |
+
#decoded_inputs = [pred.strip() for pred in decoded_inputs]
|
551 |
+
|
552 |
+
label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
|
553 |
+
# tokenizer behavior like generate time
|
554 |
+
decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
|
555 |
+
clean_up_tokenization_spaces=True)
|
556 |
+
decoded_labels = [pred.strip() for pred in decoded_labels]
|
557 |
+
|
558 |
+
predictions = np.argmax(predictions, -1)
|
559 |
+
predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
|
560 |
+
# tokenizer behavior like generate time
|
561 |
+
decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
|
562 |
+
clean_up_tokenization_spaces=True)
|
563 |
+
decoded_predictions = [pred.strip() for pred in decoded_predictions]
|
564 |
+
|
565 |
+
result = {}
|
566 |
+
for metric in metrics.values():
|
567 |
+
result1 = metric.compute(predictions=decoded_predictions, references=decoded_labels)
|
568 |
+
# get rid of lists, for precision etc., for now
|
569 |
+
numeric_results = {k: v for k, v in result1.items() if isinstance(v, (int, float))}
|
570 |
+
result.update(numeric_results)
|
571 |
+
return result
|
572 |
+
|
573 |
+
# the callback that computes metrics of interest
|
574 |
+
if val_metrics:
|
575 |
+
trainer_kwargs = dict(compute_metrics=compute_metrics)
|
576 |
+
else:
|
577 |
+
trainer_kwargs = dict()
|
578 |
+
|
579 |
+
trainer = transformers.Trainer(
|
580 |
+
model=model,
|
581 |
+
tokenizer=tokenizer,
|
582 |
+
train_dataset=train_data,
|
583 |
+
eval_dataset=valid_data,
|
584 |
+
# NOTE: CausalLM is not supporting Seq2SeqTrainingArguments arguments, but not incompatible
|
585 |
+
args=transformers.Seq2SeqTrainingArguments(
|
586 |
+
per_device_train_batch_size=micro_batch_size,
|
587 |
+
per_device_eval_batch_size=1,
|
588 |
+
eval_accumulation_steps=10,
|
589 |
+
# predict_with_generate=True, # SEQ2SEQ only
|
590 |
+
include_inputs_for_metrics=True,
|
591 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
592 |
+
warmup_steps=warmup_steps,
|
593 |
+
num_train_epochs=num_epochs,
|
594 |
+
learning_rate=learning_rate,
|
595 |
+
gradient_checkpointing=gradient_checkpointing,
|
596 |
+
fp16=fp16,
|
597 |
+
# cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
|
598 |
+
optim="adamw_torch", # consider "adafactor" to save memory
|
599 |
+
logging_steps=logging_steps,
|
600 |
+
logging_strategy="steps",
|
601 |
+
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
602 |
+
save_strategy="steps",
|
603 |
+
eval_steps=eval_steps if val_set_size > 0 else None,
|
604 |
+
save_steps=save_steps,
|
605 |
+
output_dir=output_dir,
|
606 |
+
save_total_limit=3,
|
607 |
+
load_best_model_at_end=True if val_set_size > 0 else False,
|
608 |
+
ddp_find_unused_parameters=False if ddp else None,
|
609 |
+
group_by_length=group_by_length,
|
610 |
+
#fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
|
611 |
+
#fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
|
612 |
+
report_to='tensorboard' if not neptune_run else 'neptune',
|
613 |
+
),
|
614 |
+
data_collator=transformers.DataCollatorForSeq2Seq(
|
615 |
+
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
616 |
+
),
|
617 |
+
callbacks=callbacks,
|
618 |
+
**trainer_kwargs,
|
619 |
+
)
|
620 |
+
model.config.use_cache = False
|
621 |
+
|
622 |
+
old_state_dict = model.state_dict
|
623 |
+
model.state_dict = (
|
624 |
+
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
625 |
+
).__get__(model, type(model))
|
626 |
+
|
627 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
628 |
+
model = torch.compile(model)
|
629 |
+
# WIP (not generally replacing layers until pytorch 2.1)
|
630 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
631 |
+
|
632 |
+
if gpus > 1 and not ddp:
|
633 |
+
assert trainer.is_model_parallel
|
634 |
+
else:
|
635 |
+
assert not trainer.is_model_parallel
|
636 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
637 |
+
|
638 |
+
model.save_pretrained(output_dir)
|
639 |
+
|
640 |
+
log("\n If there's a warning about missing keys above, please disregard :)")
|
641 |
+
|
642 |
+
|
643 |
+
def get_loaders(llama_type, model_name, reward_type):
|
644 |
+
# NOTE: Some models need specific new prompt_type
|
645 |
+
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
646 |
+
if llama_type:
|
647 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
648 |
+
model_loader = LlamaForCausalLM
|
649 |
+
tokenizer_loader = LlamaTokenizer
|
650 |
+
elif 'gpt2' in model_name.lower():
|
651 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
652 |
+
return GPT2LMHeadModel, GPT2Tokenizer
|
653 |
+
elif 'mbart-' in model_name.lower():
|
654 |
+
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
655 |
+
return MBartForConditionalGeneration, MBart50TokenizerFast
|
656 |
+
elif 't5' == model_name.lower() or \
|
657 |
+
't5-' in model_name.lower() or \
|
658 |
+
'flan-' in model_name.lower():
|
659 |
+
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
660 |
+
return T5ForConditionalGeneration, AutoTokenizer
|
661 |
+
elif 'bigbird' in model_name:
|
662 |
+
from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
663 |
+
return BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
664 |
+
elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
|
665 |
+
from transformers import pipeline
|
666 |
+
return pipeline, "summarization"
|
667 |
+
elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
|
668 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
669 |
+
return AutoModelForSequenceClassification, AutoTokenizer
|
670 |
+
else:
|
671 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
672 |
+
model_loader = AutoModelForCausalLM
|
673 |
+
tokenizer_loader = AutoTokenizer
|
674 |
+
return model_loader, tokenizer_loader
|
675 |
+
|
676 |
+
|
677 |
+
def get_githash():
|
678 |
+
try:
|
679 |
+
githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
|
680 |
+
except:
|
681 |
+
githash = ''
|
682 |
+
return githash
|
683 |
+
|
684 |
+
|
685 |
+
def copy_code(run_id):
|
686 |
+
"""
|
687 |
+
copy code to track changes
|
688 |
+
:param run_id:
|
689 |
+
:return:
|
690 |
+
"""
|
691 |
+
rnd_num = str(random.randint(0, 2 ** 31))
|
692 |
+
run_id = 'run_' + str(run_id)
|
693 |
+
os.makedirs(run_id, exist_ok=True)
|
694 |
+
me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
|
695 |
+
me_file = os.path.basename(__file__)
|
696 |
+
new_me = os.path.join(run_id, me_file + '_' + get_githash())
|
697 |
+
if os.path.isfile(new_me):
|
698 |
+
new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
|
699 |
+
shutil.copy(me_full, new_me)
|
700 |
+
else:
|
701 |
+
shutil.copy(me_full, new_me)
|
702 |
+
|
703 |
+
|
704 |
+
def get_prompt(prompt_type, chat, context, reduced):
|
705 |
+
if prompt_type in [-1, "-1", "plain"]:
|
706 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
|
707 |
+
terminate_response = []
|
708 |
+
elif prompt_type == 'simple_instruct':
|
709 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
710 |
+
terminate_response = []
|
711 |
+
elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
|
712 |
+
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
|
713 |
+
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
|
714 |
+
|
715 |
+
PreInstruct = """
|
716 |
+
### Instruction:
|
717 |
+
"""
|
718 |
+
|
719 |
+
PreInput = """
|
720 |
+
### Input:
|
721 |
+
"""
|
722 |
+
|
723 |
+
PreResponse = """
|
724 |
+
### Response:
|
725 |
+
"""
|
726 |
+
if prompt_type in [7, "7", "instruct_with_end"]:
|
727 |
+
terminate_response = ['### End']
|
728 |
+
else:
|
729 |
+
terminate_response = None
|
730 |
+
elif prompt_type in [1, "1", "quality"]:
|
731 |
+
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
|
732 |
+
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
|
733 |
+
|
734 |
+
PreInstruct = """
|
735 |
+
### Instruction:
|
736 |
+
"""
|
737 |
+
|
738 |
+
PreInput = """
|
739 |
+
### Input:
|
740 |
+
"""
|
741 |
+
|
742 |
+
PreResponse = """
|
743 |
+
### Response:
|
744 |
+
"""
|
745 |
+
terminate_response = None
|
746 |
+
elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
|
747 |
+
if reduced or context or prompt_type in [2, "2", "human_bot"]:
|
748 |
+
preprompt = ''
|
749 |
+
else:
|
750 |
+
cur_date = time.strftime('%Y-%m-%d')
|
751 |
+
cur_time = time.strftime('%H:%M:%S %p %Z')
|
752 |
+
|
753 |
+
PRE_PROMPT = """\
|
754 |
+
Current Date: {}
|
755 |
+
Current Time: {}
|
756 |
+
|
757 |
+
"""
|
758 |
+
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
759 |
+
start = human
|
760 |
+
promptB = promptA = '%s%s ' % (preprompt, start)
|
761 |
+
|
762 |
+
PreInstruct = ""
|
763 |
+
|
764 |
+
PreInput = None
|
765 |
+
|
766 |
+
PreResponse = bot
|
767 |
+
|
768 |
+
terminate_response = [start, PreResponse]
|
769 |
+
elif prompt_type in [3, "3", "dai_faq"]:
|
770 |
+
promptA = ''
|
771 |
+
promptB = 'Answer the following Driverless AI question.\n'
|
772 |
+
|
773 |
+
PreInstruct = """
|
774 |
+
### Driverless AI frequently asked question:
|
775 |
+
"""
|
776 |
+
|
777 |
+
PreInput = None
|
778 |
+
|
779 |
+
PreResponse = """
|
780 |
+
### Driverless AI documentation answer:
|
781 |
+
"""
|
782 |
+
terminate_response = ['\n\n']
|
783 |
+
elif prompt_type in [5, "5", "summarize"]:
|
784 |
+
promptA = promptB = PreInput = ''
|
785 |
+
PreInstruct = '## Main Text\n\n'
|
786 |
+
PreResponse = '\n\n## Summary\n\n'
|
787 |
+
terminate_response = None
|
788 |
+
elif prompt_type in [6, "6", "instruct_vicuna"]:
|
789 |
+
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
790 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
|
791 |
+
|
792 |
+
PreInstruct = """
|
793 |
+
### Human:
|
794 |
+
"""
|
795 |
+
|
796 |
+
PreInput = None
|
797 |
+
|
798 |
+
PreResponse = """
|
799 |
+
### Assistant:
|
800 |
+
"""
|
801 |
+
terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
802 |
+
else:
|
803 |
+
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
804 |
+
|
805 |
+
return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response
|
806 |
+
|
807 |
+
|
808 |
+
def generate_prompt(data_point, prompt_type, chat, reduced):
|
809 |
+
context = data_point.get('context') if chat else ''
|
810 |
+
if context is None:
|
811 |
+
context = ''
|
812 |
+
instruction = data_point.get('instruction')
|
813 |
+
input = data_point.get('input')
|
814 |
+
output = data_point.get('output')
|
815 |
+
prompt_type = data_point.get('prompt_type', prompt_type)
|
816 |
+
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
817 |
+
promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
|
818 |
+
|
819 |
+
prompt = context
|
820 |
+
|
821 |
+
if input and promptA:
|
822 |
+
prompt += f"""{promptA}"""
|
823 |
+
elif promptB:
|
824 |
+
prompt += f"""{promptB}"""
|
825 |
+
|
826 |
+
if instruction and PreInstruct is not None and input and PreInput is not None:
|
827 |
+
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
828 |
+
prompt = inject_newline(prompt_type, prompt)
|
829 |
+
elif instruction and input and PreInstruct is None and PreInput is not None:
|
830 |
+
prompt += f"""{PreInput}{instruction}
|
831 |
+
{input}"""
|
832 |
+
prompt = inject_newline(prompt_type, prompt)
|
833 |
+
elif input and instruction and PreInput is None and PreInstruct is not None:
|
834 |
+
prompt += f"""{PreInstruct}{instruction}
|
835 |
+
{input}"""
|
836 |
+
prompt = inject_newline(prompt_type, prompt)
|
837 |
+
elif instruction and PreInstruct is not None:
|
838 |
+
prompt += f"""{PreInstruct}{instruction}"""
|
839 |
+
prompt = inject_newline(prompt_type, prompt)
|
840 |
+
elif input and PreInput is not None:
|
841 |
+
prompt += f"""{PreInput}{input}"""
|
842 |
+
prompt = inject_newline(prompt_type, prompt)
|
843 |
+
elif input and instruction and PreInput is not None:
|
844 |
+
prompt += f"""{PreInput}{instruction}{input}"""
|
845 |
+
prompt = inject_newline(prompt_type, prompt)
|
846 |
+
elif input and instruction and PreInstruct is not None:
|
847 |
+
prompt += f"""{PreInstruct}{instruction}{input}"""
|
848 |
+
prompt = inject_newline(prompt_type, prompt)
|
849 |
+
elif input and instruction:
|
850 |
+
# i.e. for simple_instruct
|
851 |
+
prompt += f"""{instruction}: {input}"""
|
852 |
+
prompt = inject_newline(prompt_type, prompt)
|
853 |
+
elif input:
|
854 |
+
prompt += f"""{input}"""
|
855 |
+
prompt = inject_newline(prompt_type, prompt)
|
856 |
+
elif instruction:
|
857 |
+
prompt += f"""{instruction}"""
|
858 |
+
prompt = inject_newline(prompt_type, prompt)
|
859 |
+
|
860 |
+
if PreResponse is not None:
|
861 |
+
prompt += f"""{PreResponse}"""
|
862 |
+
pre_response = PreResponse # Don't use strip
|
863 |
+
else:
|
864 |
+
pre_response = ''
|
865 |
+
|
866 |
+
if output:
|
867 |
+
prompt += f"""{output}"""
|
868 |
+
|
869 |
+
return prompt, pre_response, terminate_response
|
870 |
+
|
871 |
+
|
872 |
+
def inject_newline(prompt_type, prompt):
|
873 |
+
if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
|
874 |
+
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
875 |
+
prompt += '\n'
|
876 |
+
return prompt
|
877 |
+
|
878 |
+
|
879 |
+
example_data_point0 = dict(instruction="Summarize",
|
880 |
+
input="Ducks eat seeds by the lake, then swim in the lake where fish eat small animals.",
|
881 |
+
output="Ducks eat and swim at the lake.")
|
882 |
+
|
883 |
+
example_data_point1 = dict(instruction="Who is smarter, Einstein or Newton?",
|
884 |
+
output="Einstein.")
|
885 |
+
|
886 |
+
example_data_point2 = dict(input="Who is smarter, Einstein or Newton?",
|
887 |
+
output="Einstein.")
|
888 |
+
|
889 |
+
example_data_points = [example_data_point0, example_data_point1, example_data_point2]
|
890 |
+
|
891 |
+
|
892 |
+
def test_train_prompt(prompt_type='instruct', data_point=0):
|
893 |
+
example_data_point = example_data_points[data_point]
|
894 |
+
return generate_prompt(example_data_point, prompt_type, False, False)
|
895 |
+
|
896 |
+
|
897 |
+
def test_debug():
|
898 |
+
fire.Fire(train)
|
899 |
+
|
900 |
+
|
901 |
+
if __name__ == "__main__":
|
902 |
+
CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
|
903 |
+
CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
|
904 |
+
log(f"""
|
905 |
+
Example runs on 4 GPUs:
|
906 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-7b-hf' --data_path=data/config.json --run_id=0 &> 0.log
|
907 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-30b-hf' --data_path=data/config.json --batch_size=16 --micro_batch_size=1 --run_id=1 --save_code=True &> 1.log
|
908 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-j-6B' --data_path=data/config.json --run_id=2 &> 2.log
|
909 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-neox-20b' --data_path=data/config.json --run_id=8 --batch_size=16 --micro_batch_size=4 &> 8.log
|
910 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --prompt_type='dai_faq' --run_id=13 --batch_size=16 --micro_batch_size=4 --num_epochs=100 --val_set_size=0 data_mix_in_path='' &> 13.log
|
911 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --run_id=28 --batch_size=16 --micro_batch_size=4 --num_epochs=8 --val_set_size=0 --data_mix_in_factor=0.1 --data_mix_in_prompt_type='human_bot' --save_code=True --cutoff_len=512 &> 28.log
|
912 |
+
|
913 |
+
All metrics:
|
914 |
+
CUDA_VISIBLE_DEVICES= finetune.py --data_mix_in_factor=0 --eval_steps=100 --warmup_steps=2 --val_set_size=100 --val_metrics="['bleu', 'rouge', 'sacrebleu', 'meteor']"
|
915 |
+
|
916 |
+
# Fine-tune 20B on 24GB GPUs across 3 nodes with 3+2+2 GPUs
|
917 |
+
rippa>
|
918 |
+
NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1,2" torchrun --node_rank 0 --nproc_per_node=3 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank0
|
919 |
+
ova>
|
920 |
+
NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 1 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank1
|
921 |
+
timemachine>
|
922 |
+
NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 2 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank2
|
923 |
+
|
924 |
+
""", flush=True)
|
925 |
+
|
926 |
+
if os.environ.get("LOCAL_RANK") is None:
|
927 |
+
# then not using torchrun, so can't do distributed, ensure CVD set
|
928 |
+
assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
|
929 |
+
|
930 |
+
fire.Fire(train)
|
h2o-logo.svg
ADDED
prompter.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from finetune import generate_prompt
|
2 |
+
|
3 |
+
|
4 |
+
class Prompter(object):
|
5 |
+
def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
|
6 |
+
allowed_repeat_line_length=10):
|
7 |
+
self.prompt_type = prompt_type
|
8 |
+
data_point = dict(instruction='', input='', output='')
|
9 |
+
_, self.pre_response, self.terminate_response = generate_prompt(data_point, prompt_type, chat, False)
|
10 |
+
self.debug = debug
|
11 |
+
self.chat = chat
|
12 |
+
self.stream_output = stream_output
|
13 |
+
self.repeat_penalty = repeat_penalty
|
14 |
+
self.allowed_repeat_line_length = allowed_repeat_line_length
|
15 |
+
|
16 |
+
def generate_prompt(self, data_point):
|
17 |
+
reduced = False
|
18 |
+
prompt, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
|
19 |
+
if self.debug:
|
20 |
+
print("prompt: ", prompt, flush=True)
|
21 |
+
self.prompt = prompt
|
22 |
+
return prompt
|
23 |
+
|
24 |
+
def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
|
25 |
+
if isinstance(outputs, str):
|
26 |
+
outputs = [outputs]
|
27 |
+
if self.debug:
|
28 |
+
print("output: ", '\n\n'.join(outputs), flush=True)
|
29 |
+
if prompt is not None:
|
30 |
+
self.prompt = prompt
|
31 |
+
|
32 |
+
def clean_response(response):
|
33 |
+
meaningless_words = ['<pad>', '</s>', '<|endoftext|>', '”\n']
|
34 |
+
for word in meaningless_words:
|
35 |
+
response = response.replace(word, "")
|
36 |
+
if sanitize_bot_response:
|
37 |
+
from better_profanity import profanity
|
38 |
+
response = profanity.censor(response)
|
39 |
+
response = response.strip("\n")
|
40 |
+
return response
|
41 |
+
|
42 |
+
def clean_repeats(response):
|
43 |
+
lines = response.split('\n')
|
44 |
+
new_lines = []
|
45 |
+
[new_lines.append(line) for line in lines if
|
46 |
+
line not in new_lines or len(line) < self.allowed_repeat_line_length]
|
47 |
+
if self.debug and len(lines) != len(new_lines):
|
48 |
+
print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
|
49 |
+
response = '\n'.join(new_lines)
|
50 |
+
return response
|
51 |
+
|
52 |
+
multi_output = len(outputs) > 1
|
53 |
+
|
54 |
+
for oi, output in enumerate(outputs):
|
55 |
+
if self.prompt_type in [0, '0', 'plain']:
|
56 |
+
output = clean_response(output)
|
57 |
+
else:
|
58 |
+
# find first instance of prereponse
|
59 |
+
# prompt sometimes has odd characters, that mutate length,
|
60 |
+
# so can't go by length alone
|
61 |
+
if self.pre_response:
|
62 |
+
outputi = output.find(prompt)
|
63 |
+
if outputi >= 0:
|
64 |
+
output = output[outputi + len(prompt):]
|
65 |
+
allow_terminate = True
|
66 |
+
else:
|
67 |
+
# subtraction is risky due to space offsets sometimes, so only do if necessary
|
68 |
+
output = output[len(prompt) - len(self.pre_response):]
|
69 |
+
# [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
|
70 |
+
if self.pre_response in output:
|
71 |
+
output = output.split(self.pre_response)[1]
|
72 |
+
allow_terminate = True
|
73 |
+
else:
|
74 |
+
print("Failure of parsing: %s" % output, flush=True)
|
75 |
+
allow_terminate = False
|
76 |
+
else:
|
77 |
+
allow_terminate = True
|
78 |
+
output = output[len(prompt):]
|
79 |
+
# clean after subtract prompt out, so correct removal of pre_response
|
80 |
+
output = clean_response(output).strip()
|
81 |
+
if self.repeat_penalty:
|
82 |
+
output = clean_repeats(output).strip()
|
83 |
+
if self.terminate_response and allow_terminate:
|
84 |
+
finds = []
|
85 |
+
for term in self.terminate_response:
|
86 |
+
finds.append(output.find(term))
|
87 |
+
finds = [x for x in finds if x >= 0]
|
88 |
+
if len(finds) > 0:
|
89 |
+
termi = finds[0]
|
90 |
+
output = output[:termi].strip()
|
91 |
+
else:
|
92 |
+
output = output.strip()
|
93 |
+
else:
|
94 |
+
output = output.strip()
|
95 |
+
if multi_output:
|
96 |
+
# prefix with output counter
|
97 |
+
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
98 |
+
if oi > 0:
|
99 |
+
# post fix outputs with seperator
|
100 |
+
output += '\n'
|
101 |
+
outputs[oi] = output
|
102 |
+
# join all outputs, only one extra new line between outputs
|
103 |
+
output = '\n'.join(outputs)
|
104 |
+
if self.debug:
|
105 |
+
print("outputclean: ", '\n\n'.join(outputs), flush=True)
|
106 |
+
return output
|
requirements.txt
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# for generate (gradio server) and finetune
|
2 |
+
datasets==2.10.1
|
3 |
+
sentencepiece==0.1.97
|
4 |
+
accelerate==0.18.0
|
5 |
+
gradio==3.27.0
|
6 |
+
huggingface_hub==0.13.4
|
7 |
+
appdirs==1.4.4
|
8 |
+
fire==0.5.0
|
9 |
+
docutils==0.19
|
10 |
+
torch==2.0.0
|
11 |
+
evaluate==0.4.0
|
12 |
+
rouge_score==0.1.2
|
13 |
+
sacrebleu==2.3.1
|
14 |
+
scikit-learn==1.2.2
|
15 |
+
alt-profanity-check==1.2.2
|
16 |
+
better-profanity==0.6.1
|
17 |
+
numpy==1.24.2
|
18 |
+
pandas==1.5.3
|
19 |
+
matplotlib==3.7.1
|
20 |
+
loralib==0.1.1
|
21 |
+
bitsandbytes==0.38.1
|
22 |
+
git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
|
23 |
+
transformers==4.28.1
|
24 |
+
tokenizers==0.13.3
|
25 |
+
|
26 |
+
# optional for finetune
|
27 |
+
tensorboard==2.12.1
|
28 |
+
neptune==1.1.1
|
29 |
+
|
30 |
+
# for gradio client
|
31 |
+
gradio_client==0.1.3
|
32 |
+
beautifulsoup4==4.12.2
|
33 |
+
markdown==3.4.1
|
34 |
+
|
35 |
+
# data and testing
|
36 |
+
pytest==7.2.2
|
37 |
+
pytest-xdist==3.2.1
|
38 |
+
nltk==3.8.1
|
39 |
+
textstat==0.7.3
|
40 |
+
pandoc==2.3
|
41 |
+
pypandoc==1.11
|
42 |
+
openpyxl==3.1.2
|
43 |
+
lm_dataformat==0.0.20
|
44 |
+
bioc==2.0
|
stopping.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
from queue import Queue
|
3 |
+
from threading import Thread
|
4 |
+
import collections.abc
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from transformers import StoppingCriteria
|
8 |
+
|
9 |
+
|
10 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
11 |
+
|
12 |
+
def __init__(self, stops=[], encounters=[]):
|
13 |
+
super().__init__()
|
14 |
+
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
15 |
+
self.encounters = encounters
|
16 |
+
self.stops = [stop.to("cuda") for stop in stops]
|
17 |
+
self.num_stops = [0] * len(stops)
|
18 |
+
|
19 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
20 |
+
for stopi, stop in enumerate(self.stops):
|
21 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
22 |
+
self.num_stops[stopi] += 1
|
23 |
+
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
24 |
+
return True
|
25 |
+
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
26 |
+
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
27 |
+
return False
|
28 |
+
|
29 |
+
|
30 |
+
class Stream(StoppingCriteria):
|
31 |
+
"""
|
32 |
+
This class can be used to callback during generation. Keep
|
33 |
+
in mind for decoder-only type of transformers, this will include the initial prompted tokens.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
func (`callable`):
|
37 |
+
A callable function to apply on first input in list every iteration of generation
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, func=None):
|
41 |
+
self.func = func
|
42 |
+
|
43 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
44 |
+
if self.func is not None:
|
45 |
+
# only consume first of multiple responses
|
46 |
+
self.func(input_ids[0])
|
47 |
+
return False
|
48 |
+
|
49 |
+
|
50 |
+
class CallbackToGenerator(collections.abc.Generator):
|
51 |
+
"""
|
52 |
+
A generator wrapper for a function that invokes a callback multiple times.
|
53 |
+
|
54 |
+
Calling `send` on the generator emits a value from one callback, and returns
|
55 |
+
the next.
|
56 |
+
|
57 |
+
Note this starts a background thread
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, func, *args, callback=None, **kwargs):
|
61 |
+
self.func = func
|
62 |
+
self.args = args
|
63 |
+
self.kwargs = kwargs
|
64 |
+
self.callback = callback
|
65 |
+
|
66 |
+
self._ready_queue = Queue(1)
|
67 |
+
self._done_queue = Queue(1)
|
68 |
+
self._done_holder = [False]
|
69 |
+
|
70 |
+
# local to avoid reference cycles
|
71 |
+
ready_queue = self._ready_queue
|
72 |
+
done_queue = self._done_queue
|
73 |
+
done_holder = self._done_holder
|
74 |
+
|
75 |
+
def val_callback(value):
|
76 |
+
done_queue.put((False, value))
|
77 |
+
cmd, val = ready_queue.get()
|
78 |
+
if cmd == 'send':
|
79 |
+
return val
|
80 |
+
elif cmd == 'throw':
|
81 |
+
raise val
|
82 |
+
else:
|
83 |
+
assert False # pragma: no cover
|
84 |
+
|
85 |
+
def thread_func():
|
86 |
+
while True:
|
87 |
+
cmd, val = ready_queue.get()
|
88 |
+
if cmd == 'send' and val is not None:
|
89 |
+
done_queue.put((True, TypeError("can't send non-None value to a just-started generator")))
|
90 |
+
continue
|
91 |
+
break
|
92 |
+
try:
|
93 |
+
if cmd == 'throw':
|
94 |
+
raise val
|
95 |
+
ret = func(callback=val_callback, **self.kwargs)
|
96 |
+
raise StopIteration(ret) if ret is not None else StopIteration
|
97 |
+
except BaseException as e:
|
98 |
+
done_holder[0] = True
|
99 |
+
done_queue.put((True, e))
|
100 |
+
|
101 |
+
self._thread = Thread(target=thread_func)
|
102 |
+
self._thread.start()
|
103 |
+
|
104 |
+
def _put(self, *args):
|
105 |
+
if self._done_holder[0]:
|
106 |
+
raise StopIteration
|
107 |
+
self._ready_queue.put(args)
|
108 |
+
is_exception, val = self._done_queue.get()
|
109 |
+
if is_exception:
|
110 |
+
try:
|
111 |
+
raise val
|
112 |
+
finally:
|
113 |
+
# prevent val's traceback containing a reference cycle
|
114 |
+
del val
|
115 |
+
else:
|
116 |
+
return val
|
117 |
+
|
118 |
+
def send(self, value):
|
119 |
+
return self._put('send', value)
|
120 |
+
|
121 |
+
def throw(self, exc):
|
122 |
+
return self._put('throw', exc)
|
123 |
+
|
124 |
+
def close(self):
|
125 |
+
try:
|
126 |
+
self.throw(GeneratorExit)
|
127 |
+
except StopIteration:
|
128 |
+
self._thread.join()
|
129 |
+
except GeneratorExit:
|
130 |
+
self._thread.join()
|
131 |
+
except BaseException:
|
132 |
+
self._thread.join()
|
133 |
+
raise
|
134 |
+
else:
|
135 |
+
# yielded again, can't clean up the thread
|
136 |
+
raise RuntimeError('Task with callback ignored GeneratorExit')
|
137 |
+
|
138 |
+
def __del__(self):
|
139 |
+
self.close()
|
utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def set_seed(seed: int):
|
9 |
+
"""
|
10 |
+
Sets the seed of the entire notebook so results are the same every time we run.
|
11 |
+
This is for REPRODUCIBILITY.
|
12 |
+
"""
|
13 |
+
np.random.seed(seed)
|
14 |
+
random_state = np.random.RandomState(seed)
|
15 |
+
random.seed(seed)
|
16 |
+
torch.manual_seed(seed)
|
17 |
+
torch.cuda.manual_seed(seed)
|
18 |
+
torch.backends.cudnn.deterministic = True
|
19 |
+
torch.backends.cudnn.benchmark = False
|
20 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
21 |
+
return random_state
|
22 |
+
|
23 |
+
|
24 |
+
def flatten_list(lis):
|
25 |
+
"""Given a list, possibly nested to any level, return it flattened."""
|
26 |
+
new_lis = []
|
27 |
+
for item in lis:
|
28 |
+
if type(item) == type([]):
|
29 |
+
new_lis.extend(flatten_list(item))
|
30 |
+
else:
|
31 |
+
new_lis.append(item)
|
32 |
+
return new_lis
|
33 |
+
|
34 |
+
|
35 |
+
def clear_torch_cache():
|
36 |
+
if torch.cuda.is_available:
|
37 |
+
torch.cuda.empty_cache()
|
38 |
+
torch.cuda.ipc_collect()
|
39 |
+
gc.collect()
|