111 lines
3.8 KiB
Python
111 lines
3.8 KiB
Python
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import base64
|
|||
|
|
import json
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import Any, Iterable, Optional
|
|||
|
|
|
|||
|
|
import requests
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass(frozen=True)
|
|||
|
|
class OcrBox:
|
|||
|
|
# 4 个顶点坐标(相对截图图片坐标系)
|
|||
|
|
points: tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int]]
|
|||
|
|
|
|||
|
|
def center(self) -> tuple[int, int]:
|
|||
|
|
xs = [p[0] for p in self.points]
|
|||
|
|
ys = [p[1] for p in self.points]
|
|||
|
|
return (int(sum(xs) / 4), int(sum(ys) / 4))
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass(frozen=True)
|
|||
|
|
class OcrItem:
|
|||
|
|
text: str
|
|||
|
|
box: OcrBox
|
|||
|
|
|
|||
|
|
|
|||
|
|
class UmiClient:
|
|||
|
|
"""
|
|||
|
|
调用 Umi-OCR 的 HTTP API,并将 data_format=dict 的返回解析为 text+box。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, url: str = "http://127.0.0.1:1224/api/ocr", timeout_s: float = 15.0) -> None:
|
|||
|
|
self.url = url
|
|||
|
|
self.timeout_s = timeout_s
|
|||
|
|
|
|||
|
|
def check_service(self) -> None:
|
|||
|
|
"""
|
|||
|
|
Umi-OCR 没有稳定的 healthz 文档接口,这里用一次轻量请求做连通性检测。
|
|||
|
|
只要能建立连接并返回 JSON(即使是业务错误),就认为服务已启动。
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
r = requests.post(self.url, json={"base64": "", "options": {"data_format": "dict"}}, timeout=3)
|
|||
|
|
_ = r.text # 触发实际请求
|
|||
|
|
except requests.RequestException as e:
|
|||
|
|
raise RuntimeError(f"无法连接 Umi-OCR 服务:{self.url}。请先在 Umi-OCR 中开启 HTTP 服务。") from e
|
|||
|
|
|
|||
|
|
def ocr_bytes(self, image_bytes: bytes) -> list[OcrItem]:
|
|||
|
|
img64 = base64.b64encode(image_bytes).decode("utf-8")
|
|||
|
|
payload = {"base64": img64, "options": {"data_format": "dict"}}
|
|||
|
|
resp = requests.post(self.url, json=payload, timeout=self.timeout_s)
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
data = resp.json()
|
|||
|
|
return self._parse_umi_dict(data)
|
|||
|
|
|
|||
|
|
def _parse_umi_dict(self, data: dict[str, Any]) -> list[OcrItem]:
|
|||
|
|
# 兼容:当返回不是 dict 时直接报错,方便定位
|
|||
|
|
if not isinstance(data, dict):
|
|||
|
|
raise ValueError(f"Umi-OCR 返回非 JSON 对象:{type(data)}")
|
|||
|
|
|
|||
|
|
items: list[OcrItem] = []
|
|||
|
|
data_list = data.get("data", [])
|
|||
|
|
if not isinstance(data_list, list):
|
|||
|
|
raise ValueError(f"Umi-OCR 返回 data 字段不是 list:{json.dumps(data, ensure_ascii=False)[:500]}")
|
|||
|
|
|
|||
|
|
for it in data_list:
|
|||
|
|
if not isinstance(it, dict):
|
|||
|
|
continue
|
|||
|
|
text = str(it.get("text", "")).strip()
|
|||
|
|
box = it.get("box")
|
|||
|
|
pts = _coerce_box_points(box)
|
|||
|
|
if not text or pts is None:
|
|||
|
|
continue
|
|||
|
|
items.append(OcrItem(text=text, box=OcrBox(points=pts)))
|
|||
|
|
return items
|
|||
|
|
|
|||
|
|
def find_text(
|
|||
|
|
self,
|
|||
|
|
target_name: str,
|
|||
|
|
items: Iterable[OcrItem],
|
|||
|
|
*,
|
|||
|
|
exact: bool = True,
|
|||
|
|
case_sensitive: bool = False,
|
|||
|
|
) -> Optional[tuple[int, int]]:
|
|||
|
|
"""
|
|||
|
|
在给定 OCR items 中查找目标文字,返回其在截图坐标系下的中心点 (x, y)。
|
|||
|
|
"""
|
|||
|
|
t = target_name if case_sensitive else target_name.lower()
|
|||
|
|
for item in items:
|
|||
|
|
s = item.text if case_sensitive else item.text.lower()
|
|||
|
|
ok = (s == t) if exact else (t in s)
|
|||
|
|
if ok:
|
|||
|
|
return item.box.center()
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _coerce_box_points(box: Any) -> Optional[tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int]]]:
|
|||
|
|
"""
|
|||
|
|
Umi-OCR 的 box 在 data_format=dict 下通常是 4 个点:
|
|||
|
|
[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]
|
|||
|
|
"""
|
|||
|
|
if not (isinstance(box, list) and len(box) == 4):
|
|||
|
|
return None
|
|||
|
|
pts: list[tuple[int, int]] = []
|
|||
|
|
for p in box:
|
|||
|
|
if not (isinstance(p, (list, tuple)) and len(p) == 2):
|
|||
|
|
return None
|
|||
|
|
pts.append((int(p[0]), int(p[1])))
|
|||
|
|
return (pts[0], pts[1], pts[2], pts[3])
|
|||
|
|
|