File size: 1,858 Bytes
f27aa56
 
 
 
 
 
 
 
 
 
99bc41f
 
 
 
 
 
f27aa56
99bc41f
8d19f57
99bc41f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f27aa56
99bc41f
8d19f57
99bc41f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
---
license: apache-2.0
language:
- ko
tags:
- construction
- interior
- defective
- finished materials
---
μ£Όμ‹νšŒμ‚¬ ν•œμ†”λ°μ½”μ˜ 곡개 도메인 데이터셋을 토큰화 및 ν•™μŠ΅ν•˜μ˜€μŠ΅λ‹ˆλ‹€.

base model : mistralai/Mistral-7B-v0.1

Dataset : ν•œμ†”λ°μ½” 도메인 데이터셋

## ν•™μŠ΅ νŒŒλΌλ―Έν„°

```
num_train_epochs=3
per_device_train_batch_size=1
gradient_accumulation_steps=4
gradient_checkpointing=True
learning_rate=5e-5
lr_scheduler_type="linear"
max_steps=200
save_strategy="no"
logging_steps=1
output_dir=new_model
optim="paged_adamw_32bit"
warmup_steps=100
fp16=True
```

## μ‹€ν–‰ 예제

```
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TextStreamer, GenerationConfig

model_name='sosoai/hansoldeco-mistral-dpo-v1'
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextStreamer(tokenizer)

def gen(x):
    generation_config = GenerationConfig(
        temperature=0.1,
        top_p=0.8,
        top_k=100,
        max_new_tokens=256,
        early_stopping=True,
        do_sample=True,
        repetition_penalty=1.2,
    )
    q = f"[INST]{x} [/INST]"
    gened = model.generate(
        **tokenizer(
            q,
            return_tensors='pt',
            return_token_type_ids=False
        ).to('cuda'),
        generation_config=generation_config,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        streamer=streamer,
    )
    result_str = tokenizer.decode(gened[0])

    start_tag = f"\n\n### Response: "
    start_index = result_str.find(start_tag)

    if start_index != -1:
        result_str = result_str[start_index + len(start_tag):].strip()
    return result_str

print(gen('λ§ˆκ°ν•˜μžλŠ” μ–΄λ–€ μ’…λ₯˜κ°€ μžˆλ‚˜μš”?'))
```