tool_classifier / train.py
edesaras's picture
First Commit
1bba5ba
raw
history blame
No virus
2.22 kB
import fastai
from fastai.vision.all import *
import timm
from PIL import Image
from pathlib import Path
from os import path
from tqdm.auto import tqdm
from urllib.error import HTTPError, URLError
def search_images_ddg(term, max_images=200):
"Search for `term` with DuckDuckGo and return a unique urls of about `max_images` images"
assert max_images<1000
url = 'https://duckduckgo.com/'
res = urlread(url,data={'q':term})
searchObj = re.search(r'vqd=([\d-]+)\&', res)
assert searchObj
requestUrl = url + 'i.js'
params = dict(l='us-en', o='json', q=term, vqd=searchObj.group(1), f=',,,', p='1', v7exp='a')
urls,data = set(),{'next':1}
headers = dict(referer='https://duckduckgo.com/')
while len(urls)<max_images and 'next' in data:
try:
res = urlread(requestUrl, data=params, headers=headers)
data = json.loads(res) if res else {}
urls.update(L(data['results']).itemgot('image'))
requestUrl = url + data['next']
except (URLError,HTTPError): pass
time.sleep(1)
return L(urls)[:max_images]
tool_names = "resistor", "bipolar transistor", "mosfet", "capacitor", "inductor", "wire", "led", "diode", "thermistor", "switch", "battery", "hammer", "screwdriver", "scissors", "wrench", "mallet", "axe"
path = Path("data", "tools")
path.absolute()
if not path.exists():
path.mkdir(parents=True)
for o in tqdm(tool_names):
dest = (path/o)
dest.mkdir(exist_ok=True)
results = search_images_ddg(f'{o}', max_images=20)
download_images(dest, urls=results, n_workers=2)
fns = get_image_files(path)
failed = verify_images(fns)
failed.map(Path.unlink)
data_config = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(seed=42),
get_y=parent_label,
item_tfms=Resize(224)
)
dls = data_config.dataloaders(path)
connectors = data_config.new(item_tfms=RandomResizedCrop(224, min_scale=0.5), batch_tfms=aug_transforms())
dls = connectors.dataloaders(path)
learn = vision_learner(dls, 'convnext_small.fb_in22k_ft_in1k', metrics=error_rate)
learn.fine_tune(7, freeze_epochs=1)
learn.path = Path('.')
learn.export()