File size: 3,275 Bytes
5e37512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285aab9
 
 
 
 
5e37512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5608340
5e37512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
from huggingface_hub import snapshot_download, delete_repo, metadata_update
import uuid
import json
import yaml
import subprocess

HF_TOKEN = os.environ.get("HF_TOKEN")
HF_DATASET = os.environ.get("DATA_PATH")


def download_dataset(hf_dataset_path: str):
    random_id = str(uuid.uuid4())
    snapshot_download(
        repo_id=hf_dataset_path,
        token=HF_TOKEN,
        local_dir=f"/tmp/{random_id}",
        repo_type="dataset",
    )
    return f"/tmp/{random_id}"


def process_dataset(dataset_dir: str):
    # dataset dir consists of images, config.yaml and a metadata.jsonl (optional) with fields: file_name, prompt
    # generate .txt files with the same name as the images with the prompt as the content
    # remove metadata.jsonl
    # return the path to the processed dataset

    # check if config.yaml exists
    if not os.path.exists(os.path.join(dataset_dir, "config.yaml")):
        raise ValueError("config.yaml does not exist")

    # check if metadata.jsonl exists
    if os.path.exists(os.path.join(dataset_dir, "metadata.jsonl")):
        metadata = []
        with open(os.path.join(dataset_dir, "metadata.jsonl"), "r") as f:
            for line in f:
                if len(line.strip()) > 0:
                    metadata.append(json.loads(line))
        for item in metadata:
            txt_path = os.path.join(dataset_dir, item["file_name"])
            txt_path = txt_path.rsplit(".", 1)[0] + ".txt"
            with open(txt_path, "w") as f:
                f.write(item["prompt"])

        # remove metadata.jsonl
        os.remove(os.path.join(dataset_dir, "metadata.jsonl"))

    with open(os.path.join(dataset_dir, "config.yaml"), "r") as f:
        config = yaml.safe_load(f)

    # update config with new dataset
    config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_dir

    with open(os.path.join(dataset_dir, "config.yaml"), "w") as f:
        yaml.dump(config, f)

    return dataset_dir


def run_training(hf_dataset_path: str):

    dataset_dir = download_dataset(hf_dataset_path)
    dataset_dir = process_dataset(dataset_dir)

    # run training
    commands = "git clone https://github.com/ostris/ai-toolkit.git ai-toolkit && cd ai-toolkit && git checkout bc693488eb3cf48ded8bc2af845059d80f4cf7d0 && git submodule update --init --recursive"
    subprocess.run(commands, shell=True)

    commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}"
    process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ)

    return process, dataset_dir


if __name__ == "__main__":
    process, dataset_dir = run_training(HF_DATASET)
    process.wait()  # Wait for the training process to finish

    with open(os.path.join(dataset_dir, "config.yaml"), "r") as f:
        config = yaml.safe_load(f)
    repo_id = config["config"]["process"][0]["save"]["hf_repo_id"]

    metadata = {
        "tags": [
            "autotrain",
            "spacerunner",
            "text-to-image",
            "flux",
            "lora",
            "diffusers",
            "template:sd-lora",
        ]
    }
    metadata_update(repo_id, metadata, token=HF_TOKEN, repo_type="model", overwrite=True)
    delete_repo(HF_DATASET, token=HF_TOKEN, repo_type="dataset", missing_ok=True)