LangGraph(四)——Self-Reflective RAG

一、Self-Reflective RAG

1.1 动机

由于大多数大型语言模型(LLMs)只是定期在大量公共数据语料库上进行训练,它们缺乏最新的信息和/或无法用于训练的私有数据。检索增强生成(Retrieval augmented generation,RAG)是大型语言模型应用开发中的一个核心范式,它通过将大型语言模型连接到外部数据源来解决这个问题(请参阅我们的视频系列和博客文章)。RAG的基本流程包括嵌入用户查询,检索与查询相关的文档,并将文档传递给大型语言模型,以便在检索到的上下文中生成答案。

基本的RAG流程

1.2 Self-Reflective RAG

实现RAG通常需要围绕以下步骤进行逻辑推理:询问何时检索(基于问题和索引的构成),何时重写问题以改善检索,或者何时丢弃不相关的检索文档并重试检索?Self-Reflective RAG的想法是使用大型语言模型(LLM)自我纠正质量差的检索和生成。

基本的RAG流程:大型语言模型(LLM)基于检索到的文档来确定要生成什么。一些RAG流程使用路由让LLM根据问题决定使用不同的检索器。但是,Self-Reflective RAG通常需要某种反馈来重新生成问题或重新检索文档。State machines是第三种支持循环的认知架构:状态机简单地让我们定义一组步骤(例如,检索、评估文档、重写查询)并设置它们之间的转换选项;例如,如果我们检索到的文档不相关,那么可以重写查询并重新检索新文档。

1.3 Self-Reflective RAG with LangGraph

LangGraph一种实现大型语言模型(LLM)状态机的简便方法。这为我们在不同RAG流程的布局上提供了很多灵活性,并支持具有特定决策点(例如文档分级)和循环(例如重试检索)的 RAG 的更通用的RAG flows过程。

下面利用CRAG 和 Self-RAG 两种RAG来说明LangGraph的灵活。

二、Corrective RAG (CRAG)

Corrective RAG (CRAG) 引入了一些有趣的想法:

  • 使用一个轻量级的检索评估器(retrieval evaluator)来评估针对查询检索到的文档的整体质量,为每个文档返回一个置信度分数。
  • 如果向量存储检索被认为是模糊的或与用户查询不相关,则执行基于网络的文档检索以补充上下文(web-based document retrieval to supplement context)
  • 通过将检索到的文档划分为“知识条(knowledge strips)”,对每个条进行评分,并过滤掉不相关的知识,对检索到的文档进行知识细化(knowledge refinement)

下面通过使用LangGraph来简单说明CRAG工作流程

  • 在LangGraph中,可以选择跳过知识精炼阶段,因为它对于理解工作流程布局不是必需的。
  • 如果检索到的文档不相关,可以通过使用Tavily Search API进行网页搜索来补充检索。
  • 利用查询重写技术来优化网页搜索的查询。
  • 对于需要做出二元决策的情况,使用Pydantic来模拟输出,确保每次运行LLM时都能调用一致的二元逻辑。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings

urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)

# Add to vectorDB
vectorstore = Chroma.from_documents(
documents=doc_splits,
collection_name="rag-chroma",
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()

LLMs

Retrieval Grader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field

# Data model
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""

binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")

# LLM with function call
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

# Prompt
system = """You are a grader assessing relevance of a retrieved document to a user question. \n
If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
]
)

retrieval_grader = grade_prompt | structured_llm_grader
question = "agent memory"
docs = retriever.get_relevant_documents(question)
doc_txt = docs[1].page_content

Generate

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
### Generate

from langchain import hub
from langchain_core.output_parsers import StrOutputParser

# Prompt
prompt = hub.pull("rlm/rag-prompt")

# LLM
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)

# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

# Chain
rag_chain = prompt | llm | StrOutputParser()

Question Re-writer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
### Question Re-writer

# LLM
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)

# Prompt
system = """You a question re-writer that converts an input question to a better version that is optimized \n
for web search. Look at the input and try to reason about the underlying sematic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "Here is the initial question: \n\n {question} \n Formulate an improved question."),
]
)

question_rewriter = re_write_prompt | llm | StrOutputParser()

### Search

from langchain_community.tools.tavily_search import TavilySearchResults
web_search_tool = TavilySearchResults(k=3)

Graph State

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from typing_extensions import TypedDict
from typing import List

