LangGraph(七)——Code Generation

一、 LangGraph for Code Generation

1.1 动机

代码生成和分析是大型语言模型(LLMs)最重要的应用之一,这一点从GitHub Copilot这样的产品的普遍性以和GPT-engineer这样的项目的受欢迎程度就可见一斑。最近AlphaCodium的研究工作表明,通过使用流程范式而不是简单的prompt:answer,可以改进代码生成,答案可以通过迭代没(1)测试答案和(2)反思这些测试的结果来提升生成的质量。

通过使用LangGraph构建两种不同的Code生成方式:

  1. 通过提示和上下文填充进行代码生成。
  2. 涉及检查和运行代码的代码生成流程,如果出现错误,则将其传回以自我纠正。

来看看第二种方式到底有多大提升。(结果表明55%->81%)

1.2 评估集

LangChain官方做了以下工作:

选择LangChain表达式语言(LCEL)的LangChain文档的子集,它既限定了范围(约60k token)。挖掘30天的chat-langchain中与LCEL相关的问题(code)。筛选提到LCEL的问题,从超过60k个聊天中筛选出大约500个。对这大约500个问题进行聚类,并使用一个大型语言模型(GPT-4,128k)对聚类进行总结,为每个类别提供有代表性的问题。并手动审查并为每个问题生成了一个真实的答案(评估集包含20个问题,data)。并将这个数据集添加到了LangSmith中。

二、使用LangGraph实现 Code generation with Reflection

使用LangGraph实现一个带有Relection的 Code Generator,包含以下组件:

我们实现了一个包含以下组件的代码生成流程:

  1. 上下文填充:受到长上下文大型语言模型(LLMs)最近趋势的启发,使用具有128k个token上下文窗口的GPT-4对60k个token的LCEL文档进行上下文填充。将有关LCEL的问题传递给填充了上下文的LCEL chain,以生成初步答案。

  2. 输出解析:使用OpenAI工具将输出解析为一个包含三个部分的Pydantic对象:(1) 描述问题的前言,(2) import块,(3) 代码。

  3. 执行import检查:首先执行import检查,因为在代码生成过程中,幻觉(hallucinations,指生成的代码中的错误或不准确的部分)可能会悄悄进入import语句。

  4. 代码可执行性检查:如果import检查通过,接着检查代码本身是否可以执行。在生成提示中,指导LLM在代码解决方案中不使用伪代码或未定义的变量,这应该会产生可执行的代码。

  5. 错误反馈与重试:如果任一检查失败,会将堆栈跟踪(stack trace,记录程序运行时出现错误的详细信息)连同之前的答案一起传回给生成节点进行反思。允许这个过程重试3次(这只是默认值),但当然可以根据需要进行扩展。

这个流程展示了一个自我修正的代码生成系统,它通过测试生成的代码并根据测试结果进行迭代改进,从而提高了代码生成系统的性能。

DOCS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from bs4 import BeautifulSoup as Soup
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader

# LCEL docs
url = "https://python.langchain.com/docs/expression_language/"
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()

# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)

LLMs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field

### OpenAI

# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[("system","""You are a coding assistant with expertise in LCEL, LangChain expression language. \n
Here is a full set of LCEL documentation: \n ------- \n {context} \n ------- \n Answer the user
question based on the above provided documentation. Ensure any code you provide can be executed \n
with all required imports and variables defined. Structure your answer with a description of the code solution. \n
Then list the imports. And finally list the functioning code block. Here is the user question:"""),
("placeholder", "{messages}")]
)

# Data model
class code(BaseModel):
"""Code output"""

prefix: str = Field(description="Description of the problem and approach")
imports: str = Field(description="Code block import statements")
code: str = Field(description="Code block not including import statements")
description = "Schema for code solutions to questions about LCEL."

expt_llm = "gpt-4-turbo"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(code)
question = "How do I build a RAG chain in LCEL?"
# solution = code_gen_chain_oai.invoke({"context":concatenated_content,"messages":[("user",question)]})

Graph State

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from typing import Dict, TypedDict, List

class GraphState(TypedDict):
"""
Represents the state of our graph.

Attributes:
error : Binary flag for control flow to indicate whether test error was tripped
messages : With user question, error messages, reasoning
generation : Code solution
iterations : Number of tries
"""

error : str
messages : List
generation : str
iterations : int

Nodes

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from operator import itemgetter
from langchain.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnablePassthrough

### Parameter

# Max tries
max_iterations = 3
# Reflect
# flag = 'reflect'
flag = 'do not reflect'

### Nodes

def generate(state: GraphState):
"""
Generate a code solution

Args:
state (dict): The current graph state

