Spaces:
Running
Running
import unittest | |
from unittest.mock import patch | |
import pandas as pd | |
import src.backend.evaluate_model as evaluate_model | |
class TestSummaryGenerator(unittest.TestCase): | |
def setUp(self): | |
self.model_id = "test_model" | |
self.revision = "test_revision" | |
def test_init(self, mock_model, mock_tokenizer): | |
evaluate_model.SummaryGenerator(self.model_id, self.revision) | |
mock_tokenizer.from_pretrained.assert_called_once_with(self.model_id, | |
self.revision) | |
mock_model.from_pretrained.assert_called_once_with(self.model_id, | |
self.revision) | |
def test_generate_summaries(self, mock_model, mock_tokenizer, mock_nlp): | |
df = pd.DataFrame({'text': ['text1', 'text2'], | |
'dataset': ['dataset1', 'dataset2']}) | |
generator = evaluate_model.SummaryGenerator(self.model_id, self.revision) | |
generator.generate_summaries(df) | |
self.assertEqual(len(generator.summaries_df), len(df)) | |
def test_compute_avg_length(self, mock_model, mock_tokenizer): | |
generator = evaluate_model.SummaryGenerator(self.model_id, self.revision) | |
test_df = pd.DataFrame({'source': ['text'], 'summary': ['This is a test.'], | |
'dataset': ['dataset']}) | |
generator.summaries_df = test_df | |
generator._compute_avg_length() | |
self.assertEqual(generator.avg_length, 4) | |
def test_compute_answer_rate(self, mock_model, mock_tokenizer): | |
generator = evaluate_model.SummaryGenerator(self.model_id, self.revision) | |
test_df = pd.DataFrame({'source': ['text'], 'summary': ['This is a test.'], | |
'dataset': ['dataset']}) | |
generator.summaries_df = test_df | |
generator._compute_answer_rate() | |
self.assertEqual(generator.answer_rate, 1) | |
def test_error_rate(self, mock_model, mock_tokenizer): | |
generator = evaluate_model.SummaryGenerator(self.model_id, self.revision) | |
test_df = pd.DataFrame({'source': ['text'], 'summary': ['This is a test.'], | |
'dataset': ['dataset']}) | |
generator.summaries_df = test_df | |
generator._compute_error_rate(0) | |
self.assertEqual(generator.error_rate, 0) | |
if __name__ == "__main__": | |
unittest.main() | |