class GraphState(TypedDict):
"""
Represents the state of our graph.

Attributes:
question: question
generation: LLM generation
web_search: whether to add search
documents: list of documents
"""
question : str
generation : str
web_search : str
documents : List[str]

Nodes

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from langchain.schema import Document

def retrieve(state):
"""
Retrieve documents

Args:
state (dict): The current graph state

Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
question = state["question"]

# Retrieval
documents = retriever.get_relevant_documents(question)
return {"documents": documents, "question": question}

def generate(state):
"""
Generate answer

Args:
state (dict): The current graph state

Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
print("---GENERATE---")
question = state["question"]
documents = state["documents"]

# RAG generation
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}

def grade_documents(state):
"""
Determines whether the retrieved documents are relevant to the question.

Args:
state (dict): The current graph state

Returns:
state (dict): Updates documents key with only filtered relevant documents
"""

print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]

# Score each doc
filtered_docs = []
web_search = "No"
for d in documents:
score = retrieval_grader.invoke({"question": question, "document": d.page_content})
grade = score.binary_score
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
web_search = "Yes"
continue
return {"documents": filtered_docs, "question": question, "web_search": web_search}

def transform_query(state):
"""
Transform the query to produce a better question.

Args:
state (dict): The current graph state

Returns:
state (dict): Updates question key with a re-phrased question
"""

print("---TRANSFORM QUERY---")
question = state["question"]
documents = state["documents"]

# Re-write question
better_question = question_rewriter.invoke({"question": question})
return {"documents": documents, "question": better_question}

def web_search(state):
"""
Web search based on the re-phrased question.

Args:
state (dict): The current graph state

Returns:
state (dict): Updates documents key with appended web results
"""

print("---WEB SEARCH---")
question = state["question"]
documents = state["documents"]

# Web search
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
documents.append(web_results)

return {"documents": documents, "question": question}

### Edges

def decide_to_generate(state):
"""
Determines whether to generate an answer, or re-generate a question.

Args:
state (dict): The current graph state

Returns:
str: Binary decision for next node to call
"""

print("---ASSESS GRADED DOCUMENTS---")
question = state["question"]
web_search = state["web_search"]
filtered_documents = state["documents"]

if web_search == "Yes":
# All documents have been filtered check_relevance
# We will re-generate a new query
print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---")
return "transform_query"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "generate"

Build Graph

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", retrieve) # retrieve
workflow.add_node("grade_documents", grade_documents) # grade documents
workflow.add_node("generate", generate) # generatae
workflow.add_node("transform_query", transform_query) # transform_query
workflow.add_node("web_search_node", web_search) # web search

# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)

# Compile
app = workflow.compile()

提问相关问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from pprint import pprint

# Run
inputs = {"question": "What are the types of agent memory?"}
for output in app.stream(inputs):
for key, value in output.items():
# Node
pprint(f"Node '{key}':")
# Optional: print full state at each node
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint("\n---\n")

# Final generation
pprint(value["generation"])

输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---
"Node 'grade_documents':"
'\n---\n'
---TRANSFORM QUERY---
"Node 'transform_query':"
'\n---\n'
---WEB SEARCH---
"Node 'web_search_node':"
'\n---\n'
---GENERATE---
"Node 'generate':"
'\n---\n'
('The different categories of memory in agents are short-term memory and '
'long-term memory. Short-term memory is used for in-context learning, while '
'long-term memory allows agents to retain and recall information over '
'extended periods by leveraging external storage. Agents can also utilize '
'tool use, such as calling external APIs for additional information.')

LangSmith流程展示:

提问无关问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from pprint import pprint

# Run
inputs = {"question": "How does the AlphaCodium paper work?"}
for output in app.stream(inputs):
for key, value in output.items():
# Node
pprint(f"Node '{key}':")
# Optional: print full state at each node
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint("\n---\n")

# Final generation
pprint(value["generation"])

输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---
"Node 'grade_documents':"
'\n---\n'
---TRANSFORM QUERY---
"Node 'transform_query':"
'\n---\n'
---WEB SEARCH---
"Node 'web_search_node':"
'\n---\n'
---GENERATE---
"Node 'generate':"
'\n---\n'
('The AlphaCodium paper functions by reasoning about problems in natural '
'language during a pre-processing phase and generating, running, and fixing '
'code solutions in an iterative phase. It improves the performance of Large '
'Language Models on code generation tasks by using a test-based, multi-stage, '
'code-oriented iterative flow. The key mechanisms involve reasoning about '
'problems in natural language, generating code solutions, and running and '
'fixing the solutions against public and AI-generated tests.')

