'use client' import type { FC } from 'react' import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' import cn from '@/utils/classnames' import TopKItem from '@/app/components/base/param-item/top-k-item' import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' import { RETRIEVE_METHOD } from '@/types/app' import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' import type { RetrievalConfig } from '@/types/app' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { DEFAULT_WEIGHTED_SCORE, RerankingModeEnum, WeightedScoreEnum, } from '@/models/datasets' import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' import Toast from '@/app/components/base/toast' type Props = { type: RETRIEVE_METHOD value: RetrievalConfig onChange: (value: RetrievalConfig) => void } const RetrievalParamConfig: FC = ({ type, value, onChange, }) => { const { t } = useTranslation() const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid const isEconomical = type === RETRIEVE_METHOD.invertedIndex const { defaultModel: rerankDefaultModel, modelList: rerankModelList, } = useModelListAndDefaultModel(ModelTypeEnum.rerank) const { currentModel, } = useCurrentProviderAndModel( rerankModelList, rerankDefaultModel ? { ...rerankDefaultModel, provider: rerankDefaultModel.provider.provider, } : undefined, ) const handleDisabledSwitchClick = useCallback(() => { if (!currentModel) Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) }, [currentModel, rerankDefaultModel, t]) const isHybridSearch = type === RETRIEVE_METHOD.hybrid const rerankModel = (() => { if (value.reranking_model) { return { provider_name: value.reranking_model.reranking_provider_name, model_name: value.reranking_model.reranking_model_name, } } else if (rerankDefaultModel) { return { provider_name: rerankDefaultModel.provider.provider, model_name: rerankDefaultModel.model, } } })() const handleChangeRerankMode = (v: RerankingModeEnum) => { if (v === value.reranking_mode) return const result = { ...value, reranking_mode: v, } if (!result.weights && v === RerankingModeEnum.WeightedScore) { result.weights = { weight_type: WeightedScoreEnum.Customized, vector_setting: { vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, embedding_provider_name: '', embedding_model_name: '', }, keyword_setting: { keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, }, } } onChange(result) } const rerankingModeOptions = [ { value: RerankingModeEnum.WeightedScore, label: t('dataset.weightedScore.title'), tips: t('dataset.weightedScore.description'), }, { value: RerankingModeEnum.RerankingModel, label: t('common.modelProvider.rerankModel.key'), tips: t('common.modelProvider.rerankModel.tip'), }, ] return (
{!isEconomical && !isHybridSearch && (
{canToggleRerankModalEnable && (
{ onChange({ ...value, reranking_enable: v, }) }} disabled={!currentModel} />
)}
{t('common.modelProvider.rerankModel.key')} {t('common.modelProvider.rerankModel.tip')}
} />
{ onChange({ ...value, reranking_model: { reranking_provider_name: v.provider, reranking_model_name: v.model, }, }) }} />
)} { !isHybridSearch && (
{ onChange({ ...value, top_k: v, }) }} enable={true} /> {(!isEconomical && !(value.search_method === RETRIEVE_METHOD.fullText && !value.reranking_enable)) && ( { onChange({ ...value, score_threshold: v, }) }} enable={value.score_threshold_enabled} hasSwitch={true} onSwitchChange={(_key, v) => { onChange({ ...value, score_threshold_enabled: v, }) }} /> )}
) } { isHybridSearch && ( <>
{ rerankingModeOptions.map(option => (
handleChangeRerankMode(option.value)} >
{option.label}
{option.tips}
} triggerClassName='ml-0.5 w-3.5 h-3.5' />
)) } { value.reranking_mode === RerankingModeEnum.WeightedScore && ( { onChange({ ...value, weights: { ...value.weights!, vector_setting: { ...value.weights!.vector_setting, vector_weight: v.value[0], }, keyword_setting: { ...value.weights!.keyword_setting, keyword_weight: v.value[1], }, }, }) }} /> ) } { value.reranking_mode !== RerankingModeEnum.WeightedScore && ( { onChange({ ...value, reranking_model: { reranking_provider_name: v.provider, reranking_model_name: v.model, }, }) }} /> ) }
{ onChange({ ...value, top_k: v, }) }} enable={true} /> { onChange({ ...value, score_threshold: v, }) }} enable={value.score_threshold_enabled} hasSwitch={true} onSwitchChange={(_key, v) => { onChange({ ...value, score_threshold_enabled: v, }) }} />
) } ) } export default React.memo(RetrievalParamConfig)