weaviate向量数据库
项目地址:https://github.com/weaviate/weaviate
文档:https://weaviate.io/developers/weaviate/api/rest#tag/objects/GET/objects
python SDK文档:https://python.langchain.com/docs/integrations/vectorstores/weaviate/
容器运行
docker run -d -p 8080:8080 -p 50051:50051 cr.weaviate.io/semitechnologies/weaviate:1.32.0
dify已内置,只需暴露端口即可,单独部署weaviate则添加在docker-compose.yml添加:
services:
weaviate:
command:
- --host
- 0.0.0.0
- --port
- '8080'
- --scheme
- http
image: cr.weaviate.io/semitechnologies/weaviate:1.30.1
ports:
- 8080:8080
- 50051:50051
volumes:
- weaviate_data:/var/lib/weaviate
restart: on-failure:0
environment:
QUERY_DEFAULTS_LIMIT: 25
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
PERSISTENCE_DATA_PATH: '/var/lib/weaviate'
ENABLE_API_BASED_MODULES: 'true'
CLUSTER_HOSTNAME: 'node1'
volumes:
weaviate_data: {}
REST API
# 假设安装了fastapi
# 保存文本到weaviate
@app.get("/save")
async def save_embedding(
you: str = Query(..., description="用户发送的文本"),
me: str = Query(..., description="模型返回的文本"),
conversation_id: str = Query(..., description="会话ID")
):
"""
调用 Embedding API 并将结果保存到 Weaviate
"""
text = f"{you}。{me}"
async with httpx.AsyncClient() as client:
# 1. 调用 Embedding API
embed_url = f"{OPENAI_BASE_URL}/v1/embeddings"
embed_headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}"
}
embed_payload = {"input": text, "model": "text-embedding-3-small"}
try:
embed_resp = await client.post(embed_url, headers=embed_headers, json=embed_payload)
embed_resp.raise_for_status()
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Embedding API 调用失败: {str(e)}"
)
embed_result = embed_resp.json()
embedding = embed_result.get("data", [])[0].get("embedding")
# 2. 使用 httpx 异步请求保存到 Weaviate
weaviate_url = f"{WEAVIATE_URL}/objects"
weaviate_headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {WEAVIATE_API_KEY}"
}
weaviate_payload = {
"class": "Memory",
"properties": {"you": you, "me": me, "conversation_id": conversation_id},
"vector": embedding
}
try:
weaviate_resp = await client.post(weaviate_url, headers=weaviate_headers, json=weaviate_payload)
weaviate_resp.raise_for_status()
except HTTPStatusError as e:
raise HTTPException(
status_code=e.response.status_code,
detail=f"Weaviate 保存失败: {e.response.text}"
)
return weaviate_resp.json()
# 列出所有信息按created_at排序最前
@app.get("/queryall_sort")
async def query_all_sort(
conversation_id: str = Query(..., description="会话ID"),
text: str = Query(..., description="查询文本"),
limit: int = Query(10, description="返回的记录数")
):
async with httpx.AsyncClient() as client:
# 1. 获取文本向量
embed_url = f"{OPENAI_BASE_URL}/v1/embeddings"
embed_headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}"
}
embed_payload = {"input": text, "model": "text-embedding-3-small"}
embed_resp = await client.post(embed_url, headers=embed_headers, json=embed_payload)
embed_resp.raise_for_status()
embed_result = embed_resp.json()
vector = embed_result.get("data", [])[0].get("embedding")
# 2. 使用 GraphQL 在 Weaviate 中查询
graphql_url = f"{WEAVIATE_URL}/graphql"
# 构建 GraphQL 查询文本
vector_str = ",".join(map(str, vector))
graphql_query = f'''{{
Get {{
Memory(
where: {{ path: ["conversation_id"], operator: Equal, valueString: "{conversation_id}" }},
nearVector: {{ vector: [{vector_str}], certainty: 0.7 }},
limit: 100,
sort: [{{
path: "_creationTimeUnix",
order: asc
}}]
) {{
you
me
_additional {{
creationTimeUnix
}}
}}
}}
}}'''
graphql_payload = {"query": graphql_query}
# 打印原始请求
# print(graphql_payload)
graphql_headers = {"Content-Type": "application/json", "Authorization": f"Bearer {WEAVIATE_API_KEY}"}
graphql_resp = await client.post(graphql_url, headers=graphql_headers, json=graphql_payload)
graphql_resp.raise_for_status()
jsonstr = graphql_resp.json()
# print(f"jsonstr: {jsonstr}")
# 格式化记忆,取最后10条
return format_memory(jsonstr)
列出所有记忆,50条,按创建排序,http请求
GET /v1/objects?class=Memory&limit=50&sort=_creationTimeUnix
Authorization: Bearer WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
http api
http://127.0.0.1:8080/v1 基础访问地址/graphql POST 执行graphql命令/objects GET 对象查询/objects/Hot/ID DELETE 删除对象/schema GET 列出所有类/schema/类名 DELETE 删除类
服务器快速操作
如果有验证操作方法
curl -X 方法 \
-H "Authorization: Bearer 身份" \
"https://127.0.0.1:8080/v1/schema/集合名"
如果没有验证,服务器快速删除一个集合curl -X DELETE "http://127.0.0.1:8080/v1/schema/集合名"
列出集合curl -s "http://127.0.0.1:8080/v1/schema"
运行命令
docker run -d --name weaviate -p 8080:8080 -p 50051:50051 -v weaviate-data:/var/lib/weaviate -e AUTHENTICATION_APIKEY_ENABLED=false -e AUTHENTICATION_METHOD=none --restart unless-stopped semitechnologies/weaviate:1.33.0-dev-5180ccb.amd64
python SDK
uv pip install langchain-openai langchain-weaviate langchain-community langchain-text-splitters
完整的示同步范:
from langchain_weaviate import WeaviateVectorStore
import weaviate
from langchain_community.document_loaders import TextLoader
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
import uuid # 生成唯一文章ID
from datetime import datetime # 导入datetime用于时间戳
from weaviate.classes.query import Filter
from weaviate.classes.aggregate import GroupByAggregate # 按字段分组聚合,用于去重
# 连接到本地weaviate
weaviate_client = weaviate.connect_to_local(
host="127.0.0.1",
port=8080,
grpc_port=50051,
skip_init_checks=True, # 建议设置,跳过pypi检查
)
# 指定嵌入模型
embeddings = OpenAIEmbeddings(
model="text-embedding-3-small",
api_key="sk-proj你的密钥",
base_url="https://api.openai.com/v1",
)
# 获取多租户集合实例(如果不存在则创建)
store = WeaviateVectorStore(
client=weaviate_client, # 连接到weaviate实列(上面创建的连接)
index_name="Test", # 指定集合名称
text_key="text", # 指定文本字段名称
embedding=embeddings, # 指定嵌入模型,也可在具体操作方法中指定
use_multi_tenancy=True, # 使用多租户,必须在创建集合时指定
)
# 获取普通集合实例(如果不存在则创建)
store2 = WeaviateVectorStore(
client=weaviate_client, # 连接到weaviate实列(上面创建的连接)
index_name="Test2", # 指定集合名称
text_key="text", # 指定文本字段名称
embedding=embeddings, # 指定嵌入模型,也可在具体操作方法中指定
)
# 删除集合
def delete_collection(collection_name:str) -> bool:
weaviate_client.collections.delete(collection_name)
return True
# 分割文档,把文档切分成一个个chunk
def split_documents(path:str):
loader = TextLoader(path) # 加载文档
documents = loader.load() # 加载文档
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2048, # 每个chunk的大小
chunk_overlap=128, # 每个chunk的重叠大小
# separators=["\n\n", "\n", "。", "!", "?", ",","!","?",".",",",";",";", " ", ""], # 按段/行/句/词/无分隔递归回退
) # 分割文档
docs = text_splitter.split_documents(documents) # 分割文档
# 为每个文档添加metadata
doc_id:str = str(uuid.uuid4()) # 生成唯一文章ID
file_name = path.split("/")[-1] # 获取文件名
created_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # 统一创建时间
for i, doc in enumerate(docs):
# 确保metadata存在
if doc.metadata is None:
doc.metadata = {}
# 添加自定义字段
doc.metadata["chunk_index"] = i # 切片索引,标识切片在文档中的顺序位置
doc.metadata["doc_id"] = doc_id # 文档唯一标识符,用于关联同一文档的所有切片
doc.metadata["file_name"] = file_name # 原始文件名,便于文档来源追踪
doc.metadata["created_at"] = created_time # 文档创建时间,格式:YYYY-MM-DD HH:MM:SS
return docs
# 写入一条租户文档数据(把切分好的文档写到集合)
def add_documents(docs, tenant):
ids = store.add_documents(
documents=docs, # 文档切片
# embedding=embeddings, # WeaviateVectorStore对象中已经指定了
tenant=tenant, # 指定租户名称
)
return ids
# 根据doc_id删除一条租户文档数据,使用weaviate原生客户端(langchain_weaviate要先查到所有的uuid,再删除,感觉低效)
def delete_documents(doc_id:str, tenant:str, class_name:str):
coll = weaviate_client.collections.get(class_name).with_tenant(tenant) # 获取租户的集合
# 构建删除条件
where = Filter.by_property("doc_id").equal(doc_id)
# 执行删除
res = coll.data.delete_many(where=where)
return res
# 删除一条信息
def delete_one_info(id:str, tenant:str):
# 使用Langchain_weaviate删除
store.delete(ids=[id], tenant=tenant)
# 列出租户的所有文档,已经过doc_id去重,补齐文件名和创建时间,使用weaviate原生客户端
def list_tenant_files(class_name: str, tenant: str):
# 函数内函数,转换为uuid
def _to_uuid(v):
"""把 doc_id 统一转为 uuid.UUID(输入可能是 str 或 uuid.UUID)。""" # 中文注释
if isinstance(v, uuid.UUID):
return v # 已是 UUID # 中文注释
if isinstance(v, str):
return uuid.UUID(v) # 字符串转 UUID;若格式非法会抛 ValueError # 中文注释
# Weaviate 某些返回类型可能是自定义 UUID 包装,这里统一转字符串再转 UUID # 中文注释
return uuid.UUID(str(v))
# 列出租户的所有文档(不是数据,按doc_id去重了),使用weaviate原生客户端
def list_files_groupby(class_name: str, tenant: str):
coll = weaviate_client.collections.use(class_name).with_tenant(tenant) # 绑定租户 # 中文注释
resp = coll.aggregate.over_all(
group_by=GroupByAggregate(prop="doc_id") # 按 doc_id 分组 # 中文注释
)
rows = []
for g in resp.groups: # 每个分组就是一个文件 # 中文注释
did = _to_uuid(g.grouped_by.value) # 分组值统一为 uuid.UUID # 中文注释
rows.append({
"doc_id": did, # 用 UUID 存储,避免后续类型错配 # 中文注释
"chunks": g.total_count, # 该文件的切片数 # 中文注释
})
return rows
def _or_filters(filters):
cur = filters[0]
for f in filters[1:]:
cur = Filter.or_(cur, f) # v4 写法:用 Filter.or_ 组合 # 中文注释
return cur
# 补齐文档列表数据
def enrich_with_filename(class_name: str, tenant: str, rows: list[dict], batch_size: int = 64):
"""
给 [{'doc_id'(UUID),'chunks'}...] 批量补齐 file_name/created_at;返回时把 doc_id 转成 str。 # 中文注释
"""
coll = weaviate_client.collections.use(class_name).with_tenant(tenant)
out = {r["doc_id"]: {**r} for r in rows} # key: uuid.UUID # 中文注释
doc_ids = list(out.keys())
for i in range(0, len(doc_ids), batch_size):
part = doc_ids[i:i+batch_size] # 本批的 UUID 列表 # 中文注释
where = _or_filters([Filter.by_property("doc_id").equal(d) for d in part]) # 直接用 UUID 等值过滤 # 中文注释
# 放大 limit,避免都命中在个别 doc_id 上导致其它没取到代表记录 # 中文注释
res = coll.query.fetch_objects(
filters=where,
limit=max(1, len(part) * 3), # 略放大;如仍有缺失,下面兜底单查 # 中文注释
return_properties=["doc_id", "file_name", "created_at"],
)
# print(f"批量补齐返回: {res}") # 如需调试可打开 # 中文注释
for o in (res.objects or []):
p = o.properties or {}
did = _to_uuid(p.get("doc_id")) # 统一为 UUID 再索引 # 中文注释
if did in out and "file_name" not in out[did]:
out[did]["file_name"] = p.get("file_name")
out[did]["created_at"] = p.get("created_at")
# 本批仍然缺失的 doc_id,兜底逐个拉 1 条代表记录 # 中文注释
missing = [d for d in part if "file_name" not in out[d]]
for did in missing:
flt = Filter.by_property("doc_id").equal(did) # did 是 UUID # 中文注释
r1 = coll.query.fetch_objects(
filters=flt, limit=1,
return_properties=["doc_id", "file_name", "created_at"],
)
if r1.objects:
p = r1.objects[0].properties or {}
out[did]["file_name"] = p.get("file_name")
out[did]["created_at"] = p.get("created_at")
# 返回时把 UUID 友好化为字符串(如果你更想保留 UUID,就去掉 str(...)) # 中文注释
return [
{
"doc_id": str(k), # UUID -> str 便于展示或序列化 # 中文注释
"chunks": v["chunks"],
"file_name": v.get("file_name"),
"created_at": v.get("created_at"),
}
for k, v in out.items()
if k in out # 保守写法,避免极端情况下 KeyError # 中文注释
]
rows = list_files_groupby(class_name, tenant) # 列出租户的所有文档,使用weaviate原生客户端
res = enrich_with_filename(class_name, tenant, rows)# 补齐文档列表数据
return res
# 列出租户总共有多少条数据,使用weaviate原生客户端
def list_tenant_num(class_name: str, tenant: str):
coll = weaviate_client.collections.use(class_name).with_tenant(tenant) # 绑定租户
ag = coll.aggregate.over_all()
# 不同版本字段名可能不同,做个兼容
total = getattr(ag, "total_count", None) or getattr(ag, "totalCount", None)
if isinstance(total, int):
return total
return 0
# 传入doc_id,返回切片信息
def list_chunks_by_doc_id(class_name: str,
tenant: str,
doc_id,
page_size: int = 1000) -> list[dict]:
"""
列出某租户下、指定 doc_id 的所有切片,自动处理 UUID/TEXT 类型,分页获取并按 chunk_index 排序。 # 中文注释
返回:每条切片的属性字典列表(包含你在 return_properties 里请求的字段)。 # 中文注释
"""
coll = weaviate_client.collections.use(class_name).with_tenant(tenant)
flt = Filter.by_property("doc_id").equal(doc_id)
items, after = [], None
while True:
res = coll.query.fetch_objects(
filters=flt,
limit=page_size,
after=after,
return_properties=["text", "chunk_index"], # 返回字段
)
for o in res.objects or []:
props = o.properties or {}
# 把对象 UUID 一并返回,便于后续单条删除/定位
props["id"] = str(o.uuid)
items.append(props)
# 兼容不同 SDK 命名的分页信息 # 中文注释
pi = getattr(res, "page_info", None) or getattr(res, "pageInfo", None)
has_next = bool(getattr(pi, "has_next_page", False) or getattr(pi, "hasNextPage", False))
after = getattr(pi, "end_cursor", None) or getattr(pi, "endCursor", None)
if not has_next:
break
# # 在服务端按chunk_index分片排序,建议在客户端排序,节省资源
# if items and ("chunk_index" in items[0]):
# items.sort(key=lambda x: (x.get("chunk_index") is None, x.get("chunk_index")))
return items
# 按租户查询数据(向量查询)
def similarity_search(query, k, tenant):
results = store.similarity_search(
query=query, # 查询文本
k=k, # 返回条数限制
tenant=tenant, # 指定租户名称
)
return results
# 带阈值检索
def similarity_search_threshold(query: str, k: int, tenant: str, score_threshold: float=0.6):
"""
优先使用 VectorStore 内置的带阈值检索(如果可用)。 # 中文注释
返回:List[Document] # 中文注释
"""
# 直接让 VectorStore 过滤,分数一般在 [0, 1],越大越相关 # 中文注释
docs_and_scores = store.similarity_search_with_relevance_scores(
query=query,
k=k,
score_threshold=score_threshold,
tenant=tenant,
)
# 有的实现返回 [(Document, score)],做一下提取 # 中文注释
return [doc for doc, _score in docs_and_scores] if docs_and_scores and isinstance(docs_and_scores[0], tuple) else docs_and_scores
# 纯关键字搜索不经过embeddings,需使用原生weaviate_client(langchain_weaviate会自动使用embeddings)
def keyword_query(coll:str,tenant:str,keyword:str) -> list[dict]:
# 得到租户专属集合
coll = weaviate_client.collections.get(coll).with_tenant(tenant)
# 纯 BM25 关键词检索(不走向量、不需要 vectorizer)
res = coll.query.bm25(
query=keyword, # 关键词查询文本
limit=5, # 返回条数限制
)
# 拼装数据,使用自定义的json格式返回
datas=[]
for r in res.objects:
properties=r.properties
data = {
"id":str(r.uuid), # Weaviate对象唯一标识符(UUID字符串格式)
"text":properties.get("text"), # 文档切片的主要内容文本
"source":properties.get("source"), # 原始文件路径(由TextLoader自动添加)
"chunk_index":properties.get("chunk_index"), # 切片索引,标识切片在文档中的顺序位置
"doc_id":properties.get("doc_id"), # 文档唯一标识符,用于关联同一文档的所有切片
"file_name":properties.get("file_name"), # 原始文件名,便于文档来源追踪
"created_at":properties.get("created_at"), # 文档创建时间,格式:YYYY-MM-DD HH:MM:SS
}
datas.append(data)
print(datas)
return datas
if __name__ == "__main__":
# docs = split_documents("test.txt")
# print(f"分割后的文档数量: {docs}")
# print("第一个文档的metadata:")
# print(docs[0].metadata)
# add_documents(docs, "User_1")
# delete_documents("42ed41c3-7c77-4d08-866e-dc0708f7c582", "User_1", "Test")
# delete_one_info("6ccef603-a4c7-4630-bfc1-d6dd4b144f1a", "User_1")
# print(keyword_query("Test","User_1","自然语言处理")) # 纯关键字搜索不经过embeddings,需使用原生weaviate_client
# print(similarity_search("自编码器适用任务", 2, "User_1")) # 按租户查询向量
# print(list_tenant_files("Test", "User_1")) # 列出租户的所有文档
# print(list_tenant_num("Test", "User_1")) # 列出租户总共有多少条数据
# print(list_chunks_by_doc_id("Test", "User_1", "017a949e-be1f-4bba-8d07-12ca569adb4c")) # 列出租户指定doc_id的所有切片
# print(similarity_search_threshold("自编码器适用任务", 2, "User_1", 0.6)) # 带阈值检索
# 在fastapi生命周期结束时,应该关闭weaviate连接
weaviate_client.close()
异步示范
import weaviate
from langchain_openai import OpenAIEmbeddings
from langchain_weaviate import WeaviateVectorStore
from weaviate.classes.query import Filter
from weaviate.classes.aggregate import GroupByAggregate
import env
import uuid
from langchain.docstore.document import Document
import os
from typing import Any
# 创建异步客户端
async_client = weaviate.use_async_with_local(
host=env.WEAVIATE_HOST,
port=env.WEAVIATE_PORT,
grpc_port=env.WEAVIATE_GRPC_PORT,
skip_init_checks=True, # 建议跳过启动检查,防止超时
)
# 创建同步客户端(给langchain_weaviate调用)
sync_client = weaviate.connect_to_local(
host=env.WEAVIATE_HOST,
port=env.WEAVIATE_PORT,
grpc_port=env.WEAVIATE_GRPC_PORT,
skip_init_checks=True, # 建议跳过启动检查,防止超时
)
# 指定嵌入模型
embeddings = OpenAIEmbeddings(
model=env.EMBEDDING_MODEL,
api_key=env.OPENAI_API_KEY,
base_url=env.OPENAI_BASE_URL,
)
# 获取多租户集合实例(如果不存在则创建) - 使用同步客户端
store = WeaviateVectorStore(
client=sync_client, # 使用同步客户端,LangChain不支持异步客户端
index_name=env.COLLECTION_NAME, # 指定集合名称
text_key="text", # 指定文本字段名称
embedding=embeddings, # 指定嵌入模型,也可在具体操作方法中指定
use_multi_tenancy=True, # 使用多租户,必须在创建集合时指定
)
# 写入一条租户文档数据(把切分好的文档写到集合)
async def add_documents(docs, tenant):
ids = await store.aadd_documents(
documents=docs, # 文档切片
tenant=tenant, # 指定租户名称
)
return ids
# 删除一条信息
async def delete_one_info(id:str, tenant:str):
# 使用Langchain_weaviate删除
await store.adelete(ids=[id], tenant=tenant)
# 根据doc_id删除一条租户文档数据,使用weaviate原生客户端(langchain_weaviate要先查到所有的uuid,再删除,感觉低效)
async def delete_documents(doc_id:str, tenant:str, class_name:str) -> bool:
coll = async_client.collections.get(class_name).with_tenant(tenant) # 获取租户的集合
# 构建过滤条件,先读取一条记录拿到 source
where = Filter.by_property("doc_id").equal(doc_id)
rep = await coll.query.fetch_objects(
filters=where,
limit=1,
return_properties=["source"],
)
source = None
if getattr(rep, "objects", None):
try:
source = rep.objects[0].properties.get("source")
except Exception:
source = None
# 先执行删除
del_res = await coll.data.delete_many(where=where)
# 尝试删除本地源文件(尽力而为)
if source and os.path.exists(source):
try:
os.remove(source)
except Exception:
pass
return del_res
# 列出租户的所有文档,已经过doc_id去重,补齐文件名和创建时间,使用weaviate原生客户端
async def list_tenant_files(class_name: str, tenant: str):
# 函数内函数,转换为uuid
def _to_uuid(v):
"""把 doc_id 统一转为 uuid.UUID(输入可能是 str 或 uuid.UUID)。""" # 中文注释
if isinstance(v, uuid.UUID):
return v # 已是 UUID # 中文注释
if isinstance(v, str):
return uuid.UUID(v) # 字符串转 UUID;若格式非法会抛 ValueError # 中文注释
# Weaviate 某些返回类型可能是自定义 UUID 包装,这里统一转字符串再转 UUID # 中文注释
return uuid.UUID(str(v))
# 列出租户的所有文档(不是数据,按doc_id去重了),使用weaviate原生客户端
async def list_files_groupby(class_name: str, tenant: str):
coll = async_client.collections.use(class_name).with_tenant(tenant) # 绑定租户 # 中文注释
resp = await coll.aggregate.over_all(
group_by=GroupByAggregate(prop="doc_id") # 按 doc_id 分组 # 中文注释
)
rows = []
for g in resp.groups: # 每个分组就是一个文件 # 中文注释
did = _to_uuid(g.grouped_by.value) # 分组值统一为 uuid.UUID # 中文注释
rows.append({
"doc_id": did, # 用 UUID 存储,避免后续类型错配 # 中文注释
"chunks": g.total_count, # 该文件的切片数 # 中文注释
})
return rows
async def _or_filters(filters):
cur = filters[0]
for f in filters[1:]:
cur = Filter.or_(cur, f) # v4 写法:用 Filter.or_ 组合 # 中文注释
return cur
# 补齐文档列表数据
async def enrich_with_filename(class_name: str, tenant: str, rows: list[dict], batch_size: int = 64):
"""
给 [{'doc_id'(UUID),'chunks'}...] 批量补齐 file_name/created_at;返回时把 doc_id 转成 str。 # 中文注释
"""
coll = async_client.collections.use(class_name).with_tenant(tenant)
out = {r["doc_id"]: {**r} for r in rows} # key: uuid.UUID # 中文注释
doc_ids = list(out.keys())
for i in range(0, len(doc_ids), batch_size):
part = doc_ids[i:i+batch_size] # 本批的 UUID 列表 # 中文注释
where = await _or_filters([Filter.by_property("doc_id").equal(d) for d in part]) # 直接用 UUID 等值过滤 # 中文注释
# 放大 limit,避免都命中在个别 doc_id 上导致其它没取到代表记录 # 中文注释
res = await coll.query.fetch_objects(
filters=where,
limit=max(1, len(part) * 3), # 略放大;如仍有缺失,下面兜底单查 # 中文注释
return_properties=["doc_id", "file_name", "created_at"],
)
# print(f"批量补齐返回: {res}") # 如需调试可打开 # 中文注释
for o in (res.objects or []):
p = o.properties or {}
did = _to_uuid(p.get("doc_id")) # 统一为 UUID 再索引 # 中文注释
if did in out and "file_name" not in out[did]:
out[did]["file_name"] = p.get("file_name")
out[did]["created_at"] = p.get("created_at")
# 本批仍然缺失的 doc_id,兜底逐个拉 1 条代表记录 # 中文注释
missing = [d for d in part if "file_name" not in out[d]]
for did in missing:
flt = Filter.by_property("doc_id").equal(did) # did 是 UUID # 中文注释
r1 = await coll.query.fetch_objects(
filters=flt, limit=1,
return_properties=["doc_id", "file_name", "created_at"],
)
if r1.objects:
p = r1.objects[0].properties or {}
out[did]["file_name"] = p.get("file_name")
out[did]["created_at"] = p.get("created_at")
# 返回时把 UUID 友好化为字符串(如果你更想保留 UUID,就去掉 str(...)) # 中文注释
return [
{
"doc_id": str(k), # UUID -> str 便于展示或序列化 # 中文注释
"chunks": v["chunks"],
"file_name": v.get("file_name"),
"created_at": v.get("created_at"),
}
for k, v in out.items()
if k in out # 保守写法,避免极端情况下 KeyError # 中文注释
]
rows = await list_files_groupby(class_name, tenant) # 列出租户的所有文档,使用weaviate原生客户端
res = await enrich_with_filename(class_name, tenant, rows)# 补齐文档列表数据
return res
# 列出租户总共有多少条数据,使用weaviate原生客户端
async def list_tenant_num(class_name: str, tenant: str):
coll = async_client.collections.use(class_name).with_tenant(tenant) # 绑定租户
ag = await coll.aggregate.over_all()
# 不同版本字段名可能不同,做个兼容
total = getattr(ag, "total_count", None) or getattr(ag, "totalCount", None)
if isinstance(total, int):
return total
return 0
# 传入doc_id,返回切片信息
async def list_doc_chunks(class_name: str,
tenant: str,
doc_id,
page_size: int = 1000) -> list[dict]:
"""
列出某租户下、指定 doc_id 的所有切片,自动处理 UUID/TEXT 类型,分页获取并按 chunk_index 排序。 # 中文注释
返回:每条切片的属性字典列表(包含你在 return_properties 里请求的字段)。 # 中文注释
"""
coll = async_client.collections.use(class_name).with_tenant(tenant)
flt = Filter.by_property("doc_id").equal(doc_id)
items, after = [], None
while True:
res = await coll.query.fetch_objects(
filters=flt,
limit=page_size,
after=after,
return_properties=["text", "chunk_index"], # 返回字段
)
for o in res.objects or []:
props = o.properties or {}
# 把对象 UUID 一并返回,便于后续单条删除/定位
props["id"] = str(o.uuid)
items.append(props)
# 兼容不同 SDK 命名的分页信息 # 中文注释
pi = getattr(res, "page_info", None) or getattr(res, "pageInfo", None)
has_next = bool(getattr(pi, "has_next_page", False) or getattr(pi, "hasNextPage", False))
after = getattr(pi, "end_cursor", None) or getattr(pi, "endCursor", None)
if not has_next:
break
# # 在服务端按chunk_index分片排序,建议在客户端排序,节省资源
# if items and ("chunk_index" in items[0]):
# items.sort(key=lambda x: (x.get("chunk_index") is None, x.get("chunk_index")))
return items
# 按租户查询数据(向量查询,异步)
async def similarity_search(query: str, k: int, tenant: str):
"""异步向量检索,直接调用 WeaviateVectorStore 的异步接口。
参数:
query: 查询文本
k: 返回条数
tenant: 多租户名
返回:
list[Document]: 命中文档列表
"""
return await store.asimilarity_search(
query=query,
k=k,
tenant=tenant,
)
# 带阈值检索
async def similarity_search_threshold(query: str, k: int, tenant: str, score_threshold: float=0.6):
"""异步带阈值检索。
返回:
list[Document]
"""
docs_and_scores = await store.asimilarity_search_with_relevance_scores(
query=query,
k=k,
score_threshold=score_threshold,
tenant=tenant,
)
return [doc for doc, _score in docs_and_scores] if docs_and_scores and isinstance(docs_and_scores[0], tuple) else docs_and_scores
# 列出文档切片
async def get_document_chunks(tenant: str, doc_id: str) -> list[Document]:
"""获取指定文档ID的切分内容
参数:
user_id: 用户ID
doc_id: 文档ID
"""
collection = async_client.collections.get(env.COLLECTION_NAME).with_tenant(tenant)
# 行内中文注释:按 user_id 与 doc_id 过滤,取回文本与常见元数据字段
result = await collection.query.fetch_objects(
limit=10000,
filters=(
Filter.by_property("doc_id").equal(doc_id)
),
return_properties=[
# 仅请求当前 schema 中存在且样例已证明存在的属性,避免 GRPC 报错
"text",
"source",
"file_name",
"doc_id",
"created_at",
"chunk_index",
],
)
print(f"获取指定文档ID的切分内容: {result}")
objs = getattr(result, "objects", []) or []
docs: list[Document] = []
for obj in objs:
props = getattr(obj, "properties", {}) or {}
text = props.get("text") or ""
# 行内中文注释:其余属性作为元数据写入 Document
metadata = {k: v for k, v in props.items() if k != "text"}
docs.append(Document(page_content=text, metadata=metadata))
# 行内中文注释:按 chunk_index 升序排序,无法转换为整数的按 0 处理
def _to_int(val: Any) -> int:
try:
return int(val) # 兼容字符串数字
except Exception:
return 0
docs.sort(key=lambda d: _to_int(d.metadata.get("chunk_index", 0)))
return docs
# 统一关闭(异步+同步)
async def vector_close():
"""关闭 Weaviate 客户端连接,避免资源泄漏。"""
try:
await async_client.close()
except Exception:
pass
try:
sync_client.close()
except Exception:
pass
if __name__ == "__main__":
import asyncio
async def _demo():
# 先建立异步连接,避免 ClosedClient 错误 # 中文注释
await async_client.connect()
try:
res = await get_document_chunks(tenant="user1", doc_id="bbd30c61-5f96-4b12-bf51-ce03b0c2c5c1")
print(res)
finally:
# 确保连接被正确关闭,避免资源泄漏 # 中文注释
await vector_close()
asyncio.run(_demo())