File size: 3,639 Bytes
aa4f694
8537019
e690364
 
 
9c0dccd
 
3e68ccf
 
 
 
724babe
 
4bd6659
724babe
9c0dccd
aa4f694
e690364
 
 
 
 
 
 
 
 
 
 
 
9c0dccd
 
 
e690364
9c0dccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e68ccf
 
9c0dccd
724babe
 
 
9c0dccd
 
724babe
 
6d7d653
4bd6659
 
 
 
 
 
6d7d653
 
9c0dccd
e690364
6d7d653
 
724babe
 
3e68ccf
 
 
 
9c0dccd
 
3e68ccf
 
6d7d653
e55d16a
 
3e68ccf
724babe
9c0dccd
 
724babe
 
 
 
 
 
 
 
9c0dccd
724babe
 
 
 
 
 
 
 
 
 
 
 
 
9c0dccd
469fc38
724babe
3e68ccf
 
8537019
 
 
 
 
3e68ccf
8537019
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import logging
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry

from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_core.language_models import LLM

from global_config import GlobalConfig


HF_API_URL = f"https://api-inference.huggingface.co/models/{GlobalConfig.HF_LLM_MODEL_NAME}"
HF_API_HEADERS = {"Authorization": f"Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}"}
REQUEST_TIMEOUT = 35

logger = logging.getLogger(__name__)

retries = Retry(
    total=5,
    backoff_factor=0.25,
    backoff_jitter=0.3,
    status_forcelist=[502, 503, 504],
    allowed_methods={'POST'},
)
adapter = HTTPAdapter(max_retries=retries)
http_session = requests.Session()
http_session.mount('https://', adapter)
http_session.mount('http://', adapter)


def get_hf_endpoint() -> LLM:
    """
    Get an LLM via the HuggingFaceEndpoint of LangChain.

    :return: The LLM.
    """

    logger.debug('Getting LLM via HF endpoint')

    return HuggingFaceEndpoint(
        repo_id=GlobalConfig.HF_LLM_MODEL_NAME,
        max_new_tokens=GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
        top_k=40,
        top_p=0.95,
        temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
        repetition_penalty=1.03,
        streaming=True,
        huggingfacehub_api_token=GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
        return_full_text=False,
        stop_sequences=['</s>'],
    )


def hf_api_query(payload: dict) -> dict:
    """
    Invoke HF inference end-point API.

    :param payload: The prompt for the LLM and related parameters.
    :return: The output from the LLM.
    """

    try:
        response = http_session.post(
            HF_API_URL,
            headers=HF_API_HEADERS,
            json=payload,
            timeout=REQUEST_TIMEOUT
        )
        result = response.json()
    except requests.exceptions.Timeout as te:
        logger.error('*** Error: hf_api_query timeout! %s', str(te))
        result = []

    return result


def generate_slides_content(topic: str) -> str:
    """
    Generate the outline/contents of slides for a presentation on a given topic.

    :param topic: Topic on which slides are to be generated.
    :return: The content in JSON format.
    """

    with open(GlobalConfig.SLIDES_TEMPLATE_FILE, 'r', encoding='utf-8') as in_file:
        template_txt = in_file.read().strip()
        template_txt = template_txt.replace('<REPLACE_PLACEHOLDER>', topic)

    output = hf_api_query({
        'inputs': template_txt,
        'parameters': {
            'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
            'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
            'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
            'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
            'num_return_sequences': 1,
            'return_full_text': False,
            # "repetition_penalty": 0.0001
        },
        'options': {
            'wait_for_model': True,
            'use_cache': True
        }
    })

    output = output[0]['generated_text'].strip()
    # output = output[len(template_txt):]

    json_end_idx = output.rfind('```')
    if json_end_idx != -1:
        # logging.debug(f'{json_end_idx=}')
        output = output[:json_end_idx]

    logger.debug('generate_slides_content: output: %s', output)

    return output


if __name__ == '__main__':
    # results = get_related_websites('5G AI WiFi 6')
    #
    # for a_result in results.results:
    #     print(a_result.title, a_result.url, a_result.extract)

    # get_ai_image('A talk on AI, covering pros and cons')
    pass