alakxender commited on
Commit
3d6100c
1 Parent(s): 92c7715

initial commit

Browse files
app.py CHANGED
@@ -1,7 +1,63 @@
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
+ import spaces
2
  import gradio as gr
3
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer,T5ForConditionalGeneration, T5Tokenizer
4
 
5
+ models = {"finetuned mt5-base":"alakxender/mt5-base-dv-en", "madlad400-3b":"google/madlad400-3b-mt"}
6
+
7
+ def tranlate(text:str,model_name:str):
8
+ if (len(text)>2000):
9
+ raise gr.Error(f"Try smaller text, yours is {len(text)}. try to fit to 2000 chars.")
10
+
11
+ if (model_name is None):
12
+ raise gr.Error("huh! not sure what to do without a model. select a model.")
13
+
14
+ if model_name =='finetuned mt5-base':
15
+ return mt5_translate(text,model_name)
16
+ else:
17
+ return t5_tranlaste(text,model_name)
18
+
19
+ @spaces.GPU(duration=120)
20
+ def t5_tranlaste(text:str,model_name:str):
21
+
22
+ model = T5ForConditionalGeneration.from_pretrained(models[model_name], device_map="auto")
23
+ tokenizer = T5Tokenizer.from_pretrained(models[model_name])
24
+
25
+ text = f"<2en> {text}"
26
+ input_ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
27
+ outputs = model.generate(input_ids=input_ids, max_new_tokens=1024*2)
28
+
29
+ translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
+
31
+ return translated_text
32
+
33
+ def mt5_translate(text:str, model_name:str):
34
+
35
+ model = MT5ForConditionalGeneration.from_pretrained(models[model_name])
36
+ tokenizer = MT5Tokenizer.from_pretrained(models[model_name])
37
+ inputs = tokenizer(text, return_tensors="pt")
38
+ result = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=1024*2)
39
+ translated_text = tokenizer.decode(result[0], skip_special_tokens=True)
40
+ return translated_text
41
+
42
+ css = """
43
+ .textbox1 textarea {
44
+ font-size: 18px !important;
45
+ font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important;
46
+ line-height: 1.8 !important;
47
+ }
48
+ """
49
+
50
+ demo = gr.Interface(
51
+ fn=tranlate,
52
+ inputs= [
53
+ gr.Textbox(lines=5, label="Enter Dhivehi Text", rtl=True, elem_classes="textbox1"),
54
+ gr.Dropdown(choices=list(models.keys()), label="Select a model", value="finetuned mt5-base"),
55
+ ],
56
+ css=css,
57
+ outputs=gr.Textbox(label="English Translation"),
58
+ title="Dhivehi to English Translation",
59
+ description="Translate Dhivehi text to English",
60
+ examples=[["މާލޭގައި ފެންބޮޑުވާ މަގުތައް މަރާމާތު ކުރަން ފަށައިފި"]]
61
+ )
62
 
 
63
  demo.launch()
