{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "eleventh-doctor-beta.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"a49e5fd0d85444a3aa9f786455ca8770": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_73e8d052a86647919649a367aa773c8e",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_40760124752846209e61177280a005bd",
"IPY_MODEL_eae3f41495884830818311e51920c956",
"IPY_MODEL_5d1a116b987549d780ee25723f83d45a"
]
}
},
"73e8d052a86647919649a367aa773c8e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"40760124752846209e61177280a005bd": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_46f7e33281354ef488945f5f1cfe4c06",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": "Epoch: 100%",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_e987dfed8c624717b5ae2054cce74f05"
}
},
"eae3f41495884830818311e51920c956": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_c49e73bdfe544de0ae62034cef7eb0da",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 4,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 4,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_52e599d90ccd44938d982310fb7e4341"
}
},
"5d1a116b987549d780ee25723f83d45a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_c01768976adf465ebfad5c3eedfe1d58",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 4/4 [00:11<00:00, 2.87s/it]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_27fb7c7e261b4a3b9656a37b1fcde71a"
}
},
"46f7e33281354ef488945f5f1cfe4c06": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"e987dfed8c624717b5ae2054cce74f05": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"c49e73bdfe544de0ae62034cef7eb0da": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"52e599d90ccd44938d982310fb7e4341": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"c01768976adf465ebfad5c3eedfe1d58": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"27fb7c7e261b4a3b9656a37b1fcde71a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"1c82670ef31346eb97dff63429fd522f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_a8c2fda5e0be4c638919b4ca1007dea3",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_5e60bbde81ed452fa0c8d7094d98b052",
"IPY_MODEL_3ac055433ca94c2ebe9f8b44e38be5e0",
"IPY_MODEL_67e70ffbe152488fb036968be105a368"
]
}
},
"a8c2fda5e0be4c638919b4ca1007dea3": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"5e60bbde81ed452fa0c8d7094d98b052": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_bbeb9a01f5bb4aebba239db555f4b16b",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": "Iteration: 100%",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_771602cc4d9444e7ab0d20438639cddd"
}
},
"3ac055433ca94c2ebe9f8b44e38be5e0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_f154ee2be8a044b3aeeb0e904411ffbd",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 5,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 5,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_a3a6841089054f1cbc31f638424674b3"
}
},
"67e70ffbe152488fb036968be105a368": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_41061818a9c94956a7d1cd129028d805",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 5/5 [00:02<00:00, 1.89it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_bab64ef3864248018e9476bc8c4018f4"
}
},
"bbeb9a01f5bb4aebba239db555f4b16b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"771602cc4d9444e7ab0d20438639cddd": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"f154ee2be8a044b3aeeb0e904411ffbd": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"a3a6841089054f1cbc31f638424674b3": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"41061818a9c94956a7d1cd129028d805": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"bab64ef3864248018e9476bc8c4018f4": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"f3fa20cd1c40453bb17b2f109607e1bf": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_97b3a5270a014515bbc712b44dba38a0",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_afe3a3438fb145d4a015fdb0709e3156",
"IPY_MODEL_6d6078316fe54c9e83a3c3a35a1169fc",
"IPY_MODEL_4700a281e7d347db8a58c6f181706b54"
]
}
},
"97b3a5270a014515bbc712b44dba38a0": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"afe3a3438fb145d4a015fdb0709e3156": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_84f8bdfeb6bf4bb7ba4585eba47a7092",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": "Iteration: 100%",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_9262f880b64a4abb80013f6997901bcb"
}
},
"6d6078316fe54c9e83a3c3a35a1169fc": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_386289f0bf56453484a6637d3263da4c",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 5,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 5,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_8366904443f649928aa9cfd915cd938a"
}
},
"4700a281e7d347db8a58c6f181706b54": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_80f977590ae94733a8a8552241c12e3b",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 5/5 [00:02<00:00, 1.79it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_3ff5971c055144a3b81190d199ffe3de"
}
},
"84f8bdfeb6bf4bb7ba4585eba47a7092": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"9262f880b64a4abb80013f6997901bcb": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"386289f0bf56453484a6637d3263da4c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"8366904443f649928aa9cfd915cd938a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"80f977590ae94733a8a8552241c12e3b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"3ff5971c055144a3b81190d199ffe3de": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"be5e0fa21fea43e8bf003ae954c29d03": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_7535306bf05847629946e333021e0ef5",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_671f7a7556b1412cbd48237293431c0d",
"IPY_MODEL_8957849d6dbf44cfafa965d71de78255",
"IPY_MODEL_a55488797f71453e917032008f198b9c"
]
}
},
"7535306bf05847629946e333021e0ef5": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"671f7a7556b1412cbd48237293431c0d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_38ccc42c4ea14f0e83fca1cb9452bfad",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": "Iteration: 100%",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_ef2ddbabd7b042c0821cf999a9265867"
}
},
"8957849d6dbf44cfafa965d71de78255": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_0222497184b2446e90f801d84af22b82",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 5,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 5,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_6a70d59d1c834989a54deda9e776bf41"
}
},
"a55488797f71453e917032008f198b9c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_62e05dc21f5e438cb5e94e400071a39b",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 5/5 [00:02<00:00, 1.86it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_803884fef29b433db14e61df7fae1ee7"
}
},
"38ccc42c4ea14f0e83fca1cb9452bfad": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"ef2ddbabd7b042c0821cf999a9265867": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"0222497184b2446e90f801d84af22b82": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"6a70d59d1c834989a54deda9e776bf41": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"62e05dc21f5e438cb5e94e400071a39b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"803884fef29b433db14e61df7fae1ee7": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"83414a06fd504f71aa212d9fce15ffb5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_e2c8785b7c51448296c8cf54331f4a68",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_3f9881738b644024aea6371982320d97",
"IPY_MODEL_143666ff0c7b4e6491779b64f6212818",
"IPY_MODEL_bf699df8e3d844f68a68491c00e8f0bc"
]
}
},
"e2c8785b7c51448296c8cf54331f4a68": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"3f9881738b644024aea6371982320d97": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_9e955ce77097447a8c085ea592ae8a5e",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": "Iteration: 100%",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_2f2aecb73861473ba553b4ebebd52e0b"
}
},
"143666ff0c7b4e6491779b64f6212818": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_db21f601fd6342a5815cea18c417aa99",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 5,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 5,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_67bf9d4f1a8e431bb580337be0e67f82"
}
},
"bf699df8e3d844f68a68491c00e8f0bc": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_ddcba8098de6433a8584045d52cb1f3b",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 5/5 [00:02<00:00, 1.89it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_60062f7860944d0d85c4ef1773c151c3"
}
},
"9e955ce77097447a8c085ea592ae8a5e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"2f2aecb73861473ba553b4ebebd52e0b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"db21f601fd6342a5815cea18c417aa99": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"67bf9d4f1a8e431bb580337be0e67f82": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"ddcba8098de6433a8584045d52cb1f3b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"60062f7860944d0d85c4ef1773c151c3": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"cc13e655b33d4fa390960d1fa40a0e1f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_2e418d21ae4f4123a9d7b13cbc368605",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_ba1f476b8bdc4fce8a703a0220bd4770",
"IPY_MODEL_26560e743afa4a3fad0eb1e0ed567a64",
"IPY_MODEL_27e7a38811a94548b8dc1980e1c83acd"
]
}
},
"2e418d21ae4f4123a9d7b13cbc368605": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"ba1f476b8bdc4fce8a703a0220bd4770": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_c5ceabb016a74435be6659f1116c9945",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": "Evaluating: ",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_4ca6ee080dad41aa8abbbea5a96e3922"
}
},
"26560e743afa4a3fad0eb1e0ed567a64": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_c0fd4025a3e84d0ca8c360887d7126ba",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 1,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 0,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_4a374856a56c4dc6a163ee5779d6b666"
}
},
"27e7a38811a94548b8dc1980e1c83acd": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_ef067c6b95ac48c58428f62ecef22e33",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 0/0 [00:00<?, ?it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_88521418122646b0b6c7d41be73e747a"
}
},
"c5ceabb016a74435be6659f1116c9945": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"4ca6ee080dad41aa8abbbea5a96e3922": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"c0fd4025a3e84d0ca8c360887d7126ba": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"4a374856a56c4dc6a163ee5779d6b666": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": "20px",
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"ef067c6b95ac48c58428f62ecef22e33": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"88521418122646b0b6c7d41be73e747a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tSIO1yDEJbxI",
"outputId": "43bc1501-529c-48bc-d825-08c242d5de04"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
]
}
],
"source": [
"from google.colab import drive\n",
"drive.mount(\"/content/drive\")"
]
},
{
"cell_type": "code",
"source": [
"!pip -q install transformers"
],
"metadata": {
"id": "LwrtmgMvMSey"
},
"execution_count": 58,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"os.chdir(\"/content/drive/My Drive/Colab Notebooks\")"
],
"metadata": {
"id": "Mp864lxgIbJE"
},
"execution_count": 59,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# libraries\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from __future__ import division\n",
"\n",
"import random\n",
"import glob\n",
"import logging\n",
"import os\n",
"import pickle\n",
"import re\n",
"import shutil\n",
"from typing import List, Dict, Tuple\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler\n",
"from torch.utils.data.distributed import DistributedSampler\n",
"from tqdm.notebook import tqdm, trange\n",
"\n",
"from pathlib import Path\n",
"\n",
"from transformers import (\n",
" MODEL_WITH_LM_HEAD_MAPPING,\n",
" WEIGHTS_NAME,\n",
" AdamW,\n",
" AutoConfig,\n",
" PreTrainedModel,\n",
" PreTrainedTokenizer,\n",
" get_linear_schedule_with_warmup,\n",
")\n",
"\n",
"try:\n",
" from torch.utils.tensorboard import SummaryWriter\n",
"except ImportError:\n",
" from tensorboardX import SummaryWriter"
],
"metadata": {
"id": "ujmUewQ5NVoO"
},
"execution_count": 60,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# visualize raw data\n",
"d = pd.read_csv(\"/content/drive/MyDrive/final-all-scripts.csv\", sep=\"delimiter\", header=None)\n",
"d.head()"
],
"metadata": {
"id": "8gMOER_tVuIr",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 261
},
"outputId": "f8275436-770a-424d-fb41-a296ddc45045"
},
"execution_count": 61,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.\n",
" \n"
]
},
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"
\n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 0 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" [Class room] | \n",
"
\n",
" \n",
" 1 | \n",
" (A city is flying through space, stuck on the ... | \n",
"
\n",
" \n",
" 2 | \n",
" COMPUTER: Well done, Mabel. Well done, Alfie. ... | \n",
"
\n",
" \n",
" 3 | \n",
" (It is the little boy's turn.) | \n",
"
\n",
" \n",
" 4 | \n",
" COMPUTER: Bad boy, Timmy. | \n",
"
\n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
],
"text/plain": [
" 0\n",
"0 [Class room]\n",
"1 (A city is flying through space, stuck on the ...\n",
"2 COMPUTER: Well done, Mabel. Well done, Alfie. ...\n",
"3 (It is the little boy's turn.)\n",
"4 COMPUTER: Bad boy, Timmy."
]
},
"metadata": {},
"execution_count": 61
}
]
},
{
"cell_type": "markdown",
"source": [
"## Data Preprocessing"
],
"metadata": {
"id": "Vr2Y_QbooJUM"
}
},
{
"cell_type": "code",
"source": [
"print(f\"Data type of file: {type(d)}\",\n",
" f\"\\nShape of file: {d.shape}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_mKAYMdxj2-l",
"outputId": "5bb10015-4046-41c8-e774-ed42e87ccfc7"
},
"execution_count": 62,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Data type of file: \n",
"Shape of file: (27597, 1)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(f\"Type of first element: {type(d.iloc[0])}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YiVX_L8P0P1K",
"outputId": "e280fb35-b11d-42d8-8ba3-50843120cea9"
},
"execution_count": 63,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Type of first element: \n"
]
}
]
},
{
"cell_type": "code",
"source": [
"dd = []\n",
"\n",
"for i in d[0]:\n",
" if not (i.startswith(\"(\") or i.startswith(\"[\")):\n",
" dd.append(i)\n",
"\n",
"dd[:10]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "y019y7pT02GL",
"outputId": "30c9a13b-3a42-4ac2-871b-9f830d6aa5a4"
},
"execution_count": 64,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['COMPUTER: Well done, Mabel. Well done, Alfie. Good girl, Tabitha. Very well done, Ranjit. Good girl, Chloe. Well done, Ben. Well done, Mandy.',\n",
" 'COMPUTER: Bad boy, Timmy.',\n",
" 'COMPUTER: Zero.',\n",
" \"MANDY: You got a zero, didn't you?\",\n",
" 'TIMMY: Yeah? So?',\n",
" \"MANDY: You'll have to walk home then.\",\n",
" \"TIMMY: Walk to London? That's twenty decks!\",\n",
" \"MANDY: You can't ride a Vator with a zero. You know what happens. You'll get sent below.\",\n",
" \"MANDY: I'll wait for you.\",\n",
" \"SMILER: Welcome to Vator Verse, sponsored by McLintock's Candy Burgers. TIMMY: London, please.\"]"
]
},
"metadata": {},
"execution_count": 64
}
]
},
{
"cell_type": "code",
"source": [
"# person-text split\n",
"#dd[1].split(\":\")\n",
"\n",
"# each dialogue\n",
"#dialogues[0][1]"
],
"metadata": {
"id": "fsT-q0762yJ8"
},
"execution_count": 65,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dialogues = [l.split(\":\") for l in dd]\n",
"len(dialogues)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lcA-TTbR64_q",
"outputId": "9b59a4f0-b3fb-4ff5-c2a3-059183309548"
},
"execution_count": 66,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"20594"
]
},
"metadata": {},
"execution_count": 66
}
]
},
{
"cell_type": "code",
"source": [
"chars= []\n",
"txts = []\n",
"\n",
"for i in range(len(dialogues)):\n",
" chars.append(dialogues[i][0])\n",
" txts.append(dialogues[i][1])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 200
},
"id": "FtSYUoGO7XLk",
"outputId": "78760c97-5c54-4b9b-b263-a9ee0809ca1f"
},
"execution_count": 67,
"outputs": [
{
"output_type": "error",
"ename": "IndexError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdialogues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mchars\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdialogues\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mtxts\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdialogues\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mIndexError\u001b[0m: list index out of range"
]
}
]
},
{
"cell_type": "code",
"source": [
"len(chars)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TWxj098VzBXF",
"outputId": "198df600-0cbf-431b-eede-58a40b62f108"
},
"execution_count": 68,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"104"
]
},
"metadata": {},
"execution_count": 68
}
]
},
{
"cell_type": "code",
"source": [
"dialogues[len(dialogues)-1][1]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "l0PCKQI5LkuR",
"outputId": "fb6f3000-7378-4f30-bd3a-e193bce34ba8"
},
"execution_count": 69,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"' So, dresses, then.'"
]
},
"metadata": {},
"execution_count": 69
}
]
},
{
"cell_type": "code",
"source": [
"#dialogues[len(dialogues)-1][1] == dialogues[-1][1]"
],
"metadata": {
"id": "kFVPluE3-ojX"
},
"execution_count": 70,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#pd.isnull(dialogues).sum()"
],
"metadata": {
"id": "oemthiCWInCq"
},
"execution_count": 71,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# seperate person-text and convert dataframe\n",
"df = pd.DataFrame(list(zip(chars, txts)), columns=[\"Character\", \"Text\"])\n",
"df.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "UqgM35hBrQp6",
"outputId": "346eab19-132c-433b-9a99-3e1012b83eac"
},
"execution_count": 72,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Character | \n",
" Text | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" COMPUTER | \n",
" Well done, Mabel. Well done, Alfie. Good girl... | \n",
"
\n",
" \n",
" 1 | \n",
" COMPUTER | \n",
" Bad boy, Timmy. | \n",
"
\n",
" \n",
" 2 | \n",
" COMPUTER | \n",
" Zero. | \n",
"
\n",
" \n",
" 3 | \n",
" MANDY | \n",
" You got a zero, didn't you? | \n",
"
\n",
" \n",
" 4 | \n",
" TIMMY | \n",
" Yeah? So? | \n",
"
\n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
],
"text/plain": [
" Character Text\n",
"0 COMPUTER Well done, Mabel. Well done, Alfie. Good girl...\n",
"1 COMPUTER Bad boy, Timmy.\n",
"2 COMPUTER Zero.\n",
"3 MANDY You got a zero, didn't you?\n",
"4 TIMMY Yeah? So?"
]
},
"metadata": {},
"execution_count": 72
}
]
},
{
"cell_type": "code",
"source": [
"CHARACTER_NAME = \"DOCTOR\""
],
"metadata": {
"id": "FY639-Fi7WfF"
},
"execution_count": 73,
"outputs": []
},
{
"cell_type": "code",
"source": [
"contexted = []\n",
"\n",
"# context window of size 7\n",
"n = 7\n",
"\n",
"for i in df[df.Character == CHARACTER_NAME].index:\n",
" if i < n:\n",
" continue\n",
" row = []\n",
" prev = i - 1 - n # we additionally substract 1, so row will contain current responce and 7 previous responces \n",
" for j in range(i, prev, -1):\n",
" row.append(df.Text[j])\n",
" contexted.append(row)\n",
"\n",
"columns = ['response', 'context'] \n",
"columns = columns + ['context/' + str(i) for i in range(n - 1)]\n",
"\n",
"df = pd.DataFrame.from_records(contexted, columns=columns)"
],
"metadata": {
"id": "vSqHrtAOz_1j"
},
"execution_count": 74,
"outputs": []
},
{
"cell_type": "code",
"source": [
"df.sample(6)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 490
},
"id": "GEVvAnKN0xad",
"outputId": "cdd5a2d2-f465-4d2a-947a-7269447ede2e"
},
"execution_count": 75,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" response | \n",
" context | \n",
" context/0 | \n",
" context/1 | \n",
" context/2 | \n",
" context/3 | \n",
" context/4 | \n",
" context/5 | \n",
"
\n",
" \n",
" \n",
" \n",
" 8 | \n",
" An important thing. In fact, Thing One. We ar... | \n",
" A thing? | \n",
" Course we can. But first, there's a thing. | \n",
" Can we go out and see? | \n",
" Well, come on. I've found us a spaceship. Thi... | \n",
" Doctor! | \n",
" Isn't that amazing? | \n",
" Doctor? | \n",
"
\n",
" \n",
" 16 | \n",
" Don't know. I think a lot. It's hard to keep ... | \n",
" Why did you just do that with the water? | \n",
" Sorry. Checking all the water in this area. T... | \n",
" What are you doing? | \n",
" Life on a giant starship. Back to basics. Bic... | \n",
" London Market is a crime-free zone. | \n",
" Now, come on, look around you. Actually look. | \n",
" Oh my God, I'm in my nightie. | \n",
"
\n",
" \n",
" 0 | \n",
" Come on, Pond. | \n",
" My name is Amy Pond. When I was seven, I had ... | \n",
" Help! Help me! | \n",
" Though the man above might say hello, expect ... | \n",
" A horse and a man, above, below. One has a pl... | \n",
" Welcome to Vator Verse, sponsored by McLintoc... | \n",
" I'll wait for you. | \n",
" You can't ride a Vator with a zero. You know ... | \n",
"
\n",
" \n",
" 23 | \n",
" What I always do. Stay out of trouble. Badly. | \n",
" What are you going to do? | \n",
" It's this or Leadworth. What do you think? Le... | \n",
" No, hang on. What do I do? I don't know what ... | \n",
" They're clean. Everything else here is all ba... | \n",
" But they're just things. | \n",
" Deck two oh seven. Apple Sesame block, dwelli... | \n",
" Where'd she go? | \n",
"
\n",
" \n",
" 11 | \n",
" Come on, use your eyes. Notice everything. Wh... | \n",
" What's wrong? | \n",
" Oh, lovely. You're a cheery one. Never mind d... | \n",
" I'm in the future. Like hundreds of years in ... | \n",
" Welcome to London Market. You are being monit... | \n",
" Doctor? | \n",
" So we're like a wildlife documentary, yeah? B... | \n",
" Ooo, that's interesting. | \n",
"
\n",
" \n",
" 9 | \n",
" Ooo, that's interesting. | \n",
" An important thing. In fact, Thing One. We ar... | \n",
" A thing? | \n",
" Course we can. But first, there's a thing. | \n",
" Can we go out and see? | \n",
" Well, come on. I've found us a spaceship. Thi... | \n",
" Doctor! | \n",
" Isn't that amazing? | \n",
"
\n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
],
"text/plain": [
" response ... context/5\n",
"8 An important thing. In fact, Thing One. We ar... ... Doctor?\n",
"16 Don't know. I think a lot. It's hard to keep ... ... Oh my God, I'm in my nightie.\n",
"0 Come on, Pond. ... You can't ride a Vator with a zero. You know ...\n",
"23 What I always do. Stay out of trouble. Badly. ... Where'd she go?\n",
"11 Come on, use your eyes. Notice everything. Wh... ... Ooo, that's interesting.\n",
"9 Ooo, that's interesting. ... Isn't that amazing?\n",
"\n",
"[6 rows x 8 columns]"
]
},
"metadata": {},
"execution_count": 75
}
]
},
{
"cell_type": "code",
"source": [
"trn_df, val_df = train_test_split(df, test_size=0.1)"
],
"metadata": {
"id": "nYM_4zKirQ5A"
},
"execution_count": 76,
"outputs": []
},
{
"cell_type": "code",
"source": [
"trn_df.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 424
},
"id": "RKF8dGVxS61X",
"outputId": "46d99699-9659-4d69-fd54-16b290ec2491"
},
"execution_count": 77,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" response | \n",
" context | \n",
" context/0 | \n",
" context/1 | \n",
" context/2 | \n",
" context/3 | \n",
" context/4 | \n",
" context/5 | \n",
"
\n",
" \n",
" \n",
" \n",
" 7 | \n",
" Course we can. But first, there's a thing. | \n",
" Can we go out and see? | \n",
" Well, come on. I've found us a spaceship. Thi... | \n",
" Doctor! | \n",
" Isn't that amazing? | \n",
" Doctor? | \n",
" Migrating to the stars. | \n",
" Doctor? | \n",
"
\n",
" \n",
" 17 | \n",
" There. | \n",
" Where? | \n",
" Don't know. I think a lot. It's hard to keep ... | \n",
" Why did you just do that with the water? | \n",
" Sorry. Checking all the water in this area. T... | \n",
" What are you doing? | \n",
" Life on a giant starship. Back to basics. Bic... | \n",
" London Market is a crime-free zone. | \n",
"
\n",
" \n",
" 1 | \n",
" Now do you believe me? | \n",
" And my imaginary friend came back. | \n",
" Come on, Pond. | \n",
" My name is Amy Pond. When I was seven, I had ... | \n",
" Help! Help me! | \n",
" Though the man above might say hello, expect ... | \n",
" A horse and a man, above, below. One has a pl... | \n",
" Welcome to Vator Verse, sponsored by McLintoc... | \n",
"
\n",
" \n",
" 20 | \n",
" Deck two oh seven. Apple Sesame block, dwelli... | \n",
" Where'd she go? | \n",
" Hundreds of parents walking past who spot her... | \n",
" Are you a parent? | \n",
" Crying silently. I mean, children cry because... | \n",
" One little girl crying. So? | \n",
" I'll have a look on the monitors. | \n",
" Apparently. | \n",
"
\n",
" \n",
" 0 | \n",
" Come on, Pond. | \n",
" My name is Amy Pond. When I was seven, I had ... | \n",
" Help! Help me! | \n",
" Though the man above might say hello, expect ... | \n",
" A horse and a man, above, below. One has a pl... | \n",
" Welcome to Vator Verse, sponsored by McLintoc... | \n",
" I'll wait for you. | \n",
" You can't ride a Vator with a zero. You know ... | \n",
"
\n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
],
"text/plain": [
" response ... context/5\n",
"7 Course we can. But first, there's a thing. ... Doctor?\n",
"17 There. ... London Market is a crime-free zone.\n",
"1 Now do you believe me? ... Welcome to Vator Verse, sponsored by McLintoc...\n",
"20 Deck two oh seven. Apple Sesame block, dwelli... ... Apparently.\n",
"0 Come on, Pond. ... You can't ride a Vator with a zero. You know ...\n",
"\n",
"[5 rows x 8 columns]"
]
},
"metadata": {},
"execution_count": 77
}
]
},
{
"cell_type": "code",
"source": [
"# create dataset suitable for our model\n",
"def construct_conv(row, tokenizer, eos = True):\n",
" flatten = lambda l: [item for sublist in l for item in sublist]\n",
" conv = list(reversed([tokenizer.encode(x) + [tokenizer.eos_token_id] for x in row]))\n",
" conv = flatten(conv)\n",
" return conv\n",
"\n",
"class ConversationDataset(Dataset):\n",
" def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):\n",
"\n",
" block_size = block_size - (tokenizer.model_max_length - tokenizer.max_len_single_sentence)\n",
"\n",
" directory = args.cache_dir\n",
" cached_features_file = os.path.join(\n",
" directory, args.model_type + \"_cached_lm_\" + str(block_size)\n",
" )\n",
"\n",
" if os.path.exists(cached_features_file) and not args.overwrite_cache:\n",
" logger.info(\"Loading features from cached file %s\", cached_features_file)\n",
" with open(cached_features_file, \"rb\") as handle:\n",
" self.examples = pickle.load(handle)\n",
" else:\n",
" logger.info(\"Creating features from dataset file at %s\", directory)\n",
"\n",
" self.examples = []\n",
" for _, row in df.iterrows():\n",
" conv = construct_conv(row, tokenizer)\n",
" self.examples.append(conv)\n",
"\n",
" logger.info(\"Saving features into cached file %s\", cached_features_file)\n",
" with open(cached_features_file, \"wb\") as handle:\n",
" pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
"\n",
" def __len__(self):\n",
" return len(self.examples)\n",
"\n",
" def __getitem__(self, item):\n",
" return torch.tensor(self.examples[item], dtype=torch.long)"
],
"metadata": {
"id": "va9Olm-DoR9w"
},
"execution_count": 78,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Cacheing and storing of data/checkpoints\n",
"\n",
"def load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False):\n",
" return ConversationDataset(tokenizer, args, df_val if evaluate else df_trn)\n",
"\n",
"\n",
"def set_seed(args):\n",
" random.seed(args.seed)\n",
" np.random.seed(args.seed)\n",
" torch.manual_seed(args.seed)\n",
" if args.n_gpu > 0:\n",
" torch.cuda.manual_seed_all(args.seed)\n",
"\n",
"\n",
"def _sorted_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> List[str]:\n",
" ordering_and_checkpoint_path = []\n",
"\n",
" glob_checkpoints = glob.glob(os.path.join(args.output_dir, \"{}-*\".format(checkpoint_prefix)))\n",
"\n",
" for path in glob_checkpoints:\n",
" if use_mtime:\n",
" ordering_and_checkpoint_path.append((os.path.getmtime(path), path))\n",
" else:\n",
" regex_match = re.match(\".*{}-([0-9]+)\".format(checkpoint_prefix), path)\n",
" if regex_match and regex_match.groups():\n",
" ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n",
"\n",
" checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n",
" checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n",
" return checkpoints_sorted\n",
"\n",
"\n",
"def _rotate_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> None:\n",
" if not args.save_total_limit:\n",
" return\n",
" if args.save_total_limit <= 0:\n",
" return\n",
"\n",
" # Check if we should delete older checkpoint(s)\n",
" checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)\n",
" if len(checkpoints_sorted) <= args.save_total_limit:\n",
" return\n",
"\n",
" number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)\n",
" checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n",
" for checkpoint in checkpoints_to_be_deleted:\n",
" logger.info(\"Deleting older checkpoint [{}] due to args.save_total_limit\".format(checkpoint))\n",
" shutil.rmtree(checkpoint)"
],
"metadata": {
"id": "wj1yakcqTCNx"
},
"execution_count": 79,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Build Model"
],
"metadata": {
"id": "PNcEUFjSoML0"
}
},
{
"cell_type": "code",
"source": [
"from transformers import AutoModelWithLMHead, AutoModelForCausalLM, AutoTokenizer\n",
"import torch\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-small\")\n",
"model = AutoModelWithLMHead.from_pretrained(\"microsoft/DialoGPT-small\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UTI0g8PPg628",
"outputId": "8b8c6e19-fa17-42be-c9e4-25e5ae4c1819"
},
"execution_count": 80,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/transformers/models/auto/modeling_auto.py:787: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
" FutureWarning,\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# configs\n",
"\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())\n",
"MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)"
],
"metadata": {
"id": "lzDAg6-eg7Fj"
},
"execution_count": 81,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Args to allow for easy convertion of python script to notebook\n",
"class Args():\n",
" def __init__(self):\n",
" self.output_dir = 'output-small'\n",
" self.model_type = 'gpt2'\n",
" self.model_name_or_path = 'microsoft/DialoGPT-small'\n",
" self.config_name = 'microsoft/DialoGPT-small'\n",
" self.tokenizer_name = 'microsoft/DialoGPT-small'\n",
" self.cache_dir = 'cached'\n",
" self.block_size = 512\n",
" self.do_train = True\n",
" self.do_eval = True\n",
" self.evaluate_during_training = False\n",
" self.per_gpu_train_batch_size = 4\n",
" self.per_gpu_eval_batch_size = 4\n",
" self.gradient_accumulation_steps = 1\n",
" self.learning_rate = 5e-5\n",
" self.weight_decay = 0.0\n",
" self.adam_epsilon = 1e-8\n",
" self.max_grad_norm = 1.0\n",
" self.num_train_epochs = 4\n",
" self.max_steps = -1\n",
" self.warmup_steps = 0\n",
" self.logging_steps = 1000\n",
" self.save_steps = 3500\n",
" self.save_total_limit = None\n",
" self.eval_all_checkpoints = False\n",
" self.no_cuda = False\n",
" self.overwrite_output_dir = True\n",
" self.overwrite_cache = True\n",
" self.should_continue = False\n",
" self.seed = 42\n",
" self.local_rank = -1\n",
" self.fp16 = False\n",
" self.fp16_opt_level = 'O1'\n",
"\n",
"args = Args()"
],
"metadata": {
"id": "8b20p10Xg7M-"
},
"execution_count": 82,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Train and Evaluate"
],
"metadata": {
"id": "9QaybLujoTg-"
}
},
{
"cell_type": "code",
"source": [
"def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:\n",
" \"\"\" Train the model \"\"\"\n",
" if args.local_rank in [-1, 0]:\n",
" tb_writer = SummaryWriter()\n",
"\n",
" args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)\n",
"\n",
" def collate(examples: List[torch.Tensor]):\n",
" if tokenizer._pad_token is None:\n",
" return pad_sequence(examples, batch_first=True)\n",
" return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n",
"\n",
" train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)\n",
" train_dataloader = DataLoader(\n",
" train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate, drop_last = True\n",
" )\n",
"\n",
" if args.max_steps > 0:\n",
" t_total = args.max_steps\n",
" args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1\n",
" else:\n",
" t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs\n",
"\n",
" model = model.module if hasattr(model, \"module\") else model # Take care of distributed/parallel training\n",
" model.resize_token_embeddings(len(tokenizer))\n",
" # add_special_tokens_(model, tokenizer)\n",
"\n",
"\n",
" # Prepare optimizer and schedule (linear warmup and decay)\n",
" no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
" optimizer_grouped_parameters = [\n",
" {\n",
" \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
" \"weight_decay\": args.weight_decay,\n",
" },\n",
" {\"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], \"weight_decay\": 0.0},\n",
" ]\n",
" optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)\n",
" scheduler = get_linear_schedule_with_warmup(\n",
" optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total\n",
" )\n",
"\n",
" # Check if saved optimizer or scheduler states exist\n",
" if (\n",
" args.model_name_or_path\n",
" and os.path.isfile(os.path.join(args.model_name_or_path, \"optimizer.pt\"))\n",
" and os.path.isfile(os.path.join(args.model_name_or_path, \"scheduler.pt\"))\n",
" ):\n",
" # Load in optimizer and scheduler states\n",
" optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"optimizer.pt\")))\n",
" scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"scheduler.pt\")))\n",
"\n",
" if args.fp16:\n",
" try:\n",
" from apex import amp\n",
" except ImportError:\n",
" raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.\")\n",
" model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)\n",
"\n",
" # multi-gpu training (should be after apex fp16 initialization)\n",
" if args.n_gpu > 1:\n",
" model = torch.nn.DataParallel(model)\n",
"\n",
" # Distributed training (should be after apex fp16 initialization)\n",
" if args.local_rank != -1:\n",
" model = torch.nn.parallel.DistributedDataParallel(\n",
" model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True\n",
" )\n",
"\n",
" # Train!\n",
" logger.info(\"***** Running training *****\")\n",
" logger.info(\" Num examples = %d\", len(train_dataset))\n",
" logger.info(\" Num Epochs = %d\", args.num_train_epochs)\n",
" logger.info(\" Instantaneous batch size per GPU = %d\", args.per_gpu_train_batch_size)\n",
" logger.info(\n",
" \" Total train batch size (w. parallel, distributed & accumulation) = %d\",\n",
" args.train_batch_size\n",
" * args.gradient_accumulation_steps\n",
" * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),\n",
" )\n",
" logger.info(\" Gradient Accumulation steps = %d\", args.gradient_accumulation_steps)\n",
" logger.info(\" Total optimization steps = %d\", t_total)\n",
"\n",
" global_step = 0\n",
" epochs_trained = 0\n",
" steps_trained_in_current_epoch = 0\n",
" # Check if continuing training from a checkpoint\n",
" if args.model_name_or_path and os.path.exists(args.model_name_or_path):\n",
" try:\n",
" # set global_step to gobal_step of last saved checkpoint from model path\n",
" checkpoint_suffix = args.model_name_or_path.split(\"-\")[-1].split(\"/\")[0]\n",
" global_step = int(checkpoint_suffix)\n",
" epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)\n",
" steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)\n",
"\n",
" logger.info(\" Continuing training from checkpoint, will skip to saved global_step\")\n",
" logger.info(\" Continuing training from epoch %d\", epochs_trained)\n",
" logger.info(\" Continuing training from global step %d\", global_step)\n",
" logger.info(\" Will skip the first %d steps in the first epoch\", steps_trained_in_current_epoch)\n",
" except ValueError:\n",
" logger.info(\" Starting fine-tuning.\")\n",
"\n",
" tr_loss, logging_loss = 0.0, 0.0\n",
"\n",
" model.zero_grad()\n",
" train_iterator = trange(\n",
" epochs_trained, int(args.num_train_epochs), desc=\"Epoch\", disable=args.local_rank not in [-1, 0]\n",
" )\n",
" set_seed(args) # Added here for reproducibility\n",
" for _ in train_iterator:\n",
" epoch_iterator = tqdm(train_dataloader, desc=\"Iteration\", disable=args.local_rank not in [-1, 0])\n",
" for step, batch in enumerate(epoch_iterator):\n",
"\n",
" # Skip past any already trained steps if resuming training\n",
" if steps_trained_in_current_epoch > 0:\n",
" steps_trained_in_current_epoch -= 1\n",
" continue\n",
"\n",
" inputs, labels = (batch, batch)\n",
" if inputs.shape[1] > 1024: continue\n",
" inputs = inputs.to(args.device)\n",
" labels = labels.to(args.device)\n",
" model.train()\n",
" outputs = model(inputs, labels=labels)\n",
" loss = outputs[0] # model outputs are always tuple in transformers (see doc)\n",
"\n",
" if args.n_gpu > 1:\n",
" loss = loss.mean() # mean() to average on multi-gpu parallel training\n",
" if args.gradient_accumulation_steps > 1:\n",
" loss = loss / args.gradient_accumulation_steps\n",
"\n",
" if args.fp16:\n",
" with amp.scale_loss(loss, optimizer) as scaled_loss:\n",
" scaled_loss.backward()\n",
" else:\n",
" loss.backward()\n",
"\n",
" tr_loss += loss.item()\n",
" if (step + 1) % args.gradient_accumulation_steps == 0:\n",
" if args.fp16:\n",
" torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)\n",
" else:\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n",
" optimizer.step()\n",
" scheduler.step() # Update learning rate schedule\n",
" model.zero_grad()\n",
" global_step += 1\n",
"\n",
" if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:\n",
" # Log metrics\n",
" if (\n",
" args.local_rank == -1 and args.evaluate_during_training\n",
" ): # Only evaluate when single GPU otherwise metrics may not average well\n",
" results = evaluate(args, model, tokenizer)\n",
" for key, value in results.items():\n",
" tb_writer.add_scalar(\"eval_{}\".format(key), value, global_step)\n",
" tb_writer.add_scalar(\"lr\", scheduler.get_lr()[0], global_step)\n",
" tb_writer.add_scalar(\"loss\", (tr_loss - logging_loss) / args.logging_steps, global_step)\n",
" logging_loss = tr_loss\n",
"\n",
" if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:\n",
" checkpoint_prefix = \"checkpoint\"\n",
" # Save model checkpoint\n",
" output_dir = os.path.join(args.output_dir, \"{}-{}\".format(checkpoint_prefix, global_step))\n",
" os.makedirs(output_dir, exist_ok=True)\n",
" model_to_save = (\n",
" model.module if hasattr(model, \"module\") else model\n",
" ) # Take care of distributed/parallel training\n",
" model_to_save.save_pretrained(output_dir)\n",
" tokenizer.save_pretrained(output_dir)\n",
"\n",
" torch.save(args, os.path.join(output_dir, \"training_args.bin\"))\n",
" logger.info(\"Saving model checkpoint to %s\", output_dir)\n",
"\n",
" _rotate_checkpoints(args, checkpoint_prefix)\n",
"\n",
" torch.save(optimizer.state_dict(), os.path.join(output_dir, \"optimizer.pt\"))\n",
" torch.save(scheduler.state_dict(), os.path.join(output_dir, \"scheduler.pt\"))\n",
" logger.info(\"Saving optimizer and scheduler states to %s\", output_dir)\n",
"\n",
" if args.max_steps > 0 and global_step > args.max_steps:\n",
" epoch_iterator.close()\n",
" break\n",
" if args.max_steps > 0 and global_step > args.max_steps:\n",
" train_iterator.close()\n",
" break\n",
"\n",
" if args.local_rank in [-1, 0]:\n",
" tb_writer.close()\n",
"\n",
" return global_step, tr_loss / global_step\n",
"\n",
"# Evaluation of some model\n",
"\n",
"def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, df_trn, df_val, prefix=\"\") -> Dict:\n",
" # Loop to handle MNLI double evaluation (matched, mis-matched)\n",
" eval_output_dir = args.output_dir\n",
"\n",
" eval_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=True)\n",
" os.makedirs(eval_output_dir, exist_ok=True)\n",
" args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)\n",
" # Note that DistributedSampler samples randomly\n",
"\n",
" def collate(examples: List[torch.Tensor]):\n",
" if tokenizer._pad_token is None:\n",
" return pad_sequence(examples, batch_first=True)\n",
" return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n",
"\n",
" eval_sampler = SequentialSampler(eval_dataset)\n",
" eval_dataloader = DataLoader(\n",
" eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate, drop_last = True\n",
" )\n",
"\n",
" # multi-gpu evaluate\n",
" if args.n_gpu > 1:\n",
" model = torch.nn.DataParallel(model)\n",
"\n",
" # Eval!\n",
" logger.info(\"***** Running evaluation {} *****\".format(prefix))\n",
" logger.info(\" Num examples = %d\", len(eval_dataset))\n",
" logger.info(\" Batch size = %d\", args.eval_batch_size)\n",
" eval_loss = 0.0\n",
" nb_eval_steps = 0\n",
" model.eval()\n",
"\n",
" for batch in tqdm(eval_dataloader, desc=\"Evaluating\"):\n",
" inputs, labels = (batch, batch)\n",
" inputs = inputs.to(args.device)\n",
" labels = labels.to(args.device)\n",
"\n",
" with torch.no_grad():\n",
" outputs = model(inputs, labels=labels)\n",
" lm_loss = outputs[0]\n",
" eval_loss += lm_loss.mean().item()\n",
" nb_eval_steps += 1\n",
"\n",
" eval_loss = eval_loss / (nb_eval_steps + 0.00001)\n",
" perplexity = torch.exp(torch.tensor(eval_loss))\n",
"\n",
" result = {\"perplexity\": perplexity}\n",
"\n",
" output_eval_file = os.path.join(eval_output_dir, prefix, \"eval_results.txt\")\n",
" with open(output_eval_file, \"w\") as writer:\n",
" logger.info(\"***** Eval results {} *****\".format(prefix))\n",
" for key in sorted(result.keys()):\n",
" logger.info(\" %s = %s\", key, str(result[key]))\n",
" writer.write(\"%s = %s\\n\" % (key, str(result[key])))\n",
"\n",
" return result"
],
"metadata": {
"id": "Yd7cAl8-oVSR"
},
"execution_count": 83,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Main runner\n",
"\n",
"def main(df_trn, df_val):\n",
" args = Args()\n",
" \n",
" if args.should_continue:\n",
" sorted_checkpoints = _sorted_checkpoints(args)\n",
" if len(sorted_checkpoints) == 0:\n",
" raise ValueError(\"Used --should_continue but no checkpoint was found in --output_dir.\")\n",
" else:\n",
" args.model_name_or_path = sorted_checkpoints[-1]\n",
"\n",
" if (\n",
" os.path.exists(args.output_dir)\n",
" and os.listdir(args.output_dir)\n",
" and args.do_train\n",
" and not args.overwrite_output_dir\n",
" and not args.should_continue\n",
" ):\n",
" raise ValueError(\n",
" \"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.\".format(\n",
" args.output_dir\n",
" )\n",
" )\n",
"\n",
" # Setup CUDA, GPU & distributed training\n",
" device = torch.device(\"cuda\")\n",
" args.n_gpu = torch.cuda.device_count()\n",
" args.device = device\n",
"\n",
" # Setup logging\n",
" logging.basicConfig(\n",
" format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
" datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
" level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,\n",
" )\n",
" logger.warning(\n",
" \"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s\",\n",
" args.local_rank,\n",
" device,\n",
" args.n_gpu,\n",
" bool(args.local_rank != -1),\n",
" args.fp16,\n",
" )\n",
"\n",
" # Set seed\n",
" set_seed(args)\n",
"\n",
" config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)\n",
" tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)\n",
" model = AutoModelWithLMHead.from_pretrained(\n",
" args.model_name_or_path,\n",
" from_tf=False,\n",
" config=config,\n",
" cache_dir=args.cache_dir,\n",
" )\n",
" model.to(args.device)\n",
" \n",
" logger.info(\"Training/evaluation parameters %s\", args)\n",
"\n",
" # Training\n",
" if args.do_train:\n",
" train_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False)\n",
"\n",
" global_step, tr_loss = train(args, train_dataset, model, tokenizer)\n",
" logger.info(\" global_step = %s, average loss = %s\", global_step, tr_loss)\n",
"\n",
" # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()\n",
" if args.do_train:\n",
" # Create output directory if needed\n",
" os.makedirs(args.output_dir, exist_ok=True)\n",
"\n",
" logger.info(\"Saving model checkpoint to %s\", args.output_dir)\n",
" # Save a trained model, configuration and tokenizer using `save_pretrained()`.\n",
" # They can then be reloaded using `from_pretrained()`\n",
" model_to_save = (\n",
" model.module if hasattr(model, \"module\") else model\n",
" ) # Take care of distributed/parallel training\n",
" model_to_save.save_pretrained(args.output_dir)\n",
" tokenizer.save_pretrained(args.output_dir)\n",
"\n",
" # Good practice: save your training arguments together with the trained model\n",
" torch.save(args, os.path.join(args.output_dir, \"training_args.bin\"))\n",
"\n",
" # Load a trained model and vocabulary that you have fine-tuned\n",
" model = AutoModelWithLMHead.from_pretrained(args.output_dir)\n",
" tokenizer = AutoTokenizer.from_pretrained(args.output_dir)\n",
" model.to(args.device)\n",
"\n",
" # Evaluation\n",
" results = {}\n",
" if args.do_eval and args.local_rank in [-1, 0]:\n",
" checkpoints = [args.output_dir]\n",
" if args.eval_all_checkpoints:\n",
" checkpoints = list(\n",
" os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + \"/**/\" + WEIGHTS_NAME, recursive=True))\n",
" )\n",
" logging.getLogger(\"transformers.modeling_utils\").setLevel(logging.WARN) # Reduce logging\n",
" logger.info(\"Evaluate the following checkpoints: %s\", checkpoints)\n",
" for checkpoint in checkpoints:\n",
" global_step = checkpoint.split(\"-\")[-1] if len(checkpoints) > 1 else \"\"\n",
" prefix = checkpoint.split(\"/\")[-1] if checkpoint.find(\"checkpoint\") != -1 else \"\"\n",
"\n",
" model = AutoModelWithLMHead.from_pretrained(checkpoint)\n",
" model.to(args.device)\n",
" result = evaluate(args, model, tokenizer, df_trn, df_val, prefix=prefix)\n",
" result = dict((k + \"_{}\".format(global_step), v) for k, v in result.items())\n",
" results.update(result)\n",
"\n",
" return results"
],
"metadata": {
"id": "M93fjuFwiu-T"
},
"execution_count": 84,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Run The Main Function"
],
"metadata": {
"id": "jdBULkbmoX6E"
}
},
{
"cell_type": "code",
"source": [
"main(trn_df, val_df)"
],
"metadata": {
"id": "IhQ-I1_Vobx0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 646,
"referenced_widgets": [
"a49e5fd0d85444a3aa9f786455ca8770",
"73e8d052a86647919649a367aa773c8e",
"40760124752846209e61177280a005bd",
"eae3f41495884830818311e51920c956",
"5d1a116b987549d780ee25723f83d45a",
"46f7e33281354ef488945f5f1cfe4c06",
"e987dfed8c624717b5ae2054cce74f05",
"c49e73bdfe544de0ae62034cef7eb0da",
"52e599d90ccd44938d982310fb7e4341",
"c01768976adf465ebfad5c3eedfe1d58",
"27fb7c7e261b4a3b9656a37b1fcde71a",
"1c82670ef31346eb97dff63429fd522f",
"a8c2fda5e0be4c638919b4ca1007dea3",
"5e60bbde81ed452fa0c8d7094d98b052",
"3ac055433ca94c2ebe9f8b44e38be5e0",
"67e70ffbe152488fb036968be105a368",
"bbeb9a01f5bb4aebba239db555f4b16b",
"771602cc4d9444e7ab0d20438639cddd",
"f154ee2be8a044b3aeeb0e904411ffbd",
"a3a6841089054f1cbc31f638424674b3",
"41061818a9c94956a7d1cd129028d805",
"bab64ef3864248018e9476bc8c4018f4",
"f3fa20cd1c40453bb17b2f109607e1bf",
"97b3a5270a014515bbc712b44dba38a0",
"afe3a3438fb145d4a015fdb0709e3156",
"6d6078316fe54c9e83a3c3a35a1169fc",
"4700a281e7d347db8a58c6f181706b54",
"84f8bdfeb6bf4bb7ba4585eba47a7092",
"9262f880b64a4abb80013f6997901bcb",
"386289f0bf56453484a6637d3263da4c",
"8366904443f649928aa9cfd915cd938a",
"80f977590ae94733a8a8552241c12e3b",
"3ff5971c055144a3b81190d199ffe3de",
"be5e0fa21fea43e8bf003ae954c29d03",
"7535306bf05847629946e333021e0ef5",
"671f7a7556b1412cbd48237293431c0d",
"8957849d6dbf44cfafa965d71de78255",
"a55488797f71453e917032008f198b9c",
"38ccc42c4ea14f0e83fca1cb9452bfad",
"ef2ddbabd7b042c0821cf999a9265867",
"0222497184b2446e90f801d84af22b82",
"6a70d59d1c834989a54deda9e776bf41",
"62e05dc21f5e438cb5e94e400071a39b",
"803884fef29b433db14e61df7fae1ee7",
"83414a06fd504f71aa212d9fce15ffb5",
"e2c8785b7c51448296c8cf54331f4a68",
"3f9881738b644024aea6371982320d97",
"143666ff0c7b4e6491779b64f6212818",
"bf699df8e3d844f68a68491c00e8f0bc",
"9e955ce77097447a8c085ea592ae8a5e",
"2f2aecb73861473ba553b4ebebd52e0b",
"db21f601fd6342a5815cea18c417aa99",
"67bf9d4f1a8e431bb580337be0e67f82",
"ddcba8098de6433a8584045d52cb1f3b",
"60062f7860944d0d85c4ef1773c151c3",
"cc13e655b33d4fa390960d1fa40a0e1f",
"2e418d21ae4f4123a9d7b13cbc368605",
"ba1f476b8bdc4fce8a703a0220bd4770",
"26560e743afa4a3fad0eb1e0ed567a64",
"27e7a38811a94548b8dc1980e1c83acd",
"c5ceabb016a74435be6659f1116c9945",
"4ca6ee080dad41aa8abbbea5a96e3922",
"c0fd4025a3e84d0ca8c360887d7126ba",
"4a374856a56c4dc6a163ee5779d6b666",
"ef067c6b95ac48c58428f62ecef22e33",
"88521418122646b0b6c7d41be73e747a"
]
},
"outputId": "612cca73-0fa4-41c0-f2a3-ba9df82d4b4c"
},
"execution_count": 85,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"01/24/2022 12:02:55 - WARNING - __main__ - Process rank: -1, device: cuda, n_gpu: 1, distributed training: False, 16-bits training: False\n",
"/usr/local/lib/python3.7/dist-packages/transformers/models/auto/modeling_auto.py:787: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
" FutureWarning,\n",
"01/24/2022 12:03:00 - INFO - __main__ - Training/evaluation parameters <__main__.Args object at 0x7f0555d60550>\n",
"01/24/2022 12:03:00 - INFO - __main__ - Creating features from dataset file at cached\n",
"01/24/2022 12:03:00 - INFO - __main__ - Saving features into cached file cached/gpt2_cached_lm_512\n",
"01/24/2022 12:03:00 - INFO - __main__ - ***** Running training *****\n",
"01/24/2022 12:03:00 - INFO - __main__ - Num examples = 22\n",
"01/24/2022 12:03:00 - INFO - __main__ - Num Epochs = 4\n",
"01/24/2022 12:03:00 - INFO - __main__ - Instantaneous batch size per GPU = 4\n",
"01/24/2022 12:03:00 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 4\n",
"01/24/2022 12:03:00 - INFO - __main__ - Gradient Accumulation steps = 1\n",
"01/24/2022 12:03:00 - INFO - __main__ - Total optimization steps = 20\n"
]
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a49e5fd0d85444a3aa9f786455ca8770",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Epoch: 0%| | 0/4 [00:00, ?it/s]"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1c82670ef31346eb97dff63429fd522f",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Iteration: 0%| | 0/5 [00:00, ?it/s]"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f3fa20cd1c40453bb17b2f109607e1bf",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Iteration: 0%| | 0/5 [00:00, ?it/s]"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "be5e0fa21fea43e8bf003ae954c29d03",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Iteration: 0%| | 0/5 [00:00, ?it/s]"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "83414a06fd504f71aa212d9fce15ffb5",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Iteration: 0%| | 0/5 [00:00, ?it/s]"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"01/24/2022 12:03:12 - INFO - __main__ - global_step = 20, average loss = 4.869351613521576\n",
"01/24/2022 12:03:12 - INFO - __main__ - Saving model checkpoint to output-small\n",
"01/24/2022 12:03:17 - INFO - __main__ - Evaluate the following checkpoints: ['output-small']\n",
"01/24/2022 12:03:20 - INFO - __main__ - Creating features from dataset file at cached\n",
"01/24/2022 12:03:20 - INFO - __main__ - Saving features into cached file cached/gpt2_cached_lm_512\n",
"01/24/2022 12:03:20 - INFO - __main__ - ***** Running evaluation *****\n",
"01/24/2022 12:03:20 - INFO - __main__ - Num examples = 3\n",
"01/24/2022 12:03:20 - INFO - __main__ - Batch size = 4\n"
]
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc13e655b33d4fa390960d1fa40a0e1f",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Evaluating: 0it [00:00, ?it/s]"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"01/24/2022 12:03:20 - INFO - __main__ - ***** Eval results *****\n",
"01/24/2022 12:03:20 - INFO - __main__ - perplexity = tensor(1.)\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'perplexity_': tensor(1.)}"
]
},
"metadata": {},
"execution_count": 85
}
]
},
{
"cell_type": "markdown",
"source": [
"## Load The Trained Model"
],
"metadata": {
"id": "F_xqB94xocSg"
}
},
{
"cell_type": "code",
"source": [
"tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small')\n",
"model = AutoModelWithLMHead.from_pretrained('output-small')"
],
"metadata": {
"id": "9AGnm3bmofkU",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "c6949113-68cb-4373-f390-983ea871c7ed"
},
"execution_count": 86,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/transformers/models/auto/modeling_auto.py:787: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
" FutureWarning,\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Let's chat for 4 lines\n",
"for step in range(4):\n",
" # encode the new user input, add the eos_token and return a tensor in Pytorch\n",
" new_user_input_ids = tokenizer.encode(input(\">> User:\") + tokenizer.eos_token, return_tensors='pt')\n",
" # print(new_user_input_ids)\n",
"\n",
" # append the new user input tokens to the chat history\n",
" bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids\n",
"\n",
" # generated a response while limiting the total chat history to 1000 tokens, \n",
" chat_history_ids = model.generate(\n",
" bot_input_ids, max_length=200,\n",
" pad_token_id=tokenizer.eos_token_id, \n",
" no_repeat_ngram_size=3, \n",
" do_sample=True, \n",
" top_k=100, \n",
" top_p=0.7,\n",
" temperature=0.8\n",
" )\n",
" \n",
" # pretty print last ouput tokens from bot\n",
" print(\"EleventhDoctorBot: {}\".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))"
],
"metadata": {
"id": "2-CQOEWPrmE7",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "fb711e82-492c-45a3-8aab-7e888bff12cb"
},
"execution_count": 90,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
">> User:why are you not working\n",
"EleventhDoctorBot: I'm not a lawyer, but I'm going to go out on a limb and say that the defendant's name was already on the back of the police's desk.\n",
">> User:i wish you were here\n",
"EleventhDoctorBot: I'll be back in a few hours to talk about this case and the police report.\n",
">> User:what police\n",
"EleventhDoctorBot: I just got a new phone and it's my new phone so I'm not going to talk to anyone on it\n",
">> User:haha\n",
"EleventhDoctorBot: I know, I just had a phone call and it was a phonecall that was being sent to the police and they were on it when they came back to\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Push Model to HuggingFace"
],
"metadata": {
"id": "KQBRSKDcoiJ4"
}
},
{
"cell_type": "code",
"source": [
"#model.push_to_hub(MY_MODEL_NAME, use_auth_token=HUGGINGFACE_API_KEY)\n",
"#tokenizer.push_to_hub(MY_MODEL_NAME, use_auth_token=HUGGINGFACE_API_KEY)"
],
"metadata": {
"id": "E_IH5n-P2u3N"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"os.chdir(\"/content/\")"
],
"metadata": {
"id": "tQtHvpnXA2fC"
},
"execution_count": 88,
"outputs": []
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "SJa0EMUZ-gYI"
},
"execution_count": null,
"outputs": []
}
]
}