sales-prediction / README.md
tonyassi's picture
Upload folder using huggingface_hub
5047659 verified
metadata
license: apache-2.0
base_model: google/vit-base-patch16-224
tags:
  - Image Regression
datasets:
  - tonyassi/clothing-sales-ds
metrics:
  - accuracy
model-index:
  - name: sales-prediction
    results: []

sales-prediction

Image Regression Model

This model was trained with Image Regression Model Trainer. It takes an image as input and outputs a float value.

from ImageRegression import predict
predict(repo_id='tonyassi/sales-prediction',image_path='image.jpg')

Dataset

Dataset: tonyassi/clothing-sales-ds
Value Column: 'sales'
Train Test Split: 0.2


Training

Base Model: google/vit-base-patch16-224
Epochs: 10
Learning Rate: 0.0001


Usage

Download

git clone https://github.com/TonyAssi/ImageRegression.git
cd ImageRegression

Installation

pip install -r requirements.txt

Import

from ImageRegression import train_model, upload_model, predict

Inference (Prediction)

  • repo_id 🤗 repo id of the model
  • image_path path to image
predict(repo_id='tonyassi/sales-prediction',
        image_path='image.jpg')

The first time this function is called it'll download the safetensor model. Subsequent function calls will run faster.

Train Model

  • dataset_id 🤗 dataset id
  • value_column_name column name of prediction values in dataset
  • test_split test split of the train/test split
  • output_dir the directory where the checkpoints will be saved
  • num_train_epochs training epochs
  • learning_rate learning rate
train_model(dataset_id='tonyassi/clothing-sales-ds',
            value_column_name='sales',
            test_split=0.2,
            output_dir='./results',
            num_train_epochs=10,
            learning_rate=0.0001)

The trainer will save the checkpoints in the output_dir location. The model.safetensors are the trained weights you'll use for inference (predicton).

Upload Model

This function will upload your model to the 🤗 Hub.

  • model_id the name of the model id
  • token go here to create a new 🤗 token
  • checkpoint_dir checkpoint folder that will be uploaded
upload_model(model_id='sales-prediction',
             token='YOUR_HF_TOKEN',
             checkpoint_dir='./results/checkpoint-940')