Returns:
state (dict): New key added to state, generation
"""

print("---GENERATING CODE SOLUTION---")

# State
messages = state["messages"]
iterations = state["iterations"]
error = state["error"]

# We have been routed back to generation with an error
if error == "yes":
messages += [("user","Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:")]

# Solution
code_solution = code_gen_chain.invoke({"context": concatenated_content, "messages" : messages})
messages += [("assistant",f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}")]

# Increment
iterations = iterations + 1
return {"generation": code_solution, "messages": messages, "iterations": iterations}

def code_check(state: GraphState):
"""
Check code

Args:
state (dict): The current graph state

Returns:
state (dict): New key added to state, error
"""

print("---CHECKING CODE---")

# State
messages = state["messages"]
code_solution = state["generation"]
iterations = state["iterations"]

# Get solution components
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code

# Check imports
try:
exec(imports)
except Exception as e:
print("---CODE IMPORT CHECK: FAILED---")
error_message = [("user", f"Your solution failed the import test: {e}")]
messages += error_message
return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes"}

# Check execution
try:
exec(imports + "\n" + code)
except Exception as e:
print("---CODE BLOCK CHECK: FAILED---")
error_message = [("user", f"Your solution failed the code execution test: {e}")]
messages += error_message
return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes"}

# No errors
print("---NO CODE TEST FAILURES---")
return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "no"}

def reflect(state: GraphState):
"""
Reflect on errors

Args:
state (dict): The current graph state

Returns:
state (dict): New key added to state, generation
"""

print("---GENERATING CODE SOLUTION---")

# State
messages = state["messages"]
iterations = state["iterations"]
code_solution = state["generation"]

# Prompt reflection
reflection_message = [("user", """You tried to solve this problem and failed a unit test. Reflect on this failure
given the provided documentation. Write a few key suggestions based on the
documentation to avoid making this mistake again.""")]

# Add reflection
reflections = code_gen_chain.invoke({"context" : concatenated_content, "messages" : messages})
messages += [("assistant" , f"Here are reflections on the error: {reflections}")]
return {"generation": code_solution, "messages": messages, "iterations": iterations}

### Edges

def decide_to_finish(state: GraphState):
"""
Determines whether to finish.

Args:
state (dict): The current graph state

