LinkangZhan commited on
Commit
05b7bb9
1 Parent(s): b64802f

update slider

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ test.py
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/.name ADDED
@@ -0,0 +1 @@
 
 
1
+ app.py
.idea/Genshin-World-Model.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="11">
8
+ <item index="0" class="java.lang.String" itemvalue="tiktoken" />
9
+ <item index="1" class="java.lang.String" itemvalue="scipy" />
10
+ <item index="2" class="java.lang.String" itemvalue="matplotlib" />
11
+ <item index="3" class="java.lang.String" itemvalue="whisper" />
12
+ <item index="4" class="java.lang.String" itemvalue="torch" />
13
+ <item index="5" class="java.lang.String" itemvalue="numpy" />
14
+ <item index="6" class="java.lang.String" itemvalue="requests" />
15
+ <item index="7" class="java.lang.String" itemvalue="torchvision" />
16
+ <item index="8" class="java.lang.String" itemvalue="torchaudio" />
17
+ <item index="9" class="java.lang.String" itemvalue="Pillow" />
18
+ <item index="10" class="java.lang.String" itemvalue="Requests" />
19
+ </list>
20
+ </value>
21
+ </option>
22
+ </inspection_tool>
23
+ </profile>
24
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="pytorch" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Genshin-World-Model.iml" filepath="$PROJECT_DIR$/.idea/Genshin-World-Model.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
app.py CHANGED
@@ -47,18 +47,18 @@ def respond(role_name, character_name, msg, textbox, temp, rep, max_len, top_p,
47
  yield ["", textbox]
48
  if character_name != '':
49
  textbox += ('\n' if textbox[-1] != '\n' else '') + character_name + ':'
50
- input_ids = tokenizer.encode(textbox)[-4096:]
51
  input_ids = torch.LongTensor([input_ids]).to(device)
52
  generation_config = model.generation_config
53
  stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
54
  gen_kwargs = {}
55
  gen_kwargs.update(dict(
56
  input_ids=input_ids,
57
- temperature=1.5,
58
- top_p=0.7,
59
- top_k=50,
60
- repetition_penalty=1.0,
61
- max_new_tokens=256,
62
  do_sample=True,
63
  ))
64
  outputs = []
 
47
  yield ["", textbox]
48
  if character_name != '':
49
  textbox += ('\n' if textbox[-1] != '\n' else '') + character_name + ':'
50
+ input_ids = tokenizer.encode(textbox)[-3200:]
51
  input_ids = torch.LongTensor([input_ids]).to(device)
52
  generation_config = model.generation_config
53
  stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
54
  gen_kwargs = {}
55
  gen_kwargs.update(dict(
56
  input_ids=input_ids,
57
+ temperature=temp,
58
+ top_p=top_p,
59
+ top_k=top_k,
60
+ repetition_penalty=rep,
61
+ max_new_tokens=max_len,
62
  do_sample=True,
63
  ))
64
  outputs = []