Tuchuanhuhuhu commited on
Commit
75dddd5
1 Parent(s): 2342c7b

加入MOSS支持

Browse files
modules/models/MOSS.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
3
+ import time
4
+ import numpy as np
5
+ from torch.nn import functional as F
6
+ import os
7
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
8
+ from transformers import MossForCausalLM, MossConfig
9
+
10
+ from .base_model import BaseLLMModel
11
+
12
+ MOSS_MODEL = None
13
+ MOSS_TOKENIZER = None
14
+
15
+ class MOSS_Client(BaseLLMModel):
16
+ def __init__(self, model_name) -> None:
17
+ super().__init__(model_name=model_name)
18
+ global MOSS_MODEL, MOSS_TOKENIZER
19
+ config = MossConfig.from_pretrained("fnlp/moss-16B-sft")
20
+ print("MOSS Model Parallelism Devices: ", torch.cuda.device_count())
21
+ with init_empty_weights():
22
+ raw_model = MossForCausalLM._from_config(config, torch_dtype=torch.float16)
23
+ raw_model.tie_weights()
24
+ MOSS_MODEL = load_checkpoint_and_dispatch(
25
+ raw_model,
26
+ "fnlp/moss-16B-sft",
27
+ device_map="auto",
28
+ no_split_module_classes=["MossBlock"],
29
+ dtype=torch.float16
30
+ )
31
+
32
+ if __name__ == "__main__":
33
+ model = MOSS_Client("MOSS")
modules/models/base_model.py CHANGED
@@ -31,6 +31,7 @@ class ModelType(Enum):
31
  LLaMA = 2
32
  XMChat = 3
33
  StableLM = 4
 
34
 
35
  @classmethod
36
  def get_type(cls, model_name: str):
@@ -46,6 +47,8 @@ class ModelType(Enum):
46
  model_type = ModelType.XMChat
47
  elif "stablelm" in model_name_lower:
48
  model_type = ModelType.StableLM
 
 
49
  else:
50
  model_type = ModelType.Unknown
51
  return model_type
 
31
  LLaMA = 2
32
  XMChat = 3
33
  StableLM = 4
34
+ MOSS = 5
35
 
36
  @classmethod
37
  def get_type(cls, model_name: str):
 
47
  model_type = ModelType.XMChat
48
  elif "stablelm" in model_name_lower:
49
  model_type = ModelType.StableLM
50
+ elif "moss" in model_name_lower:
51
+ model_type = ModelType.MOSS
52
  else:
53
  model_type = ModelType.Unknown
54
  return model_type
modules/presets.py CHANGED
@@ -74,11 +74,12 @@ LOCAL_MODELS = [
74
  "chatglm-6b",
75
  "chatglm-6b-int4",
76
  "chatglm-6b-int4-qe",
 
 
77
  "llama-7b-hf",
78
  "llama-13b-hf",
79
  "llama-30b-hf",
80
  "llama-65b-hf",
81
- "StableLM"
82
  ]
83
 
84
  if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
 
74
  "chatglm-6b",
75
  "chatglm-6b-int4",
76
  "chatglm-6b-int4-qe",
77
+ "StableLM",
78
+ "MOSS",
79
  "llama-7b-hf",
80
  "llama-13b-hf",
81
  "llama-30b-hf",
82
  "llama-65b-hf",
 
83
  ]
84
 
85
  if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':