disticleaner / pipeline.py
Ben Burtenshaw
first commit
39e6ae5
raw
history blame
2.67 kB
import os
import json
import sys
from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import (
LoadDataFromDicts,
TextGenerationToArgilla,
ExpandColumns,
)
from distilabel.steps.tasks import SelfInstruct
from huggingface_hub import hf_hub_download
def run(repo_id):
# Get super secret tokens
hub_token = os.environ.get("HF_TOKEN")
with open(
hf_hub_download(
repo_id=repo_id, filename="pipeline_params.json", repo_type="dataset"
),
"r",
) as f:
params = json.load(f)
self_instruct_base_url = params.get("self_instruct_base_url")
self_intruct_num_generations = params.get("self_instruct_num_generations", 2)
domain_expert_num_generations = params.get("domain_expert_num_generations", 2)
self_instruct_temperature = params.get("self_instruct_temperature", 0.9)
domain_expert_temperature = params.get("domain_expert_temperature", 0.9)
self_instruct_max_new_tokens = params.get("self_instruct_max_new_tokens", 1024)
domain_expert_max_new_tokens = params.get("domain_expert_max_new_tokens", 1024)
with open(
hf_hub_download(
repo_id=repo_id, filename="seed_data.json", repo_type="dataset"
),
"r",
) as f:
seed_data = json.load(f)
application_instruction = seed_data.get("application_instruction")
domain_expert_prompt = seed_data.get("domain_expert_prompt")
domain_name = seed_data.get("domain")
terms = seed_data.get("seed_terms")
with Pipeline(domain_name) as pipeline:
load_data = LoadDataFromDicts(
name="load_data",
batch_size=64,
data=[{"input": term} for term in terms],
)
self_instruct = SelfInstruct(
name="self_instruct",
num_instructions=self_intruct_num_generations,
input_batch_size=8,
llm=InferenceEndpointsLLM(
api_key=hub_token,
base_url=self_instruct_base_url,
),
application_description=application_instruction,
)
# Connect up the pipeline
load_data.connect(self_instruct)
# Run the pipeline
pipeline.run(
use_cache=False,
parameters={
"self_instruct": {
"llm": {
"generation_kwargs": {
"max_new_tokens": self_instruct_max_new_tokens,
"temperature": self_instruct_temperature,
},
}
},
},
)
if __name__ == "__main__":
run(sys.argv[1])