Spaces:
Running
Running
File size: 3,639 Bytes
aa4f694 8537019 e690364 9c0dccd 3e68ccf 724babe 4bd6659 724babe 9c0dccd aa4f694 e690364 9c0dccd e690364 9c0dccd 3e68ccf 9c0dccd 724babe 9c0dccd 724babe 6d7d653 4bd6659 6d7d653 9c0dccd e690364 6d7d653 724babe 3e68ccf 9c0dccd 3e68ccf 6d7d653 e55d16a 3e68ccf 724babe 9c0dccd 724babe 9c0dccd 724babe 9c0dccd 469fc38 724babe 3e68ccf 8537019 3e68ccf 8537019 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import logging
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_core.language_models import LLM
from global_config import GlobalConfig
HF_API_URL = f"https://api-inference.huggingface.co/models/{GlobalConfig.HF_LLM_MODEL_NAME}"
HF_API_HEADERS = {"Authorization": f"Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}"}
REQUEST_TIMEOUT = 35
logger = logging.getLogger(__name__)
retries = Retry(
total=5,
backoff_factor=0.25,
backoff_jitter=0.3,
status_forcelist=[502, 503, 504],
allowed_methods={'POST'},
)
adapter = HTTPAdapter(max_retries=retries)
http_session = requests.Session()
http_session.mount('https://', adapter)
http_session.mount('http://', adapter)
def get_hf_endpoint() -> LLM:
"""
Get an LLM via the HuggingFaceEndpoint of LangChain.
:return: The LLM.
"""
logger.debug('Getting LLM via HF endpoint')
return HuggingFaceEndpoint(
repo_id=GlobalConfig.HF_LLM_MODEL_NAME,
max_new_tokens=GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
top_k=40,
top_p=0.95,
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
repetition_penalty=1.03,
streaming=True,
huggingfacehub_api_token=GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
return_full_text=False,
stop_sequences=['</s>'],
)
def hf_api_query(payload: dict) -> dict:
"""
Invoke HF inference end-point API.
:param payload: The prompt for the LLM and related parameters.
:return: The output from the LLM.
"""
try:
response = http_session.post(
HF_API_URL,
headers=HF_API_HEADERS,
json=payload,
timeout=REQUEST_TIMEOUT
)
result = response.json()
except requests.exceptions.Timeout as te:
logger.error('*** Error: hf_api_query timeout! %s', str(te))
result = []
return result
def generate_slides_content(topic: str) -> str:
"""
Generate the outline/contents of slides for a presentation on a given topic.
:param topic: Topic on which slides are to be generated.
:return: The content in JSON format.
"""
with open(GlobalConfig.SLIDES_TEMPLATE_FILE, 'r', encoding='utf-8') as in_file:
template_txt = in_file.read().strip()
template_txt = template_txt.replace('<REPLACE_PLACEHOLDER>', topic)
output = hf_api_query({
'inputs': template_txt,
'parameters': {
'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
'num_return_sequences': 1,
'return_full_text': False,
# "repetition_penalty": 0.0001
},
'options': {
'wait_for_model': True,
'use_cache': True
}
})
output = output[0]['generated_text'].strip()
# output = output[len(template_txt):]
json_end_idx = output.rfind('```')
if json_end_idx != -1:
# logging.debug(f'{json_end_idx=}')
output = output[:json_end_idx]
logger.debug('generate_slides_content: output: %s', output)
return output
if __name__ == '__main__':
# results = get_related_websites('5G AI WiFi 6')
#
# for a_result in results.results:
# print(a_result.title, a_result.url, a_result.extract)
# get_ai_image('A talk on AI, covering pros and cons')
pass
|