【大模型】LangChain自定义 LLM 类
为便捷构建 LLM 应用,需要基于本地部署的LLM模型,自定义一个 LLM 类,将LLM接入到 LangChain 框架中。完成自定义 LLM 类之后,可以以完全一致的方式调用 LangChain 的接口,而无需考虑底层模型调用的不一致。
前言
为了便捷构建 LLM 应用,需要基于本地部署的LLM模型,自定义一个 LLM 类,将LLM接入到 LangChain 框架中。完成自定义 LLM 类之后,可以以完全一致的方式调用 LangChain 的接口,而无需考虑底层模型调用的不一致。
1.本地部署
1.1 本地部署代码准备
基于本地部署自定义 LLM 类并不复杂,我们只需从 LangChain.llms.base.LLM 类继承一个子类,并重写构造函数与 _call 函数即可。
from langchain.llms.base import LLM
from typing import Any, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, LlamaTokenizerFast
import torch
class Qwen2_LLM(LLM):
# 基于本地 Qwen2 自定义 LLM 类
tokenizer: AutoTokenizer = None
model: AutoModelForCausalLM = None
def __init__(self, mode_name_or_path :str):
super().__init__()
print("正在从本地加载模型...")
self.tokenizer = AutoTokenizer.from_pretrained(mode_name_or_path, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(mode_name_or_path, torch_dtype=torch.bfloat16, device_map="auto")
self.model.generation_config = GenerationConfig.from_pretrained(mode_name_or_path)
print("完成本地模型的加载")
def _call(self, prompt : str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any):
messages = [{"role": "user", "content": prompt }]
input_ids = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.tokenizer([input_ids], return_tensors="pt").to('cuda')
generated_ids = self.model.generate(model_inputs.input_ids,max_new_tokens=512)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
@property
def _llm_type(self) -> str:
return "Qwen2_LLM"
在上述类定义中,我们分别重写了构造函数和 _call 函数:对于构造函数,我们在对象实例化的一开始加载本地部署的 Qwen2 模型,从而避免每一次调用都需要重新加载模型带来的时间过长;_call 函数是 LLM 类的核心函数,LangChain 会调用该函数来调用 LLM,在该函数中,我们调用已实例化模型的 chat 方法,从而实现对模型的调用并返回调用结果。
在整体项目中,我们将上述代码封装为 LLM.py,后续将直接从该文件中引入自定义的 LLM 类。
1.2 调用
像使用任何其他的langchain大模型功能一样使用即可。
from LLM import Qwen2_LLM
llm = Qwen2_LLM(mode_name_or_path = "/root/autodl-tmp/qwen/Qwen1.5-7B-Chat")
llm("你是谁")
2. 远程调用代码准备
远程调用是指在服务器上部署了大模型,并通过Flask、FastAPI的方式给出了访问接口。我们需要将远程调用的代码和注册模型的代码融合。
2.1 原始访问代码
from openai import OpenAI
import requests
import json
def simple_chat(messages):
#print(messages[0])
response = client.chat.completions.create(
model="glm-4",
messages=messages,
stream=False,
max_tokens=256,
temperature=0.4,
presence_penalty=1.2,
top_p=0.8,
)
if response:
return json.loads(response.model_dump_json())['choices'][0]['message']['content']
else:
return ""
def glm4_function(prompt,sentence):
result = ""
prompt_dict = [
{
"role": "system",
"content":prompt,
},
{
"role": "user",
"content": sentence
},
]
try:
result = simple_chat(prompt_dict)
except Exception as e:
print(e)
return result
base_url = "http://..../v1/chatglm4"
client = OpenAI(api_key="EMPTY", base_url=base_url)
system_prompt = "you are a helpful assistant "
sentence = '你好,请介绍一下你自己'
temp_response = glm4_function(system_prompt,sentence)
print(temp_response)
2.2 融合后代码
from langchain.llms.base import LLM
from typing import Any, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
import requests
import json
from pydantic import BaseModel, Field
class Qwen2_LLM(LLM, BaseModel):
api_url: str = Field(..., description="远程API的URL")
@property
def _llm_type(self) -> str:
return "Qwen2_LLM_API"
def __init__(self, **data: Any):
super().__init__(**data)
print(f"API URL设置为: {self.api_url}")
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any) -> str:
"""
调用API并返回模型响应
"""
# 设置系统消息和用户输入
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
# 请求API
try:
response = requests.post(
f"{self.api_url}/chatglm4",
headers={"Authorization": "Bearer YOUR_API_KEY"}, # 替换为实际的API密钥
json={
"model": "glm-4",
"messages": messages,
"stream": False,
"max_tokens": 512,
"temperature": 0.4,
"presence_penalty": 1.2,
"top_p": 0.8
}
)
response.raise_for_status() # 检查HTTP错误
result = response.json()
# 提取生成的文本
return result['choices'][0]['message']['content']
except requests.exceptions.HTTPError as http_err:
print(f"HTTP错误发生: {http_err}")
except Exception as e:
print(f"请求API时发生错误: {e}")
return "" # 在出现错误时返回空字符串或其他适当的默认值
# 使用示例
if __name__ == "__main__":
llm = Qwen2_LLM(api_url="http://.../v1")
response = llm("你好,请介绍一下你自己")
print(response)
Reference:
1.self-llm/models/Qwen2 at master · datawhalechina/self-llm · GitHub
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)