edesaras commited on
Commit
1bba5ba
1 Parent(s): 61d4c6d

First Commit

Browse files
Files changed (4) hide show
  1. app.py +26 -0
  2. model.pkl +3 -0
  3. requirements.txt +6 -0
  4. train.py +63 -0
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import gradio as gr
3
+ from fastai.vision.all import *
4
+
5
+ # to solve the mismatch between windows (local) system and huggingspace (linux)
6
+ plt = platform.system()
7
+ if plt == 'Linux': pathlib.WindowsPath = pathlib.PosixPath
8
+
9
+ learn = load_learner('export.pkl')
10
+ labels = learn.dls.vocab
11
+
12
+ def predict(img):
13
+ img = PILImage.create(img)
14
+ pred,pred_idx,probs = learn.predict(img)
15
+ return {labels[i]: float(probs[i]) for i in range(len(labels))}
16
+
17
+ iface = gr.Interface(
18
+ fn=predict,
19
+ inputs=gr.components.Image(shape=(512, 512)),
20
+ outputs=gr.components.Label(num_top_classes=3),
21
+ description="Pet Classifier",
22
+ article="<p style='text-align: center'><a href='google.com' target='_blank'>Blog post</a></p>",
23
+ live=True,
24
+ )
25
+
26
+ iface.launch(enable_queue=True)
model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a700e65e0aaa01d0b1882387185378a6e2f614443e1d587b9a4a8a0ab753e273
3
+ size 201292933
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastai
2
+ torch
3
+ gradio
4
+ numpy
5
+ pandas
6
+ timm
train.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fastai
2
+ from fastai.vision.all import *
3
+ import timm
4
+ from PIL import Image
5
+ from pathlib import Path
6
+ from os import path
7
+ from tqdm.auto import tqdm
8
+ from urllib.error import HTTPError, URLError
9
+
10
+ def search_images_ddg(term, max_images=200):
11
+ "Search for `term` with DuckDuckGo and return a unique urls of about `max_images` images"
12
+ assert max_images<1000
13
+ url = 'https://duckduckgo.com/'
14
+ res = urlread(url,data={'q':term})
15
+ searchObj = re.search(r'vqd=([\d-]+)\&', res)
16
+ assert searchObj
17
+ requestUrl = url + 'i.js'
18
+ params = dict(l='us-en', o='json', q=term, vqd=searchObj.group(1), f=',,,', p='1', v7exp='a')
19
+ urls,data = set(),{'next':1}
20
+ headers = dict(referer='https://duckduckgo.com/')
21
+ while len(urls)<max_images and 'next' in data:
22
+ try:
23
+ res = urlread(requestUrl, data=params, headers=headers)
24
+ data = json.loads(res) if res else {}
25
+ urls.update(L(data['results']).itemgot('image'))
26
+ requestUrl = url + data['next']
27
+ except (URLError,HTTPError): pass
28
+ time.sleep(1)
29
+ return L(urls)[:max_images]
30
+
31
+ tool_names = "resistor", "bipolar transistor", "mosfet", "capacitor", "inductor", "wire", "led", "diode", "thermistor", "switch", "battery", "hammer", "screwdriver", "scissors", "wrench", "mallet", "axe"
32
+ path = Path("data", "tools")
33
+ path.absolute()
34
+
35
+
36
+ if not path.exists():
37
+ path.mkdir(parents=True)
38
+ for o in tqdm(tool_names):
39
+ dest = (path/o)
40
+ dest.mkdir(exist_ok=True)
41
+ results = search_images_ddg(f'{o}', max_images=20)
42
+ download_images(dest, urls=results, n_workers=2)
43
+
44
+ fns = get_image_files(path)
45
+ failed = verify_images(fns)
46
+ failed.map(Path.unlink)
47
+
48
+ data_config = DataBlock(
49
+ blocks=(ImageBlock, CategoryBlock),
50
+ get_items=get_image_files,
51
+ splitter=RandomSplitter(seed=42),
52
+ get_y=parent_label,
53
+ item_tfms=Resize(224)
54
+ )
55
+
56
+ dls = data_config.dataloaders(path)
57
+ connectors = data_config.new(item_tfms=RandomResizedCrop(224, min_scale=0.5), batch_tfms=aug_transforms())
58
+ dls = connectors.dataloaders(path)
59
+
60
+ learn = vision_learner(dls, 'convnext_small.fb_in22k_ft_in1k', metrics=error_rate)
61
+ learn.fine_tune(7, freeze_epochs=1)
62
+ learn.path = Path('.')
63
+ learn.export()