m-vai commited on
Commit
3984963
1 Parent(s): 89eb8b1

modify tester to handle context-based instruction mode.

Browse files
Files changed (1) hide show
  1. tester.py +49 -4
tester.py CHANGED
@@ -3,6 +3,9 @@ import pathlib
3
  import torch
4
  from transformers import pipeline
5
 
 
 
 
6
  def getText():
7
  s = '''
8
  A US climber has died on his way to scale Mount Everest on Monday, according to an expedition organizer.
@@ -22,13 +25,13 @@ Following Sugarman’s death, the Embassy of the United States issued a statemen
22
 
23
  return s
24
 
25
- def main():
26
- print("In main...")
27
  # print(pathlib.Path().resolve())
28
  # generate_text_pipline = pipeline(model="databricks/dolly-v2-3b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
29
  # generate_text_pipline = pipeline(model="verseAI/databricks-dolly-v2-3b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
30
  workingDir = pathlib.Path().resolve()
31
- generate_text_pipline = pipeline(model=workingDir, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
32
  inputText = getText()
33
  resp = generate_text_pipline(inputText)
34
  respStr = resp[0]["generated_text"]
@@ -38,5 +41,47 @@ def main():
38
 
39
  print("All Done!")
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  if __name__ == "__main__":
42
- main()
 
3
  import torch
4
  from transformers import pipeline
5
 
6
+ from langchain import PromptTemplate, LLMChain
7
+ from langchain.llms import HuggingFacePipeline
8
+
9
  def getText():
10
  s = '''
11
  A US climber has died on his way to scale Mount Everest on Monday, according to an expedition organizer.
 
25
 
26
  return s
27
 
28
+ def mainSimple():
29
+ print("In main simple...")
30
  # print(pathlib.Path().resolve())
31
  # generate_text_pipline = pipeline(model="databricks/dolly-v2-3b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
32
  # generate_text_pipline = pipeline(model="verseAI/databricks-dolly-v2-3b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
33
  workingDir = pathlib.Path().resolve()
34
+ generate_text_pipline = pipeline(model=workingDir, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", return_full_text=True)
35
  inputText = getText()
36
  resp = generate_text_pipline(inputText)
37
  respStr = resp[0]["generated_text"]
 
41
 
42
  print("All Done!")
43
 
44
+ def getInstrAndContext():
45
+ context = '''George Washington (February 22, 1732[b] - December 14, 1799) was an American military officer, statesman,
46
+ and Founding Father who served as the first president of the United States from 1789 to 1797.'''
47
+ # instr = '''When was George Washington president?'''
48
+ instr = '''What do you think?'''
49
+
50
+ return instr, context
51
+
52
+ def mainContext():
53
+ print("In main context...")
54
+ workingDir = pathlib.Path().resolve()
55
+ generate_text_pipline = pipeline(model=workingDir, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", return_full_text=True)
56
+
57
+ # template for an instrution with no input
58
+ prompt = PromptTemplate(
59
+ input_variables=["instruction"],
60
+ template="{instruction}")
61
+
62
+ # template for an instruction with input
63
+ prompt_with_context = PromptTemplate(
64
+ input_variables=["instruction", "context"],
65
+ template="{instruction}\n\nInput:\n{context}")
66
+
67
+ hf_pipeline = HuggingFacePipeline(pipeline=generate_text_pipline)
68
+
69
+ llm_chain = LLMChain(llm=hf_pipeline, prompt=prompt)
70
+ llm_context_chain = LLMChain(llm=hf_pipeline, prompt=prompt_with_context)
71
+
72
+ instr, context = getInstrAndContext()
73
+ resp = ''
74
+
75
+ if(context and not context.isspace()):
76
+ resp = llm_context_chain.predict(instruction=instr, context=context).lstrip()
77
+ else:
78
+ resp = llm_chain.predict(instruction=instr).lstrip()
79
+
80
+ print(f'Input-Instr: {instr}')
81
+ print(f'Input-Context: {context}')
82
+ print(f'Response: {resp}')
83
+
84
+ print("All Done!")
85
+
86
  if __name__ == "__main__":
87
+ mainContext()