Spaces:
Running
Running
style: use isort
Browse files- .github/workflows/black.yml +0 -14
- app/gradio/app_gradio.py +6 -14
- app/streamlit/app.py +2 -1
- app/streamlit/backend.py +3 -2
- dalle_mini/data.py +4 -2
- dalle_mini/model.py +6 -8
- dalle_mini/text.py +5 -3
- tools/train/train.py +8 -14
.github/workflows/black.yml
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
name: Lint
|
2 |
-
|
3 |
-
on:
|
4 |
-
push:
|
5 |
-
branches: [main]
|
6 |
-
pull_request:
|
7 |
-
branches: [main]
|
8 |
-
|
9 |
-
jobs:
|
10 |
-
lint:
|
11 |
-
runs-on: ubuntu-latest
|
12 |
-
steps:
|
13 |
-
- uses: actions/checkout@v2
|
14 |
-
- uses: psf/black@stable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/gradio/app_gradio.py
CHANGED
@@ -7,26 +7,18 @@
|
|
7 |
|
8 |
import random
|
9 |
|
|
|
10 |
import jax
|
11 |
-
import
|
12 |
-
from flax.training.common_utils import shard
|
13 |
from flax.jax_utils import replicate
|
14 |
-
|
15 |
-
from transformers import BartTokenizer
|
16 |
-
|
17 |
from PIL import Image, ImageDraw, ImageFont
|
18 |
-
import numpy as np
|
19 |
-
|
20 |
-
from vqgan_jax.modeling_flax_vqgan import VQModel
|
21 |
-
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
22 |
|
23 |
# ## CLIP Scoring
|
24 |
-
from transformers import CLIPProcessor, FlaxCLIPModel
|
25 |
-
|
26 |
-
import gradio as gr
|
27 |
-
|
28 |
-
from PIL import Image, ImageDraw, ImageFont
|
29 |
|
|
|
30 |
|
31 |
DALLE_REPO = "flax-community/dalle-mini"
|
32 |
DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
|
|
|
7 |
|
8 |
import random
|
9 |
|
10 |
+
import gradio as gr
|
11 |
import jax
|
12 |
+
import numpy as np
|
|
|
13 |
from flax.jax_utils import replicate
|
14 |
+
from flax.training.common_utils import shard
|
|
|
|
|
15 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# ## CLIP Scoring
|
18 |
+
from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
|
19 |
+
from vqgan_jax.modeling_flax_vqgan import VQModel
|
|
|
|
|
|
|
20 |
|
21 |
+
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
22 |
|
23 |
DALLE_REPO = "flax-community/dalle-mini"
|
24 |
DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
|
app/streamlit/app.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
# coding: utf-8
|
3 |
|
4 |
-
from .backend import ServiceError, get_images_from_backend
|
5 |
import streamlit as st
|
6 |
|
|
|
|
|
7 |
st.sidebar.markdown(
|
8 |
"""
|
9 |
<style>
|
|
|
1 |
#!/usr/bin/env python
|
2 |
# coding: utf-8
|
3 |
|
|
|
4 |
import streamlit as st
|
5 |
|
6 |
+
from .backend import ServiceError, get_images_from_backend
|
7 |
+
|
8 |
st.sidebar.markdown(
|
9 |
"""
|
10 |
<style>
|
app/streamlit/backend.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
import requests
|
2 |
-
from io import BytesIO
|
3 |
import base64
|
|
|
|
|
|
|
4 |
from PIL import Image
|
5 |
|
6 |
|
|
|
|
|
|
|
1 |
import base64
|
2 |
+
from io import BytesIO
|
3 |
+
|
4 |
+
import requests
|
5 |
from PIL import Image
|
6 |
|
7 |
|
dalle_mini/data.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
from dataclasses import dataclass, field
|
2 |
-
from datasets import load_dataset, Dataset
|
3 |
from functools import partial
|
4 |
-
|
5 |
import jax
|
6 |
import jax.numpy as jnp
|
|
|
|
|
7 |
from flax.training.common_utils import shard
|
|
|
8 |
from .text import TextNormalizer
|
9 |
|
10 |
|
|
|
1 |
from dataclasses import dataclass, field
|
|
|
2 |
from functools import partial
|
3 |
+
|
4 |
import jax
|
5 |
import jax.numpy as jnp
|
6 |
+
import numpy as np
|
7 |
+
from datasets import Dataset, load_dataset
|
8 |
from flax.training.common_utils import shard
|
9 |
+
|
10 |
from .text import TextNormalizer
|
11 |
|
12 |
|
dalle_mini/model.py
CHANGED
@@ -1,16 +1,14 @@
|
|
1 |
-
import jax
|
2 |
import flax.linen as nn
|
3 |
-
|
|
|
4 |
from transformers.models.bart.modeling_flax_bart import (
|
5 |
-
FlaxBartModule,
|
6 |
-
FlaxBartForConditionalGenerationModule,
|
7 |
-
FlaxBartForConditionalGeneration,
|
8 |
-
FlaxBartEncoder,
|
9 |
FlaxBartDecoder,
|
|
|
|
|
|
|
|
|
10 |
)
|
11 |
|
12 |
-
from transformers import BartConfig
|
13 |
-
|
14 |
|
15 |
class CustomFlaxBartModule(FlaxBartModule):
|
16 |
def setup(self):
|
|
|
|
|
1 |
import flax.linen as nn
|
2 |
+
import jax
|
3 |
+
from transformers import BartConfig
|
4 |
from transformers.models.bart.modeling_flax_bart import (
|
|
|
|
|
|
|
|
|
5 |
FlaxBartDecoder,
|
6 |
+
FlaxBartEncoder,
|
7 |
+
FlaxBartForConditionalGeneration,
|
8 |
+
FlaxBartForConditionalGenerationModule,
|
9 |
+
FlaxBartModule,
|
10 |
)
|
11 |
|
|
|
|
|
12 |
|
13 |
class CustomFlaxBartModule(FlaxBartModule):
|
14 |
def setup(self):
|
dalle_mini/text.py
CHANGED
@@ -2,13 +2,15 @@
|
|
2 |
Utilities for processing text.
|
3 |
"""
|
4 |
|
|
|
|
|
|
|
|
|
5 |
from pathlib import Path
|
6 |
-
from unidecode import unidecode
|
7 |
|
8 |
-
import re, math, random, html
|
9 |
import ftfy
|
10 |
-
|
11 |
from huggingface_hub import hf_hub_download
|
|
|
12 |
|
13 |
# based on wiki word occurence
|
14 |
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
|
|
|
2 |
Utilities for processing text.
|
3 |
"""
|
4 |
|
5 |
+
import html
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
import re
|
9 |
from pathlib import Path
|
|
|
10 |
|
|
|
11 |
import ftfy
|
|
|
12 |
from huggingface_hub import hf_hub_download
|
13 |
+
from unidecode import unidecode
|
14 |
|
15 |
# based on wiki word occurence
|
16 |
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
|
tools/train/train.py
CHANGED
@@ -18,37 +18,31 @@ Fine-tuning the library models for seq2seq, text to image.
|
|
18 |
Script adapted from run_summarization_flax.py
|
19 |
"""
|
20 |
|
21 |
-
import
|
22 |
import logging
|
|
|
23 |
import sys
|
24 |
import time
|
25 |
-
from dataclasses import dataclass, field
|
26 |
from pathlib import Path
|
27 |
from typing import Callable, Optional
|
28 |
-
import json
|
29 |
|
30 |
import datasets
|
31 |
-
from datasets import Dataset
|
32 |
-
from tqdm import tqdm
|
33 |
-
from dataclasses import asdict
|
34 |
-
|
35 |
import jax
|
36 |
import jax.numpy as jnp
|
37 |
import optax
|
38 |
import transformers
|
|
|
|
|
39 |
from flax import jax_utils, traverse_util
|
40 |
-
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.jax_utils import unreplicate
|
|
|
42 |
from flax.training import train_state
|
43 |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
44 |
-
from
|
45 |
-
|
46 |
-
HfArgumentParser,
|
47 |
-
)
|
48 |
from transformers.models.bart.modeling_flax_bart import BartConfig
|
49 |
|
50 |
-
import wandb
|
51 |
-
|
52 |
from dalle_mini.data import Dataset
|
53 |
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
54 |
|
|
|
18 |
Script adapted from run_summarization_flax.py
|
19 |
"""
|
20 |
|
21 |
+
import json
|
22 |
import logging
|
23 |
+
import os
|
24 |
import sys
|
25 |
import time
|
26 |
+
from dataclasses import asdict, dataclass, field
|
27 |
from pathlib import Path
|
28 |
from typing import Callable, Optional
|
|
|
29 |
|
30 |
import datasets
|
|
|
|
|
|
|
|
|
31 |
import jax
|
32 |
import jax.numpy as jnp
|
33 |
import optax
|
34 |
import transformers
|
35 |
+
import wandb
|
36 |
+
from datasets import Dataset
|
37 |
from flax import jax_utils, traverse_util
|
|
|
38 |
from flax.jax_utils import unreplicate
|
39 |
+
from flax.serialization import from_bytes, to_bytes
|
40 |
from flax.training import train_state
|
41 |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
42 |
+
from tqdm import tqdm
|
43 |
+
from transformers import AutoTokenizer, HfArgumentParser
|
|
|
|
|
44 |
from transformers.models.bart.modeling_flax_bart import BartConfig
|
45 |
|
|
|
|
|
46 |
from dalle_mini.data import Dataset
|
47 |
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
48 |
|