gradio_cached_examples/16/indices.csv ADDED
@@ -0,0 +1 @@
 
 
1
+ 0
gradio_cached_examples/16/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ English Translation,flag,username,timestamp
2
+ flooding roads in the city,,,2024-06-20 17:50:47.570666
requirements.txt ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.31.0
2
+ aiofiles==23.2.1
3
+ aiohttp==3.9.5
4
+ aiosignal==1.3.1
5
+ altair==5.3.0
6
+ annotated-types==0.7.0
7
+ anyio==4.4.0
8
+ asttokens==2.4.1
9
+ async-timeout==4.0.3
10
+ attrs==23.2.0
11
+ Authlib==1.3.1
12
+ certifi==2024.6.2
13
+ cffi==1.16.0
14
+ charset-normalizer==3.3.2
15
+ click==8.0.4
16
+ contourpy==1.2.1
17
+ cryptography==42.0.8
18
+ cycler==0.12.1
19
+ datasets==2.19.2
20
+ decorator==5.1.1
21
+ dill==0.3.8
22
+ dnspython==2.6.1
23
+ email_validator==2.1.1
24
+ exceptiongroup==1.2.1
25
+ executing==2.0.1
26
+ fastapi==0.111.0
27
+ fastapi-cli==0.0.4
28
+ ffmpy==0.3.2
29
+ filelock==3.14.0
30
+ fonttools==4.53.0
31
+ frozenlist==1.4.1
32
+ fsspec==2024.3.1
33
+ gradio==4.36.1
34
+ gradio_client==1.0.1
35
+ h11==0.14.0
36
+ hf_transfer==0.1.6
37
+ httpcore==1.0.5
38
+ httptools==0.6.1
39
+ httpx==0.27.0
40
+ huggingface-hub==0.23.3
41
+ idna==3.7
42
+ importlib_resources==6.4.0
43
+ ipython==8.25.0
44
+ itsdangerous==2.2.0
45
+ jedi==0.19.1
46
+ Jinja2==3.1.4
47
+ jsonschema==4.22.0
48
+ jsonschema-specifications==2023.12.1
49
+ kiwisolver==1.4.5
50
+ markdown-it-py==3.0.0
51
+ MarkupSafe==2.1.5
52
+ matplotlib==3.9.0
53
+ matplotlib-inline==0.1.7
54
+ mdurl==0.1.2
55
+ mpmath==1.3.0
56
+ multidict==6.0.5
57
+ multiprocess==0.70.16
58
+ networkx==3.3
59
+ numpy==1.26.4
60
+ nvidia-cublas-cu12==12.1.3.1
61
+ nvidia-cuda-cupti-cu12==12.1.105
62
+ nvidia-cuda-nvrtc-cu12==12.1.105
63
+ nvidia-cuda-runtime-cu12==12.1.105
64
+ nvidia-cudnn-cu12==8.9.2.26
65
+ nvidia-cufft-cu12==11.0.2.54
66
+ nvidia-curand-cu12==10.3.2.106
67
+ nvidia-cusolver-cu12==11.4.5.107
68
+ nvidia-cusparse-cu12==12.1.0.106
69
+ nvidia-nccl-cu12==2.19.3
70
+ nvidia-nvjitlink-cu12==12.5.40
71
+ nvidia-nvtx-cu12==12.1.105
72
+ orjson==3.10.3
73
+ packaging==24.0
74
+ pandas==2.2.2
75
+ parso==0.8.4
76
+ pexpect==4.9.0
77
+ pillow==10.3.0
78
+ prompt_toolkit==3.0.47
79
+ protobuf==3.20.3
80
+ psutil==5.9.8
81
+ ptyprocess==0.7.0
82
+ pure-eval==0.2.2
83
+ pyarrow==16.1.0
84
+ pyarrow-hotfix==0.6
85
+ pycparser==2.22
86
+ pydantic==2.7.3
87
+ pydantic_core==2.18.4
88
+ pydub==0.25.1
89
+ Pygments==2.18.0
90
+ pyparsing==3.1.2
91
+ python-dateutil==2.9.0.post0
92
+ python-dotenv==1.0.1
93
+ python-multipart==0.0.9
94
+ pytz==2024.1
95
+ PyYAML==6.0.1
96
+ referencing==0.35.1
97
+ regex==2024.5.15
98
+ requests==2.32.3
99
+ rich==13.7.1
100
+ rpds-py==0.18.1
101
+ ruff==0.4.8
102
+ safetensors==0.4.3
103
+ semantic-version==2.10.0
104
+ sentencepiece==0.2.0
105
+ shellingham==1.5.4
106
+ six==1.16.0
107
+ sniffio==1.3.1
108
+ spaces==0.28.3
109
+ stack-data==0.6.3
110
+ starlette==0.37.2
111
+ sympy==1.12.1
112
+ tokenizers==0.19.1
113
+ tomlkit==0.12.0
114
+ toolz==0.12.1
115
+ torch==2.2.0
116
+ tqdm==4.66.4
117
+ traitlets==5.14.3
118
+ transformers==4.41.2
119
+ triton==2.2.0
120
+ typer==0.12.3
121
+ typing_extensions==4.12.1
122
+ tzdata==2024.1
123
+ ujson==5.10.0
124
+ urllib3==2.2.1
125
+ uvicorn==0.30.1
126
+ uvloop==0.19.0
127
+ watchfiles==0.22.0
128
+ wcwidth==0.2.13
129
+ websockets==11.0.3
130
+ xxhash==3.4.1
131
+ yarl==1.9.4