ayethuzar commited on
Commit
c2abfcd
1 Parent(s): 4cc0cc9

update milestone-3 notebook

Browse files
Files changed (1) hide show
  1. CS670_milestone_3_AyeThuzar.ipynb +532 -9
CS670_milestone_3_AyeThuzar.ipynb CHANGED
@@ -922,18 +922,18 @@
922
  "\n",
923
  "from transformers import pipeline, Trainer, TrainingArguments\n",
924
  "\n",
925
- "\n",
926
  "import torch\n",
927
  "import torch.nn.functional as F\n",
928
  "\n",
929
  "from transformers import logging\n",
930
  "\n",
931
- "logging.set_verbosity_warning()"
932
  ],
933
  "metadata": {
934
  "id": "FxZeFFTlFvz1"
935
  },
936
- "execution_count": 5,
937
  "outputs": []
938
  },
939
  {
@@ -1163,7 +1163,7 @@
1163
  "metadata": {
1164
  "colab": {
1165
  "base_uri": "https://localhost:8080/",
1166
- "height": 140
1167
  },
1168
  "id": "jDBvcgmP5Puh",
1169
  "outputId": "f4f73693-11f7-4918-a86d-2912e863b151"
@@ -1193,7 +1193,7 @@
1193
  "metadata": {
1194
  "colab": {
1195
  "base_uri": "https://localhost:8080/",
1196
- "height": 87
1197
  },
1198
  "id": "sBhSPSV-5XKS",
1199
  "outputId": "0057e051-3b36-4705-8636-19e7850fa0a9"
@@ -4118,7 +4118,7 @@
4118
  "id": "h7bzRvkItdir",
4119
  "outputId": "7495ec10-0ee5-4f1c-ffe9-50f4afe2cb83"
4120
  },
4121
- "execution_count": null,
4122
  "outputs": [
4123
  {
4124
  "output_type": "stream",
@@ -4253,7 +4253,469 @@
4253
  "batch_average_accuray: 0.5\n",
4254
  "batch_average_accuray: 0.5\n",
4255
  "batch_average_accuray: 0.625\n",
4256
- "batch_average_accuray: 0.75\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4257
  ]
4258
  }
4259
  ]
@@ -4277,7 +4739,7 @@
4277
  "metadata": {
4278
  "id": "KefqatP-YDSC"
4279
  },
4280
- "execution_count": 41,
4281
  "outputs": []
4282
  },