Returns:
str: Next node to call
"""
error = state["error"]
iterations = state["iterations"]

if error == "no" or iterations == max_iterations:
print("---DECISION: FINISH---")
return "end"
else:
print("---DECISION: RE-TRY SOLUTION---")
if flag == 'reflect':
return "reflect"
else:
return "generate"

Graph

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("generate", generate) # generation solution
workflow.add_node("check_code", code_check) # check code
workflow.add_node("reflect", reflect) # reflect

# Build graph
workflow.set_entry_point("generate")
workflow.add_edge("generate", "check_code")
workflow.add_conditional_edges(
"check_code",
decide_to_finish,
{
"end": END,
"reflect": "reflect",
"generate": "generate",
},
)
workflow.add_edge("reflect", "generate")
app = workflow.compile()

Graph预览

1
2
3
from IPython.display import Image

Image(app.get_graph().draw_png())

执行

1
2
question = "How can I directly pass a string to a runnable and use it to construct the input needed for my prompt?"
app.invoke({"messages":[("user",question)],"iterations":0})

输出

1
2
3
4
5
6
7
8
9
10
11
12
---GENERATING CODE SOLUTION---
---CHECKING CODE---
messages=[HumanMessage(content='Process this input: Hello, this is a direct string input!')]
---NO CODE TEST FAILURES---
---DECISION: FINISH---
{'error': 'no',
'messages': [('user',
'How can I directly pass a string to a runnable and use it to construct the input needed for my prompt?'),
('assistant',
'Directly passing a string to a runnable for constructing input for a prompt in LCEL \n Imports: from langchain_core.runnables import RunnableLambda\nfrom langchain_core.prompts import ChatPromptTemplate \n Code: # Define the prompt template\nprompt = ChatPromptTemplate.from_template("Process this input: {input}")\n\n# Define a runnable that constructs the input for the prompt\nconstruct_input = RunnableLambda(lambda x: {\'input\': x})\n\n# Chain the input constructor with the prompt\nchain = construct_input | prompt\n\n# Example usage\nresult = chain.invoke(\'Hello, this is a direct string input!\')\nprint(result)')],
'generation': code(prefix='Directly passing a string to a runnable for constructing input for a prompt in LCEL', imports='from langchain_core.runnables import RunnableLambda\nfrom langchain_core.prompts import ChatPromptTemplate', code='# Define the prompt template\nprompt = ChatPromptTemplate.from_template("Process this input: {input}")\n\n# Define a runnable that constructs the input for the prompt\nconstruct_input = RunnableLambda(lambda x: {\'input\': x})\n\n# Chain the input constructor with the prompt\nchain = construct_input | prompt\n\n# Example usage\nresult = chain.invoke(\'Hello, this is a direct string input!\')\nprint(result)', description='Schema for code solutions to questions about LCEL.'),
'iterations': 1}

LangSmith流程展示:

三、使用LangSmith进行评估

首先实现了一个没有使用LangGraph的上下文填充作为基本对照,这是graph中的第一个节点,没有任何检查或反馈:使用具有128k个token上下文窗口的GPT-4对60k个token的LCEL文档进行上下文填充。将有关LCEL的问题传递给填充了上下文的LCEL chain,以生成答案。

为以下两个方面在LangSmith中实现了自定义评估器

  1. 导入评估
  2. 代码执行

使用20个问题的评估集对上下文填充进行了四次评估,评估结果。在上下文填充中,大约98%的导入测试是正确的,而大约55%的代码执行测试是正确的(N=79次成功的试验)。使用LangSmith来查看失败案例:这里有一个例子,它没有意识到RunnableLambda函数的输入将是一个字典,并错误地认为它是一个字符串:AttributeError: 'dict' object has no attribute 'upper'

然后,测试了上下文填充 + LangGraph,以:

  1. 在导入和代码执行中执行检查,以发现此类错误;
  2. 在执行更新答案生成时反思任何错误。

在同一评估集上,看到100%的导入测试是正确的,以及大约81%的代码执行测试是正确的(N=78次试验)。

我们可以重新审视上述失败案例,以说明原因:完整的跟踪显示我们确实在第二次尝试回答问题时遇到了相同的错误。我们在反思步骤中包含了这个错误,其中先前的解决方案和由此产生的错误都包含在提示中,用于生成最终(正确)的答案:

1
2
3
4
5
6
7
You previously tried to solve this problem. 
...
--- Most recent run error ---
Execution error: 'dict' object has no attribute 'upper'
...
Please re-try to answer this.
...

最终生成的代码正确处理了RunnableLambda函数中的输入字典,绕过了在上下文填充基线案例中观察到的错误。总体而言,通过添加这个简单的反思步骤并使用LangGraph进行重试,代码执行的性能有了显著提升,提高了约47%。如图:

Download DataSet

1
2
3
4
5
6
import langsmith
client = langsmith.Client()

# Clone the dataset to your tenant to use it
public_dataset = ("https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d")
client.clone_public_dataset(public_dataset)

Custom evals

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from langsmith.schemas import Example, Run

def check_import(run: Run, example: Example) -> dict:
imports = run.outputs.get("imports")
try:
exec(imports)
return {"key": "import_check" , "score": 1}
except:
return {"key": "import_check" , "score": 0}

def check_execution(run: Run, example: Example) -> dict:
imports = run.outputs.get("imports")
code = run.outputs.get("code")
try:
exec(imports + "\n" + code)
return {"key": "code_execution_check" , "score": 1}
except:
return {"key": "code_execution_check" , "score": 0}

Compare LangGraph to Context Stuffing

1
2
3
4
5
6
7
8
9
10
11
def predict_base_case(example: dict):
""" Context stuffing """
solution = code_gen_chain.invoke({"context" : concatenated_content, "messages" : [("user",example["question"])]})
solution_structured = structured_code_formatter.invoke([("code",solution)])
return {"imports": solution_structured.imports, "code": solution_structured.code}

def predict_langgraph(example: dict):
""" LangGraph """
graph = app.invoke({"messages":[("user",example["question"])],"iterations":0})
solution = graph["generation"]
return {"imports": solution.imports, "code": solution.code}

Eval

1
2
3
4
5
6
7
from langsmith.evaluation import evaluate

# Evaluator
code_evalulator = [check_import,check_execution]

# Dataset
dataset_name = "test-LCEL-code-gen"

Base Case

1
2
3
4
5
6
7
8
9
10
11
# Run base case
experiment_results_ = evaluate(
predict_base_case,
data=dataset_name,
evaluators=code_evalulator,
experiment_prefix=f"test-without-langgraph-{expt_llm}",
max_concurrency=2,
metadata={
"llm": expt_llm,
},
)

LangGraph

1
2
3
4
5
6
7
8
9
10
11
12
# Run with langgraph
experiment_results = evaluate(
predict_langgraph,
data=dataset_name,
evaluators=code_evalulator,
experiment_prefix=f"test-with-langgraph-{expt_llm}-{flag}",
max_concurrency=2,
metadata={
"llm": expt_llm,
"feedback": flag,
},
)

评估结果LangSmith展示:https://smith.langchain.com/public/78a3d858-c811-4e46-91cb-0f10ef56260b/d

官方资源

代码示例更详细说明:


LangGraph(七)——Code Generation
https://mztchaoqun.com.cn/posts/D38_LangGraph_Code_Generation/
作者
mztchaoqun
发布于
2024年9月21日
许可协议