当前位置: 首页 > news >正文

评测基准设计:全面评估 AI 系统的性能与质量

评测基准设计:全面评估 AI 系统的性能与质量

前言

没有评测就没有优化。设计合理的评测基准是提升 AI 系统的关键,能够帮助我们客观地评估系统表现、发现问题、指导改进。

我在多个项目中设计过评测基准,今天分享一些经验和实践。

评测指标体系

检索质量指标

import numpy as np from typing import List, Dict class RetrievalMetrics: """检索评测指标""" def __init__(self): pass def precision_at_k(self, retrieved: List[str], relevant: List[str], k: int) -> float: """Precision@k""" retrieved_k = retrieved[:k] hits = len(set(retrieved_k) & set(relevant)) return hits / k if k > 0 else 0.0 def recall_at_k(self, retrieved: List[str], relevant: List[str], k: int) -> float: """Recall@k""" retrieved_k = retrieved[:k] hits = len(set(retrieved_k) & set(relevant)) return hits / len(relevant) if relevant else 0.0 def average_precision(self, retrieved: List[str], relevant: List[str]) -> float: """平均精确率""" precisions = [] for i, doc_id in enumerate(retrieved): if doc_id in relevant: precision = self.precision_at_k(retrieved, relevant, i+1) precisions.append(precision) return np.mean(precisions) if precisions else 0.0 def dcg_at_k(self, retrieved: List[str], relevance_scores: Dict[str, int], k: int) -> float: """折扣累积增益""" dcg = 0.0 for i in range(min(k, len(retrieved))): doc_id = retrieved[i] rel = relevance_scores.get(doc_id, 0) dcg += rel / np.log2(i + 2) return dcg def ndcg_at_k(self, retrieved: List[str], relevance_scores: Dict[str, int], k: int) -> float: """归一化折扣累积增益""" dcg = self.dcg_at_k(retrieved, relevance_scores, k) # 理想排序 ideal_sorted = sorted(relevance_scores.items(), key=lambda x: x[1], reverse=True) ideal_retrieved = [doc_id for doc_id, _ in ideal_sorted] idcg = self.dcg_at_k(ideal_retrieved, relevance_scores, k) return dcg / idcg if idcg > 0 else 0.0 def compute_all_metrics(self, retrieved: List[str], relevant: List[str], relevance_scores: Dict[str, int] = None, k_list: List[int] = [1, 5, 10]): """计算所有指标""" metrics = {} for k in k_list: metrics[f"precision@{k}"] = self.precision_at_k(retrieved, relevant, k) metrics[f"recall@{k}"] = self.recall_at_k(retrieved, relevant, k) if relevance_scores: metrics[f"ndcg@{k}"] = self.ndcg_at_k(retrieved, relevance_scores, k) metrics["map"] = self.average_precision(retrieved, relevant) return metrics

系统性能指标

import time from dataclasses import dataclass @dataclass class LatencyStats: """延迟统计""" avg_latency: float p50_latency: float p95_latency: float p99_latency: float min_latency: float max_latency: float class PerformanceMetrics: """性能评测指标""" def __init__(self): self.latencies = [] self.start_time = None self.end_time = None self.request_count = 0 self.error_count = 0 def start_benchmark(self): """开始基准测试""" self.start_time = time.time() def end_benchmark(self): """结束基准测试""" self.end_time = time.time() def record_request(self, latency: float, is_error: bool = False): """记录请求""" self.latencies.append(latency) self.request_count += 1 if is_error: self.error_count += 1 def get_latency_stats(self) -> LatencyStats: """获取延迟统计""" if not self.latencies: return LatencyStats(0, 0, 0, 0, 0, 0) sorted_latencies = sorted(self.latencies) n = len(sorted_latencies) return LatencyStats( avg_latency=np.mean(self.latencies), p50_latency=np.percentile(self.latencies, 50), p95_latency=np.percentile(self.latencies, 95), p99_latency=np.percentile(self.latencies, 99), min_latency=min(self.latencies), max_latency=max(self.latencies) ) def get_throughput(self) -> float: """获取吞吐量""" if self.start_time and self.end_time: duration = self.end_time - self.start_time return self.request_count / duration if duration > 0 else 0 return 0 def get_error_rate(self) -> float: """获取错误率""" return self.error_count / self.request_count if self.request_count > 0 else 0