4283
  {
@@ -4289,9 +4751,70 @@
4289
  "metadata": {
4290
  "id": "Km8eScKJl4VP"
4291
  },
4292
- "execution_count": 42,
4293
  "outputs": []
4294
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4295
  {
4296
  "cell_type": "markdown",
4297
  "source": [
 
922
  "\n",
923
  "from transformers import pipeline, Trainer, TrainingArguments\n",
924
  "\n",
925
+ "import numpy as np\n",
926
  "import torch\n",
927
  "import torch.nn.functional as F\n",
928
  "\n",
929
  "from transformers import logging\n",
930
  "\n",
931
+ "logging.set_verbosity_warning()\n"
932
  ],
933
  "metadata": {
934
  "id": "FxZeFFTlFvz1"
935
  },
936
+ "execution_count": 92,
937
  "outputs": []
938
  },
939
  {
 
1163
  "metadata": {
1164
  "colab": {
1165
  "base_uri": "https://localhost:8080/",
1166
+ "height": 157
1167
  },
1168
  "id": "jDBvcgmP5Puh",
1169
  "outputId": "f4f73693-11f7-4918-a86d-2912e863b151"
 
1193
  "metadata": {
1194
  "colab": {
1195
  "base_uri": "https://localhost:8080/",
1196
+ "height": 105
1197
  },
1198
  "id": "sBhSPSV-5XKS",
1199
  "outputId": "0057e051-3b36-4705-8636-19e7850fa0a9"
 
4118
  "id": "h7bzRvkItdir",
4119
  "outputId": "7495ec10-0ee5-4f1c-ffe9-50f4afe2cb83"
4120
  },
4121
+ "execution_count": 88,
4122
  "outputs": [
4123
  {
4124
  "output_type": "stream",
 
4253
  "batch_average_accuray: 0.5\n",
4254
  "batch_average_accuray: 0.5\n",
4255
  "batch_average_accuray: 0.625\n",
4256
+ "batch_average_accuray: 0.75\n",
4257
+ "batch_average_accuray: 0.375\n",
4258
+ "batch_average_accuray: 0.5\n",
4259
+ "batch_average_accuray: 0.25\n",
4260
+ "batch_average_accuray: 0.5625\n",
4261
+ "batch_average_accuray: 0.4375\n",
4262
+ "batch_average_accuray: 0.75\n",
4263
+ "batch_average_accuray: 0.375\n",
4264
+ "batch_average_accuray: 0.5625\n",
4265
+ "batch_average_accuray: 0.8125\n",
4266
+ "batch_average_accuray: 0.5625\n",
4267
+ "batch_average_accuray: 0.5625\n",
4268
+ "batch_average_accuray: 0.5\n",
4269
+ "batch_average_accuray: 0.625\n",
4270
+ "batch_average_accuray: 0.6875\n",
4271
+ "batch_average_accuray: 0.4375\n",
4272
+ "batch_average_accuray: 0.625\n",
4273
+ "batch_average_accuray: 0.625\n",
4274
+ "batch_average_accuray: 0.5625\n",
4275
+ "batch_average_accuray: 0.5\n",
4276
+ "batch_average_accuray: 0.5\n",
4277
+ "batch_average_accuray: 0.75\n",
4278
+ "batch_average_accuray: 0.5625\n",
4279
+ "batch_average_accuray: 0.5625\n",
4280
+ "batch_average_accuray: 0.5625\n",
4281
+ "batch_average_accuray: 0.375\n",
4282
+ "batch_average_accuray: 0.5625\n",
4283
+ "batch_average_accuray: 0.625\n",
4284
+ "batch_average_accuray: 0.375\n",
4285
+ "batch_average_accuray: 0.6875\n",
4286
+ "batch_average_accuray: 0.5\n",
4287
+ "batch_average_accuray: 0.625\n",
4288
+ "batch_average_accuray: 0.4375\n",
4289
+ "batch_average_accuray: 0.375\n",
4290
+ "batch_average_accuray: 0.375\n",
4291
+ "batch_average_accuray: 0.5625\n",
4292
+ "batch_average_accuray: 0.5625\n",
4293
+ "batch_average_accuray: 0.375\n",
4294
+ "batch_average_accuray: 0.4375\n",
4295
+ "batch_average_accuray: 0.75\n",
4296
+ "batch_average_accuray: 0.4375\n",
4297
+ "batch_average_accuray: 0.4375\n",
4298
+ "batch_average_accuray: 0.5625\n",
4299
+ "batch_average_accuray: 0.4375\n",
4300
+ "batch_average_accuray: 0.6875\n",
4301
+ "batch_average_accuray: 0.625\n",
4302
+ "batch_average_accuray: 0.6875\n",
4303
+ "batch_average_accuray: 0.625\n",
4304
+ "batch_average_accuray: 0.5\n",
4305
+ "batch_average_accuray: 0.4375\n",
4306
+ "batch_average_accuray: 0.375\n",
4307
+ "batch_average_accuray: 0.4375\n",
4308
+ "batch_average_accuray: 0.625\n",
4309
+ "batch_average_accuray: 0.625\n",
4310
+ "batch_average_accuray: 0.625\n",
4311
+ "batch_average_accuray: 0.75\n",
4312
+ "batch_average_accuray: 0.6875\n",
4313
+ "batch_average_accuray: 0.5625\n",
4314
+ "batch_average_accuray: 0.5\n",
4315
+ "batch_average_accuray: 0.4375\n",
4316
+ "batch_average_accuray: 0.5625\n",
4317
+ "batch_average_accuray: 0.6875\n",
4318
+ "batch_average_accuray: 0.625\n",
4319
+ "batch_average_accuray: 0.75\n",
4320
+ "batch_average_accuray: 0.4375\n",
4321
+ "batch_average_accuray: 0.4375\n",
4322
+ "batch_average_accuray: 0.6875\n",
4323
+ "batch_average_accuray: 0.4375\n",
4324
+ "batch_average_accuray: 0.5625\n",
4325
+ "batch_average_accuray: 0.6875\n",
4326
+ "batch_average_accuray: 0.375\n",
4327
+ "batch_average_accuray: 0.3125\n",
4328
+ "batch_average_accuray: 0.5625\n",
4329
+ "batch_average_accuray: 0.5625\n",
4330
+ "batch_average_accuray: 0.625\n",
4331
+ "batch_average_accuray: 0.5\n",
4332
+ "batch_average_accuray: 0.4375\n",
4333
+ "batch_average_accuray: 0.5625\n",
4334
+ "batch_average_accuray: 0.625\n",
4335
+ "batch_average_accuray: 0.5\n",
4336
+ "batch_average_accuray: 0.6875\n",
4337
+ "batch_average_accuray: 0.5625\n",
4338
+ "batch_average_accuray: 0.375\n",
4339
+ "batch_average_accuray: 0.5\n",
4340
+ "batch_average_accuray: 0.4375\n",
4341
+ "batch_average_accuray: 0.5\n",
4342
+ "batch_average_accuray: 0.625\n",
4343
+ "batch_average_accuray: 0.5625\n",
4344
+ "batch_average_accuray: 0.4375\n",
4345
+ "batch_average_accuray: 0.5\n",
4346
+ "batch_average_accuray: 0.5\n",
4347
+ "batch_average_accuray: 0.4375\n",
4348
+ "batch_average_accuray: 0.8125\n",
4349
+ "batch_average_accuray: 0.625\n",
4350
+ "batch_average_accuray: 0.5\n",
4351
+ "batch_average_accuray: 0.6875\n",
4352
+ "batch_average_accuray: 0.5625\n",
4353
+ "batch_average_accuray: 0.4375\n",
4354
+ "batch_average_accuray: 0.5\n",
4355
+ "batch_average_accuray: 0.625\n",
4356
+ "batch_average_accuray: 0.6875\n",
4357
+ "batch_average_accuray: 0.25\n",
4358
+ "batch_average_accuray: 0.625\n",
4359
+ "batch_average_accuray: 0.5625\n",
4360
+ "batch_average_accuray: 0.25\n",
4361
+ "batch_average_accuray: 0.375\n",
4362
+ "batch_average_accuray: 0.75\n",
4363
+ "batch_average_accuray: 0.625\n",
4364
+ "batch_average_accuray: 0.75\n",
4365
+ "batch_average_accuray: 0.375\n",
4366
+ "batch_average_accuray: 0.4375\n",
4367
+ "batch_average_accuray: 0.625\n",
4368
+ "batch_average_accuray: 0.4375\n",
4369
+ "batch_average_accuray: 0.5\n",
4370
+ "batch_average_accuray: 0.75\n",
4371
+ "batch_average_accuray: 0.3125\n",
4372
+ "batch_average_accuray: 0.5625\n",
4373
+ "batch_average_accuray: 0.75\n",
4374
+ "batch_average_accuray: 0.5625\n",
4375
+ "batch_average_accuray: 0.75\n",
4376
+ "batch_average_accuray: 0.5625\n",
4377
+ "batch_average_accuray: 0.625\n",
4378
+ "batch_average_accuray: 0.75\n",
4379
+ "batch_average_accuray: 0.6875\n",
4380
+ "batch_average_accuray: 0.5625\n",
4381
+ "batch_average_accuray: 0.6875\n",
4382
+ "batch_average_accuray: 0.3125\n",
4383
+ "batch_average_accuray: 0.5\n",
4384
+ "batch_average_accuray: 0.5\n",
4385
+ "batch_average_accuray: 0.25\n",
4386
+ "batch_average_accuray: 0.4375\n",
4387
+ "batch_average_accuray: 0.375\n",
4388
+ "batch_average_accuray: 0.375\n",
4389
+ "batch_average_accuray: 0.5625\n",
4390
+ "batch_average_accuray: 0.5\n",
4391
+ "batch_average_accuray: 0.625\n",
4392
+ "batch_average_accuray: 0.4375\n",
4393
+ "batch_average_accuray: 0.5\n",
4394
+ "batch_average_accuray: 0.6875\n",
4395
+ "batch_average_accuray: 0.5625\n",
4396
+ "batch_average_accuray: 0.625\n",
4397
+ "batch_average_accuray: 0.5625\n",
4398
+ "batch_average_accuray: 0.5625\n",
4399
+ "batch_average_accuray: 0.6875\n",
4400
+ "batch_average_accuray: 0.625\n",
4401
+ "batch_average_accuray: 0.5625\n",
4402
+ "batch_average_accuray: 0.5\n",
4403
+ "batch_average_accuray: 0.5625\n",
4404
+ "batch_average_accuray: 0.6875\n",
4405
+ "batch_average_accuray: 0.6875\n",
4406
+ "batch_average_accuray: 0.75\n",
4407
+ "batch_average_accuray: 0.25\n",
4408
+ "batch_average_accuray: 0.5\n",
4409
+ "batch_average_accuray: 0.625\n",
4410
+ "batch_average_accuray: 0.625\n",
4411
+ "batch_average_accuray: 0.5625\n",
4412
+ "batch_average_accuray: 0.5\n",
4413
+ "batch_average_accuray: 0.375\n",
4414
+ "batch_average_accuray: 0.6875\n",
4415
+ "batch_average_accuray: 0.75\n",
4416
+ "batch_average_accuray: 0.375\n",
4417
+ "batch_average_accuray: 0.625\n",
4418
+ "batch_average_accuray: 0.5625\n",
4419
+ "batch_average_accuray: 0.5\n",
4420
+ "batch_average_accuray: 0.5\n",
4421
+ "batch_average_accuray: 0.5\n",
4422
+ "batch_average_accuray: 0.5625\n",
4423
+ "batch_average_accuray: 0.375\n",
4424
+ "batch_average_accuray: 0.625\n",
4425
+ "batch_average_accuray: 0.5625\n",
4426
+ "batch_average_accuray: 0.75\n",
4427
+ "batch_average_accuray: 0.6875\n",
4428
+ "batch_average_accuray: 0.375\n",
4429
+ "batch_average_accuray: 0.5625\n",
4430
+ "batch_average_accuray: 0.5625\n",
4431
+ "batch_average_accuray: 0.5\n",
4432
+ "batch_average_accuray: 0.625\n",
4433
+ "batch_average_accuray: 0.5625\n",
4434
+ "batch_average_accuray: 0.625\n",
4435
+ "batch_average_accuray: 0.625\n",
4436
+ "batch_average_accuray: 0.25\n",
4437
+ "batch_average_accuray: 0.3125\n",
4438
+ "batch_average_accuray: 0.5625\n",
4439
+ "batch_average_accuray: 0.375\n",
4440
+ "batch_average_accuray: 0.4375\n",
4441
+ "batch_average_accuray: 0.4375\n",
4442
+ "batch_average_accuray: 0.375\n",
4443
+ "batch_average_accuray: 0.8125\n",
4444
+ "batch_average_accuray: 0.6875\n",
4445
+ "batch_average_accuray: 0.4375\n",
4446
+ "batch_average_accuray: 0.5625\n",
4447
+ "batch_average_accuray: 0.6875\n",
4448
+ "batch_average_accuray: 0.5\n",
4449
+ "batch_average_accuray: 0.4375\n",
4450
+ "batch_average_accuray: 0.375\n",
4451
+ "batch_average_accuray: 0.5\n",
4452
+ "batch_average_accuray: 0.4375\n",
4453
+ "batch_average_accuray: 0.4375\n",
4454
+ "batch_average_accuray: 0.375\n",
4455
+ "batch_average_accuray: 0.5\n",
4456
+ "batch_average_accuray: 0.4375\n",
4457
+ "batch_average_accuray: 0.5\n",
4458
+ "batch_average_accuray: 0.4375\n",
4459
+ "batch_average_accuray: 0.5625\n",
4460
+ "batch_average_accuray: 0.6875\n",
4461
+ "batch_average_accuray: 0.5\n",
4462
+ "batch_average_accuray: 0.75\n",
4463
+ "batch_average_accuray: 0.625\n",
4464
+ "batch_average_accuray: 0.625\n",
4465
+ "batch_average_accuray: 0.5\n",
4466
+ "batch_average_accuray: 0.375\n",
4467
+ "batch_average_accuray: 0.5\n",
4468
+ "batch_average_accuray: 0.8125\n",
4469
+ "batch_average_accuray: 0.375\n",
4470
+ "batch_average_accuray: 0.6875\n",
4471
+ "batch_average_accuray: 0.6875\n",
4472
+ "batch_average_accuray: 0.5625\n",
4473
+ "batch_average_accuray: 0.5625\n",
4474
+ "batch_average_accuray: 0.5625\n",
4475
+ "batch_average_accuray: 0.5\n",
4476
+ "batch_average_accuray: 0.5625\n",
4477
+ "batch_average_accuray: 0.5625\n",
4478
+ "batch_average_accuray: 0.5\n",
4479
+ "batch_average_accuray: 0.5625\n",
4480
+ "batch_average_accuray: 0.4375\n",
4481
+ "batch_average_accuray: 0.375\n",
4482
+ "batch_average_accuray: 0.875\n",
4483
+ "batch_average_accuray: 0.5\n",
4484
+ "batch_average_accuray: 0.4375\n",
4485
+ "batch_average_accuray: 0.5\n",
4486
+ "batch_average_accuray: 0.625\n",
4487
+ "batch_average_accuray: 0.5\n",
4488
+ "batch_average_accuray: 0.4375\n",
4489
+ "batch_average_accuray: 0.6875\n",
4490
+ "batch_average_accuray: 0.625\n",
4491
+ "batch_average_accuray: 0.4375\n",
4492
+ "batch_average_accuray: 0.4375\n",
4493
+ "batch_average_accuray: 0.4375\n",
4494
+ "batch_average_accuray: 0.625\n",
4495
+ "batch_average_accuray: 0.4375\n",
4496
+ "batch_average_accuray: 0.6875\n",
4497
+ "batch_average_accuray: 0.625\n",
4498
+ "batch_average_accuray: 0.5625\n",
4499
+ "batch_average_accuray: 0.5\n",
4500
+ "batch_average_accuray: 0.4375\n",
4501
+ "batch_average_accuray: 0.375\n",
4502
+ "batch_average_accuray: 0.75\n",
4503
+ "batch_average_accuray: 0.625\n",
4504
+ "batch_average_accuray: 0.75\n",
4505
+ "batch_average_accuray: 0.4375\n",
4506
+ "batch_average_accuray: 0.4375\n",
4507
+ "batch_average_accuray: 0.3125\n",
4508
+ "batch_average_accuray: 0.5\n",
4509
+ "batch_average_accuray: 0.375\n",
4510
+ "batch_average_accuray: 0.5\n",
4511
+ "batch_average_accuray: 0.8125\n",
4512
+ "batch_average_accuray: 0.4375\n",
4513
+ "batch_average_accuray: 0.8125\n",
4514
+ "batch_average_accuray: 0.4375\n",
4515
+ "batch_average_accuray: 0.75\n",
4516
+ "batch_average_accuray: 0.625\n",
4517
+ "batch_average_accuray: 0.6875\n",
4518
+ "batch_average_accuray: 0.75\n",
4519
+ "batch_average_accuray: 0.5625\n",
4520
+ "batch_average_accuray: 0.5625\n",
4521
+ "batch_average_accuray: 0.6875\n",
4522
+ "batch_average_accuray: 0.4375\n",
4523
+ "batch_average_accuray: 0.375\n",
4524
+ "batch_average_accuray: 0.5\n",
4525
+ "batch_average_accuray: 0.75\n",
4526
+ "batch_average_accuray: 0.5\n",
4527
+ "batch_average_accuray: 0.625\n",
4528
+ "batch_average_accuray: 0.5\n",
4529
+ "batch_average_accuray: 0.5625\n",
4530
+ "batch_average_accuray: 0.25\n",
4531
+ "batch_average_accuray: 0.6875\n",
4532
+ "batch_average_accuray: 0.5625\n",
4533
+ "batch_average_accuray: 0.5\n",
4534
+ "batch_average_accuray: 0.5\n",
4535
+ "batch_average_accuray: 0.4375\n",
4536
+ "batch_average_accuray: 0.375\n",
4537
+ "batch_average_accuray: 0.625\n",
4538
+ "batch_average_accuray: 0.6875\n",
4539
+ "batch_average_accuray: 0.5625\n",
4540
+ "batch_average_accuray: 0.5\n",
4541
+ "batch_average_accuray: 0.5\n",
4542
+ "batch_average_accuray: 0.6875\n",
4543
+ "batch_average_accuray: 0.5\n",
4544
+ "batch_average_accuray: 0.5\n",
4545
+ "batch_average_accuray: 0.5625\n",
4546
+ "batch_average_accuray: 0.5\n",
4547
+ "batch_average_accuray: 0.5\n",
4548
+ "batch_average_accuray: 0.75\n",
4549
+ "batch_average_accuray: 0.625\n",
4550
+ "batch_average_accuray: 0.4375\n",
4551
+ "batch_average_accuray: 0.5625\n",
4552
+ "batch_average_accuray: 0.625\n",
4553
+ "batch_average_accuray: 0.625\n",
4554
+ "batch_average_accuray: 0.4375\n",
4555
+ "batch_average_accuray: 0.5\n",
4556
+ "batch_average_accuray: 0.25\n",
4557
+ "batch_average_accuray: 0.5\n",
4558
+ "batch_average_accuray: 0.4375\n",
4559
+ "batch_average_accuray: 0.8125\n",
4560
+ "batch_average_accuray: 0.75\n",
4561
+ "batch_average_accuray: 0.6875\n",
4562
+ "batch_average_accuray: 0.625\n",
4563
+ "batch_average_accuray: 0.5625\n",
4564
+ "batch_average_accuray: 0.6875\n",
4565
+ "batch_average_accuray: 0.625\n",
4566
+ "batch_average_accuray: 0.5625\n",
4567
+ "batch_average_accuray: 0.625\n",
4568
+ "batch_average_accuray: 0.4375\n",
4569
+ "batch_average_accuray: 0.6875\n",
4570
+ "batch_average_accuray: 0.3125\n",
4571
+ "batch_average_accuray: 0.75\n",
4572
+ "batch_average_accuray: 0.4375\n",
4573
+ "batch_average_accuray: 0.5625\n",
4574
+ "batch_average_accuray: 0.5\n",
4575
+ "batch_average_accuray: 0.6875\n",
4576
+ "batch_average_accuray: 0.5625\n",
4577
+ "batch_average_accuray: 0.4375\n",
4578
+ "batch_average_accuray: 0.75\n",
4579
+ "batch_average_accuray: 0.5625\n",
4580
+ "batch_average_accuray: 0.4375\n",
4581
+ "batch_average_accuray: 0.625\n",
4582
+ "batch_average_accuray: 0.5625\n",
4583
+ "batch_average_accuray: 0.5\n",
4584
+ "batch_average_accuray: 0.4375\n",
4585
+ "batch_average_accuray: 0.625\n",
4586
+ "batch_average_accuray: 0.8125\n",
4587
+ "batch_average_accuray: 0.8125\n",
4588
+ "batch_average_accuray: 0.5625\n",
4589
+ "batch_average_accuray: 0.5625\n",
4590
+ "batch_average_accuray: 0.5625\n",
4591
+ "batch_average_accuray: 0.6875\n",
4592
+ "batch_average_accuray: 0.375\n",
4593
+ "batch_average_accuray: 0.5625\n",
4594
+ "batch_average_accuray: 0.5625\n",
4595
+ "batch_average_accuray: 0.375\n",
4596
+ "batch_average_accuray: 0.625\n",
4597
+ "batch_average_accuray: 0.4375\n",
4598
+ "batch_average_accuray: 0.375\n",
4599
+ "batch_average_accuray: 0.5625\n",
4600
+ "batch_average_accuray: 0.6875\n",
4601
+ "batch_average_accuray: 0.625\n",
4602
+ "batch_average_accuray: 0.375\n",
4603
+ "batch_average_accuray: 0.625\n",
4604
+ "batch_average_accuray: 0.5625\n",
4605
+ "batch_average_accuray: 0.5\n",
4606
+ "batch_average_accuray: 0.625\n",
4607
+ "batch_average_accuray: 0.4375\n",
4608
+ "batch_average_accuray: 0.5\n",
4609
+ "batch_average_accuray: 0.5625\n",
4610
+ "batch_average_accuray: 0.5\n",
4611
+ "batch_average_accuray: 0.4375\n",
4612
+ "batch_average_accuray: 0.4375\n",
4613
+ "batch_average_accuray: 0.3125\n",
4614
+ "batch_average_accuray: 0.75\n",
4615
+ "batch_average_accuray: 0.75\n",
4616
+ "batch_average_accuray: 0.625\n",
4617
+ "batch_average_accuray: 0.5\n",
4618
+ "batch_average_accuray: 0.25\n",
4619
+ "batch_average_accuray: 0.5625\n",
4620
+ "batch_average_accuray: 0.75\n",
4621
+ "batch_average_accuray: 0.625\n",
4622
+ "batch_average_accuray: 0.375\n",
4623
+ "batch_average_accuray: 0.625\n",
4624
+ "batch_average_accuray: 0.625\n",
4625
+ "batch_average_accuray: 0.5625\n",
4626
+ "batch_average_accuray: 0.625\n",
4627
+ "batch_average_accuray: 0.625\n",
4628
+ "batch_average_accuray: 0.4375\n",
4629
+ "batch_average_accuray: 0.5\n",
4630
+ "batch_average_accuray: 0.75\n",
4631
+ "batch_average_accuray: 0.4375\n",
4632
+ "batch_average_accuray: 0.625\n",
4633
+ "batch_average_accuray: 0.375\n",
4634
+ "batch_average_accuray: 0.625\n",
4635
+ "batch_average_accuray: 0.625\n",
4636
+ "batch_average_accuray: 0.4375\n",
4637
+ "batch_average_accuray: 0.5625\n",
4638
+ "batch_average_accuray: 0.3125\n",
4639
+ "batch_average_accuray: 0.5625\n",
4640
+ "batch_average_accuray: 0.75\n",
4641
+ "batch_average_accuray: 0.6875\n",
4642
+ "batch_average_accuray: 0.375\n",
4643
+ "batch_average_accuray: 0.5625\n",
4644
+ "batch_average_accuray: 0.6875\n",
4645
+ "batch_average_accuray: 0.625\n",
4646
+ "batch_average_accuray: 0.625\n",
4647
+ "batch_average_accuray: 0.5625\n",
4648
+ "batch_average_accuray: 0.375\n",
4649
+ "batch_average_accuray: 0.5\n",
4650
+ "batch_average_accuray: 0.5\n",
4651
+ "batch_average_accuray: 0.5625\n",
4652
+ "batch_average_accuray: 0.5625\n",
4653
+ "batch_average_accuray: 0.5625\n",
4654
+ "batch_average_accuray: 0.4375\n",
4655
+ "batch_average_accuray: 0.5625\n",
4656
+ "batch_average_accuray: 0.5\n",
4657
+ "batch_average_accuray: 0.6875\n",
4658
+ "batch_average_accuray: 0.375\n",
4659
+ "batch_average_accuray: 0.4375\n",
4660
+ "batch_average_accuray: 0.5625\n",
4661
+ "batch_average_accuray: 0.4375\n",
4662
+ "batch_average_accuray: 0.6875\n",
4663
+ "batch_average_accuray: 0.5\n",
4664
+ "batch_average_accuray: 0.5625\n",
4665
+ "batch_average_accuray: 0.875\n",
4666
+ "batch_average_accuray: 0.75\n",
4667
+ "batch_average_accuray: 0.25\n",
4668
+ "batch_average_accuray: 0.5\n",
4669
+ "batch_average_accuray: 0.625\n",
4670
+ "batch_average_accuray: 0.375\n",
4671
+ "batch_average_accuray: 0.5625\n",
4672
+ "batch_average_accuray: 0.5625\n",
4673
+ "batch_average_accuray: 0.5625\n",
4674
+ "batch_average_accuray: 0.4375\n",
4675
+ "batch_average_accuray: 0.5625\n",
4676
+ "batch_average_accuray: 0.625\n",
4677
+ "batch_average_accuray: 0.4375\n",
4678
+ "batch_average_accuray: 0.5625\n",
4679
+ "batch_average_accuray: 0.375\n",
4680
+ "batch_average_accuray: 0.625\n",
4681
+ "batch_average_accuray: 0.4375\n",
4682
+ "batch_average_accuray: 0.625\n",
4683
+ "batch_average_accuray: 0.6875\n",
4684
+ "batch_average_accuray: 0.375\n",
4685
+ "batch_average_accuray: 0.6875\n",
4686
+ "batch_average_accuray: 0.5625\n",
4687
+ "batch_average_accuray: 0.6875\n",
4688
+ "batch_average_accuray: 0.6875\n",
4689
+ "batch_average_accuray: 0.4375\n",
4690
+ "batch_average_accuray: 0.5\n",
4691
+ "batch_average_accuray: 0.625\n",
4692
+ "batch_average_accuray: 0.5625\n",
4693
+ "batch_average_accuray: 0.5625\n",
4694
+ "batch_average_accuray: 0.5625\n",
4695
+ "batch_average_accuray: 0.125\n"
4696
+ ]
4697
+ }
4698
+ ]
4699
+ },
4700
+ {
4701
+ "cell_type": "code",
4702
+ "source": [
4703
+ "print(f\"average accuracy: {np.mean(accuracy)}\")"
4704
+ ],
4705
+ "metadata": {
4706
+ "colab": {
4707
+ "base_uri": "https://localhost:8080/"
4708
+ },
4709
+ "id": "-Ow1N7MnEc98",
4710
+ "outputId": "01fddc67-f273-4659-ecfa-fd89e6c78935"
4711
+ },
4712
+ "execution_count": 93,
4713
+ "outputs": [
4714
+ {
4715
+ "output_type": "stream",
4716
+ "name": "stdout",
4717
+ "text": [
4718
+ "average accuracy: 0.5421792618629174\n"
4719
  ]
4720
  }
4721
  ]
 
4739
  "metadata": {
4740
  "id": "KefqatP-YDSC"
4741
  },
4742
+ "execution_count": 94,
4743
  "outputs": []
4744
  },
4745
  {
 
4751
  "metadata": {
4752
  "id": "Km8eScKJl4VP"
4753
  },
4754
+ "execution_count": 95,
4755
  "outputs": []
4756
  },
4757
+ {
4758
+ "cell_type": "markdown",
4759
+ "source": [
4760
+ "## Testing the saved model"
4761
+ ],
4762
+ "metadata": {
4763
+ "id": "dCZQwr_ZE-cB"
4764
+ }
4765
+ },
4766
+ {
4767
+ "cell_type": "code",
4768
+ "source": [
4769
+ "with torch.no_grad():\n",
4770
+ " outputs = model_saved(batch['input_ids']).logits\n",
4771
+ " print(outputs)\n",
4772
+ " predictions = F.softmax(outputs, dim = 1)\n",
4773
+ " print(predictions)\n",
4774
+ " labels = torch.argmax(predictions, dim = 1)\n",
4775
+ " print(labels)\n",
4776
+ " print(\"--------\")\n",
4777
+ " print(batch['decision'])\n",
4778
+ " print(\"--------\")\n",
4779
+ " res = labels == batch['decision']\n",
4780
+ " print(res)\n",
4781
+ " print(res.sum() / batch_size)"
4782
+ ],
4783
+ "metadata": {
4784
+ "colab": {
4785
+ "base_uri": "https://localhost:8080/"
4786
+ },
4787
+ "id": "u_iN3BSHFB27",
4788
+ "outputId": "d73153a7-f156-413c-9e3c-6f2930e8905d"
4789
+ },
4790
+ "execution_count": 96,
4791
+ "outputs": [
4792
+ {
4793
+ "output_type": "stream",
4794
+ "name": "stdout",
4795
+ "text": [
4796
+ "tensor([[-0.2934, 0.9680, 4.0130, -8.2634, -8.1291, -8.6447],\n",
4797
+ " [ 0.5176, 3.2941, 1.8334, -8.3832, -8.6352, -8.5553],\n",
4798
+ " [-0.4728, 0.9731, 4.1658, -8.1353, -7.9516, -8.5336],\n",
4799
+ " [-0.4363, 1.1413, 4.1972, -8.3214, -8.2106, -8.7486],\n",
4800
+ " [-0.3831, 1.4167, 4.0593, -8.5625, -8.5613, -9.0239],\n",
4801
+ " [ 0.3174, 3.2739, 2.2290, -8.6113, -8.8512, -8.8537]])\n",
4802
+ "tensor([[1.2706e-02, 4.4856e-02, 9.4243e-01, 4.3923e-06, 5.0237e-06, 2.9996e-06],\n",
4803
+ " [4.8101e-02, 7.7258e-01, 1.7930e-01, 6.5550e-06, 5.0946e-06, 5.5186e-06],\n",
4804
+ " [9.2039e-03, 3.9077e-02, 9.5171e-01, 4.3269e-06, 5.1996e-06, 2.9054e-06],\n",
4805
+ " [9.1980e-03, 4.4548e-02, 9.4624e-01, 3.4612e-06, 3.8667e-06, 2.2579e-06],\n",
4806
+ " [1.0866e-02, 6.5728e-02, 9.2340e-01, 3.0465e-06, 3.0504e-06, 1.9206e-06],\n",
4807
+ " [3.7043e-02, 7.1237e-01, 2.5057e-01, 4.9094e-06, 3.8624e-06, 3.8528e-06]])\n",
4808
+ "tensor([2, 1, 2, 2, 2, 1])\n",
4809
+ "--------\n",
4810
+ "tensor([2, 2, 0, 1, 1, 1])\n",
4811
+ "--------\n",
4812
+ "tensor([ True, False, False, False, False, True])\n",
4813
+ "tensor(0.1250)\n"
4814
+ ]
4815
+ }
4816
+ ]
4817
+ },
4818
  {
4819
  "cell_type": "markdown",
4820
  "source": [