File size: 2,221 Bytes
1bba5ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import fastai
from fastai.vision.all import *
import timm
from PIL import Image
from pathlib import Path
from os import path
from tqdm.auto import tqdm
from urllib.error import HTTPError, URLError

def search_images_ddg(term, max_images=200):
    "Search for `term` with DuckDuckGo and return a unique urls of about `max_images` images"
    assert max_images<1000
    url = 'https://duckduckgo.com/'
    res = urlread(url,data={'q':term})
    searchObj = re.search(r'vqd=([\d-]+)\&', res)
    assert searchObj
    requestUrl = url + 'i.js'
    params = dict(l='us-en', o='json', q=term, vqd=searchObj.group(1), f=',,,', p='1', v7exp='a')
    urls,data = set(),{'next':1}
    headers = dict(referer='https://duckduckgo.com/')
    while len(urls)<max_images and 'next' in data:
        try:
            res = urlread(requestUrl, data=params, headers=headers)
            data = json.loads(res) if res else {}
            urls.update(L(data['results']).itemgot('image'))
            requestUrl = url + data['next']
        except (URLError,HTTPError): pass
        time.sleep(1)
    return L(urls)[:max_images]

tool_names = "resistor", "bipolar transistor", "mosfet", "capacitor", "inductor", "wire", "led", "diode", "thermistor", "switch", "battery", "hammer", "screwdriver", "scissors", "wrench", "mallet", "axe"
path = Path("data", "tools")
path.absolute()


if not path.exists():
    path.mkdir(parents=True)
    for o in tqdm(tool_names):
        dest = (path/o)
        dest.mkdir(exist_ok=True)
        results = search_images_ddg(f'{o}', max_images=20)
        download_images(dest, urls=results, n_workers=2)

fns = get_image_files(path)
failed = verify_images(fns)
failed.map(Path.unlink)

data_config = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    splitter=RandomSplitter(seed=42),
    get_y=parent_label,
    item_tfms=Resize(224)
)

dls = data_config.dataloaders(path)
connectors = data_config.new(item_tfms=RandomResizedCrop(224, min_scale=0.5), batch_tfms=aug_transforms())
dls = connectors.dataloaders(path)

learn = vision_learner(dls, 'convnext_small.fb_in22k_ft_in1k', metrics=error_rate)
learn.fine_tune(7, freeze_epochs=1)
learn.path = Path('.')
learn.export()