Paulus Michael Leang commited on
Commit
2bd5d74
1 Parent(s): e3e5833

Change it to Ollama

Browse files
Files changed (2) hide show
  1. app.py +5 -32
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,53 +1,26 @@
1
  from fastapi import FastAPI
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from pydantic import BaseModel
4
- import torch
5
  import uvicorn
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
- model_name = "Qwen/Qwen2.5-Coder-14B-Instruct"
10
-
11
- model = AutoModelForCausalLM.from_pretrained(
12
- model_name,
13
- torch_dtype="auto",
14
- device_map="auto" if device == "cuda" else None
15
- )
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
 
18
  class ChatRequest(BaseModel):
19
  prompt: str
20
 
21
  app = FastAPI()
22
 
 
 
23
  @app.get("/")
24
  def greet_json():
25
  return {"Hello": "World!"}
26
 
27
  @app.post("/generate_chat")
28
  def generateAi(request: ChatRequest):
29
- messages = [
30
- {"role": "system", "content": "You are a Mandarin language learning assistant that only answers in Mandarin."},
31
- {"role": "user", "content": request.prompt}
32
- ]
33
- text = tokenizer.apply_chat_template(
34
- messages,
35
- tokenize=False,
36
- add_generation_prompt=True
37
- )
38
- model_inputs = tokenizer([text], return_tensors="pt").to(device)
39
-
40
- generated_ids = model.generate(
41
- model_inputs.input_ids,
42
- max_new_tokens=512
43
- )
44
- generated_ids = [
45
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
46
- ]
47
-
48
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
 
50
- return {"answer": response}
51
 
52
  if __name__ == "__main__":
53
  uvicorn.run(app, host='0.0.0.0', port='8000')
 
1
  from fastapi import FastAPI
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from pydantic import BaseModel
4
+ from langchain_ollama import OllamaLLM
5
  import uvicorn
6
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class ChatRequest(BaseModel):
9
  prompt: str
10
 
11
  app = FastAPI()
12
 
13
+ llmModel = OllamaLLM(model='paullmich28/manlingua-ai-sim')
14
+
15
  @app.get("/")
16
  def greet_json():
17
  return {"Hello": "World!"}
18
 
19
  @app.post("/generate_chat")
20
  def generateAi(request: ChatRequest):
21
+ result = llmModel.invoke(input="Hello, how are you?")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ return {"answer": result}
24
 
25
  if __name__ == "__main__":
26
  uvicorn.run(app, host='0.0.0.0', port='8000')
requirements.txt CHANGED
@@ -4,4 +4,5 @@ transformers
4
  pydantic
5
  accelerate
6
  torch
7
- torchvision
 
 
4
  pydantic
5
  accelerate
6
  torch
7
+ torchvision
8
+ langchain_ollama