评测数据集构建

测试用例设计

from typing import List, Dict, Any from dataclasses import dataclass @dataclass class TestCase: """测试用例""" query: str relevant_docs: List[str] relevance_scores: Dict[str, int] = None metadata: Dict[str, Any] = None class BenchmarkDataset: """评测数据集""" def __init__(self): self.test_cases: List[TestCase] = [] def add_test_case(self, test_case: TestCase): """添加测试用例""" self.test_cases.append(test_case) def categorize_test_cases(self): """分类测试用例""" categories = { "short_query": [], "long_query": [], "complex_query": [], "domain_specific": [] } for case in self.test_cases: query_len = len(case.query.split()) if query_len <= 3: categories["short_query"].append(case) elif query_len >= 10: categories["long_query"].append(case) else: categories["domain_specific"].append(case) return categories def get_statistics(self): """获取统计信息""" return { "total_cases": len(self.test_cases), "avg_relevant_docs": np.mean([len(c.relevant_docs) for c in self.test_cases]), "avg_query_length": np.mean([len(c.query.split()) for c in self.test_cases]) }

自动标注

from typing import List class AutoLabeler: """自动标注器""" def __init__(self, llm_client=None): self.llm_client = llm_client def filter_candidates(self, query: str, candidates: List[Dict], threshold: float = 0.7) -> List[str]: """过滤候选文档""" if self.llm_client: # 使用 LLM 进行标注 labeled = [] for candidate in candidates: prompt = f""" 查询: {query} 文档: {candidate['content']} 请问这个文档与查询相关吗?只回答 yes 或 no。 """ response = self.llm_client.generate(prompt).lower() if "yes" in response: labeled.append(candidate['id']) return labeled else: # 基于相似度的简单标注 return [c['id'] for c in candidates if c['score'] > threshold] def generate_relevance_scores(self, query: str, candidates: List[Dict]) -> Dict[str, int]: """生成相关性分数""" scores = {} for candidate in candidates: score = candidate['score'] if score > 0.8: scores[candidate['id']] = 3 elif score > 0.6: scores[candidate['id']] = 2 elif score > 0.4: scores[candidate['id']] = 1 else: scores[candidate['id']] = 0 return scores

完整评测框架

from typing import Callable import json class BenchmarkRunner: """评测运行器""" def __init__(self, dataset: BenchmarkDataset): self.dataset = dataset self.retrieval_metrics = RetrievalMetrics() self.performance_metrics = PerformanceMetrics() self.results = [] def run_benchmark(self, search_func: Callable, include_performance: bool = True) -> Dict: """运行评测""" all_metrics = [] if include_performance: self.performance_metrics.start_benchmark() for test_case in self.dataset.test_cases: # 执行搜索 start_time = time.time() try: retrieved = search_func(test_case.query) is_error = False except Exception: retrieved = [] is_error = True end_time = time.time() latency = (end_time - start_time) * 1000 if include_performance: self.performance_metrics.record_request(latency, is_error) # 计算指标 case_metrics = self.retrieval_metrics.compute_all_metrics( retrieved, test_case.relevant_docs, test_case.relevance_scores ) case_metrics['query'] = test_case.query all_metrics.append(case_metrics) self.results.append({ "query": test_case.query, "retrieved": retrieved, "relevant": test_case.relevant_docs, "metrics": case_metrics }) if include_performance: self.performance_metrics.end_benchmark() # 聚合结果 aggregated_metrics = self._aggregate_metrics(all_metrics) if include_performance: aggregated_metrics['performance'] = { 'latency': self.performance_metrics.get_latency_stats().__dict__, 'throughput': self.performance_metrics.get_throughput(), 'error_rate': self.performance_metrics.get_error_rate() } return aggregated_metrics def _aggregate_metrics(self, all_metrics: List[Dict]) -> Dict: """聚合指标""" aggregated = {} if not all_metrics: return aggregated metric_names = all_metrics[0].keys() for metric_name in metric_names: if metric_name == 'query': continue values = [case[metric_name] for case in all_metrics if metric_name in case] if values: aggregated[metric_name] = { 'mean': np.mean(values), 'std': np.std(values), 'min': np.min(values), 'max': np.max(values) } return aggregated def save_results(self, filepath: str): """保存结果""" with open(filepath, 'w') as f: json.dump(self.results, f, ensure_ascii=False, indent=2)