LangSmith流程展示:

三、Self-RAG

Self-RAG 是一种与其他几个有趣的 RAG 想法相关的方法(论文)。 该框架训练 LLM 生成self-reflection tokens,以管理 RAG 过程中的各个阶段。基本过程如下:

  • Retrieve token来决定使用x (question)或者x (question),y (generation)来检索 D chunks,输出是yes, no, continue
  • ISRAEL token 通过循环 input (x (question), d (chunk)) for d in D 来判断 D 是否与 x 相关, 。输出是relevant, irrelevant
  • ISSUP token 决定 D 中每个 chunk 的LLM生成是否与该 chunk 相关。输入是 x,d,y for d in D。并确认y (generation)中所有值得验证的陈述都得到了d (chunk)的支持。输出是fully supported, partially supported, no support
  • ISUSE token 决定D 中每个chunk的生成内容对x是否是有用的响应。输入是 x,y for d in D。输出是{5, 4, 3, 2, 1}
Self-RAG使用的四种token

可以通过下图来理解信息流:

Self-RAG中使用的流程示意图

在LangGraph实现时进行一些简化和调整(根据需要进行定制和扩展),首先对每个检索到的文档进行评级。如果有相关文档,继续进行生成。如果所有文档都是不相关的,那么将转换查询,以形成一个改进的问题并重新检索。(也可以在这个路径中使用上面提到的CRAG(网络搜索)的想法作为一个补充节点)。

论文中将对每个块进行一次生成并对每个生成两次评级。而在LangGraph实现中则从所有相关文档中进行一次生成。然后,这次生成将相对于文档(例如,为了防止出现幻觉)和相对于答案进行评级。这样做减少了对大型语言模型(LLM)的调用次数,提高了响应速度,并允许将更多上下文整合到生成中。当然,如果需要更多控制,可以轻松地改为对每个块单独生成并单独评级。

使用LangGraph实现的Self-RAG:

Retriever

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings

urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)

# Add to vectorDB
vectorstore = Chroma.from_documents(
documents=doc_splits,
collection_name="rag-chroma",
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()

LLMs

Retrieval Grader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
### Retrieval Grader 

from typing import Literal

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI


# Data model
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""

binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")

# LLM with function call
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

# Prompt
system = """You are a grader assessing relevance of a retrieved document to a user question. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
]
)

retrieval_grader = grade_prompt | structured_llm_grader

Generate

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from langchain import hub
from langchain_core.output_parsers import StrOutputParser

# Prompt
prompt = hub.pull("rlm/rag-prompt")

# LLM
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)

# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

# Chain
rag_chain = prompt | llm | StrOutputParser()

documents generation Grader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Data model
class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""

binary_score: str = Field(description="Answer is grounded in the facts, 'yes' or 'no'")

# LLM with function call
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeHallucinations)

# Prompt
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
]
)

hallucination_grader = hallucination_prompt | structured_llm_grader

question generation Grader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
### Answer Grader 

# Data model
class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question."""

binary_score: str = Field(description="Answer addresses the question, 'yes' or 'no'")

# LLM with function call
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeAnswer)

# Prompt
system = """You are a grader assessing whether an answer addresses / resolves a question \n
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
]
)

answer_grader = answer_prompt | structured_llm_grader

Question Re-writer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
### Question Re-writer

# LLM
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)

# Prompt
system = """You a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the input and try to reason about the underlying sematic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "Here is the initial question: \n\n {question} \n Formulate an improved question."),
]
)

question_rewriter = re_write_prompt | llm | StrOutputParser()

Graph State

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from typing_extensions import TypedDict
from typing import List

class GraphState(TypedDict):
"""
Represents the state of our graph.

Attributes:
question: question
generation: LLM generation
documents: list of documents
"""
question : str
generation : str
documents : List[str]

Nodes

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
### Nodes

from langchain.schema import Document

def retrieve(state):
"""
Retrieve documents

Args:
state (dict): The current graph state

Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
question = state["question"]

# Retrieval
documents = retriever.get_relevant_documents(question)
return {"documents": documents, "question": question}

def generate(state):
"""
Generate answer

Args:
state (dict): The current graph state

Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
print("---GENERATE---")
question = state["question"]
documents = state["documents"]

# RAG generation
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}

def grade_documents(state):
"""
Determines whether the retrieved documents are relevant to the question.

Args:
state (dict): The current graph state

Returns:
state (dict): Updates documents key with only filtered relevant documents
"""