总结

评测基准设计要点:

  1. 全面指标:质量+性能+稳定性
  2. 分类测试:不同查询类型分别评估
  3. 自动标注:降低人工成本
  4. 持续对比:A/B 测试验证改进
  5. 结果分析:发现薄弱环节

实践建议:

  • 从简单指标开始,逐步丰富
  • 保留历史结果便于对比
  • 人工抽检验证自动标注
  • 定期更新测试集
http://www.rkmt.cn/news/1418741.html

相关文章:

  • 别再硬刚pip install了!手把手教你用conda搞定torch_geometric(附版本匹配避坑清单)
  • 告别云服务账单:用llama.cpp和4-bit量化在老旧笔记本上搭建你的私有AI助手
  • 2026年高粘背胶的文具PVC装饰贴片/PVC装饰贴片/家具PVC装饰贴片/卡通PVC装饰贴片厂家选择推荐 - 品牌宣传支持者
  • AI文本检测技术解析:从DetectGPT到信息论,三大流派实战指南
  • 【Gemini Go编程实战指南】:20年Go专家亲授,避开97%开发者踩过的5大陷阱
  • H3CSE 高性能园区网:IRF 堆叠技术详解
  • Navicat vs DBeaver:从零到一,手把手教你根据项目需求选对数据库管理工具(附避坑指南)
  • 从需求分析到产品落地:AI产品经理实战训练营,带你玩转AI赋能产品全流程!
  • Git 分支合并操作备忘录
  • 金字塔原理:教你做一个技术强会表达的芯片工程师(7000字)
  • Solar Pro Preview 模型架构详解:从Phi-3-medium到220亿参数的深度上采样技术
  • NLP —— 英译法实例
  • 第3章:裂痕——Siri、Copilot与寄生者入侵
  • GeoServer数据源创建失败?别慌,可能是这个Windows文件命名‘潜规则’在捣鬼
  • Python爬虫实战:极客实战 - 全自动化构建 GraphQL/REST API 结构化字典!
  • WPF文本框的Placeholder效果,除了Watermark和Style,这几种实现方式你知道吗?
  • 告别‘一大片爆红’:手把手教你用CMake-GUI无错配置VTK(Windows/VS2022版)
  • 避坑指南:DataSophon部署中那些官方文档没细说的坑(防火墙、MySQL、Nginx配置)
  • 别再自己造轮子了!盘点那些能直接提升UniApp开发效率的34个原生插件
  • 如何3分钟搞定QQ空间数据备份:GetQzonehistory终极指南 [特殊字符]
  • 告别繁琐组态:用SVG+JavaScript手搓一个可复用的HMI仪表盘组件
  • 生成式AI重塑网络安全攻防:开发者如何构建AI增强型防御体系
  • SAP推出AI智能体中枢,统一管理企业多厂商智能体
  • 别再为layui上传进度条发愁了!手把手教你用layer弹窗实现文件上传进度可视化(附完整PHP后端代码)
  • 宽频抗干扰更稳定:鼎讯信通 ZN‑061A 手持式信号综合分析仪应用
  • 5分钟搞定!中国科学技术大学Beamer模板终极使用指南
  • CSDN日常运营方法
  • 大模型公司开始派人进客户现场,属于产品经理的转型时刻要来了?
  • 简单学习 --> 模型的短期记忆
  • SPI通信模式0和模式3怎么选?实测W25Q128FV在STM32 HAL库下的兼容性问题与调试心得