print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]

# Score each doc
filtered_docs = []
for d in documents:
score = retrieval_grader.invoke({"question": question, "document": d.page_content})
grade = score.binary_score
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
continue
return {"documents": filtered_docs, "question": question}

def transform_query(state):
"""
Transform the query to produce a better question.

Args:
state (dict): The current graph state

Returns:
state (dict): Updates question key with a re-phrased question
"""

print("---TRANSFORM QUERY---")
question = state["question"]
documents = state["documents"]

# Re-write question
better_question = question_rewriter.invoke({"question": question})
return {"documents": documents, "question": better_question}

### Edges

def decide_to_generate(state):
"""
Determines whether to generate an answer, or re-generate a question.

Args:
state (dict): The current graph state

Returns:
str: Binary decision for next node to call
"""

print("---ASSESS GRADED DOCUMENTS---")
question = state["question"]
filtered_documents = state["documents"]

if not filtered_documents:
# All documents have been filtered check_relevance
# We will re-generate a new query
print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---")
return "transform_query"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "generate"

def grade_generation_v_documents_and_question(state):
"""
Determines whether the generation is grounded in the document and answers question.

Args:
state (dict): The current graph state

Returns:
str: Decision for next node to call
"""

print("---CHECK HALLUCINATIONS---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]

score = hallucination_grader.invoke({"documents": documents, "generation": generation})
grade = score.binary_score

# Check hallucination
if grade == "yes":
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
# Check question-answering
print("---GRADE GENERATION vs QUESTION---")
score = answer_grader.invoke({"question": question,"generation": generation})
grade = score.binary_score
if grade == "yes":
print("---DECISION: GENERATION ADDRESSES QUESTION---")
return "useful"
else:
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
return "not useful"
else:
pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
return "not supported"

Build Graph

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", retrieve) # retrieve
workflow.add_node("grade_documents", grade_documents) # grade documents
workflow.add_node("generate", generate) # generatae
workflow.add_node("transform_query", transform_query) # transform_query

# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
"generate",
grade_generation_v_documents_and_question,
{
"not supported": "generate",
"useful": END,
"not useful": "transform_query",
},
)

# Compile
app = workflow.compile()

Graph Preview

1
2
3
from IPython.display import Image

Image(app.get_graph().draw_png())
Graph

执行

例子1:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from pprint import pprint

# Run
inputs = {"question": "Explain how the different types of agent memory work?"}
for output in app.stream(inputs):
for key, value in output.items():
# Node
pprint(f"Node '{key}':")
# Optional: print full state at each node
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint("\n---\n")

# Final generation
pprint(value["generation"])

输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: GENERATE---
"Node 'grade_documents':"
'\n---\n'
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---
"Node 'generate':"
'\n---\n'
('Short-term memory is used for in-context learning and allows the model to '
'learn new information. It has a limited capacity and lasts for a short '
'duration. Long-term memory, on the other hand, can store information for a '
'long time and has unlimited storage capacity. It includes '
'explicit/declarative memory for facts and events, and implicit/procedural '
'memory for skills and routines.')

LangSmith流程展示:

上图中有文档过滤

例子2:

1
2
3
4
5
6
7
8
9
10
11
inputs = {"question": "Explain how chain of thought prompting works?"}
for output in app.stream(inputs):
for key, value in output.items():
# Node
pprint(f"Node '{key}':")
# Optional: print full state at each node
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint("\n---\n")

# Final generation
pprint(value["generation"])

输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: GENERATE---
"Node 'grade_documents':"
'\n---\n'
---GENERATE---
---CHECK HALLUCINATIONS---
'---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---'
"Node 'generate':"
'\n---\n'
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---
"Node 'generate':"
'\n---\n'
('Chain of thought prompting involves guiding the behavior of autoregressive '
'language models by providing prompts or demonstrations that contain '
'high-quality reasoning chains. This can be done through methods such as '
'self-asking, interleaving retrieval with chain-of-thought reasoning, and '
'complexity-based prompting for multi-step reasoning. These techniques aim to '
"improve the model's ability to generate coherent and logical responses "
'without updating its weights.')

LangSmith流程展示:

上图中有评判生成未通过重新生成

官方资源

代码示例更详细说明:


LangGraph(四)——Self-Reflective RAG
https://mztchaoqun.com.cn/posts/D35_LangGraph_Self_Reflective_RAG/
作者
mztchaoqun
发布于
2024年9月1日
许可协议