うさラボ

お勉強と備忘録

【Langchain】エージェント経由でNetmikoを使ってネットワーク機器からログを取得する

概要

LLM(大規模言語モデル)を利用したアプリ開発フレームワークであるLangchainを使ってネットワーク機器からログを取得してみます

www.langchain.com

環境

Pythonライブラリ

langchain==0.1.16
langchain-core==0.1.46
langchain-cohere==0.1.4
python-dotenv==1.0.1
netmiko==4.3.0
typing==3.7.4.3

やれたこと

CohereのCommand R+モデルをつかってエージェントを作成し自作したツールを動かして質問に答えさせることができました。

Command R +はトライアルAPIキーを使っているため現時点(2024/4/29)では無料です cohere.com

コード

早速完成したコードを載せます。
StreamlitでWebApp化しているためGUI上で質問の入力と回答の確認をできるようにしました。

import streamlit as st

from langchain.agents import AgentExecutor
from langchain_cohere import ChatCohere
# from langchain_cohere.cohere_agent import create_cohere_tools_agent
from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent

from langchain_core.tools import BaseTool, tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnablePassthrough


from dotenv import load_dotenv
load_dotenv()

from netmiko import ConnectHandler
from typing import Type, List, Optional
from langchain.pydantic_v1 import BaseModel

# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-
# データ構造定義
class NetworkDeviceInfo(BaseModel):
    device_type: str
    host: str
    username: str
    password: str
    secret: Optional[str]

class Command(BaseModel):
    command: str

class CommandList(BaseModel):
    commands: List[Command]

class GetLogNetworkDeviceInput(BaseModel):
    device: NetworkDeviceInfo
    commands: CommandList

# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-
# NOTE: NetmikoでSSH接続してログを取得するツール
class GetLogNetworkDevice(BaseTool):
    name = "get_command_nw_devcie"
    description = "Access NW devices and execute commands with Netmiko"
    # args_schema: Type[BaseModel] = GetLogNetworkDeviceInput
    # return_direct: bool = True

    def _run(self, param: GetLogNetworkDeviceInput) -> List[str]:
        """Returns network device logs."""
        output = []
        device_data = param['device']
        command_list = param['commands']
        with ConnectHandler(**device_data) as net_connect:
            if 'secret' in device_data:
                if device_data['secret']:
                    net_connect.enable()
            for command in command_list:
                output.append(net_connect.send_command(command['command']))
        return output

    async def _arun(self, param: GetLogNetworkDeviceInput) -> str:
        """Use the tool asynchronously."""
        raise NotImplementedError("Calculator does not support async")

# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-
# NOTE: プロンプトテンプレート
command_template = """
・質問を解決するために実施すべきshowコマンドを出力してください
・回答はJSON形式で出力してください
    ・複数コマンドの変数サンプル:
        {{"commands": [{{"command": show run"}}, {{"command": show ip route"}}]}}
    ・単一コマンドの変数サンプル:
        {{"commands": [{{"command": show ip route"}}]}}

質問: {question}
"""

device_info_template = """
・質問を解決するためにログ取得が必要な対象機器情報を出力します
・回答はJSON形式で出力してください
    ```json
    {{ 
        "device": {{
            "device_type": "cisco_ix",
            "host": "192.0.2.0",
            "username": "test_user",
            "password": "password",
            "secret": "enable_password",  
        }}
    }}
    ```
・対応する値が存在しな場合も、空で出力してください
質問: {question}
"""

template = """
前提:
・日本語で回答してください
・必要に応じて、ツールを使い
  下記機器に接続して、情報取得をした結果を利用して回答してください
    - 機器情報 ----------------------
    {device_info}
    - 取得コマンド ----------------------
    {command_list}
    ----------------------------------
・実行コマンドは質問に対して最適なものを1から複数個生成して実行してください
・情報取得を実施した場合、実行したコマンドを出力に含めてください
・問題が特定できた場合は、解決方法について出力してください
・問題が解決できな場合、必要なアクションについて出力してください

質問: {question}
"""
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-

try:
    st.set_page_config(layout="wide")
    st.title("Cohere Agent Custom Tool")

    # NOTE: ツール定義
    tools = [GetLogNetworkDevice()]

    # NOTE: プロンプト定義
    prompt = ChatPromptTemplate.from_template(template)
    device_info_template_prompt = ChatPromptTemplate.from_template(device_info_template)
    command_prompt = ChatPromptTemplate.from_template(command_template)

    # NOTE: アウトプットパーサー定義
    output = StrOutputParser()
    command_parser = JsonOutputParser(pydantic_object=CommandList)
    dev_info_parser = JsonOutputParser(pydantic_object=NetworkDeviceInfo)

    # NOTE: LLM定義
    command_r = ChatCohere(model="command-r-plus", temperature=0)
    command_gen = command_prompt | command_r | command_parser
    dev_info_gen = device_info_template_prompt | command_r | dev_info_parser
    
    # NOTE: エージェント定義
    agent = create_cohere_react_agent(llm=command_r, tools=tools, prompt=prompt)
    agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

    chain = (
        {"device_info": dev_info_gen, "command_list": command_gen, "question": RunnablePassthrough()} 
        | agent_executor 
        )

    question = st.text_area("質問", value="ルートの統計情報教えて")
    if st.button("実行"):
        with st.spinner("生成中...."):
            result = chain.invoke({"question": question})
            st.write(result['output'])

except Exception as e:
    st.error(f"An error occurred: {e}")
    st.error("えらーっぽい")
"""

Streamlit起動

python -m streamlit run cohere_tool.py

質問実行

応答

エージェントがツールを実行している様子

コードの詳細

1. 各種ライブラリのインポート

初めに、必要なライブラリをインポートします
- langchainのライブラリ群 - 認証情報を扱うためのdotenv - SSHしてログをとるためのNetmiko

import streamlit as st

from langchain.agents import AgentExecutor
from langchain_cohere import ChatCohere
# from langchain_cohere.cohere_agent import create_cohere_tools_agent
from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent

from langchain_core.tools import BaseTool, tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnablePassthrough


from dotenv import load_dotenv
load_dotenv()

from netmiko import ConnectHandler
from typing import Type, List, Optional
from langchain.pydantic_v1 import BaseModel

2.自作ツールで利用するデータ構造の定義

自作ツールを動かすために、LLMにどんな変数構造で情報を引き渡す必要があるかを伝える必要があるようです、そのため引数はどんな型であるかpydanticを使い定義します。

# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-
# データ構造定義
class NetworkDeviceInfo(BaseModel):
    device_type: str
    host: str
    username: str
    password: str
    secret: Optional[str]

class Command(BaseModel):
    command: str

class CommandList(BaseModel):
    commands: List[Command]

class GetLogNetworkDeviceInput(BaseModel):
    device: NetworkDeviceInfo
    commands: CommandList

3.自作ツール部分

Netmikoを使い、ログを取得する処理部分を定義します。
リストを作成しコマンドの応答を要素として追加して戻すシンプルなものにしました。
args_schemaを有効化するとうまくいかず。。(pydantic周りの理解が足りていない..)今回は無効化することでなんと動きました。。

# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-
# NOTE: NetmikoでSSH接続してログを取得するツール
class GetLogNetworkDevice(BaseTool):
    name = "get_command_nw_devcie"
    description = "Access NW devices and execute commands with Netmiko"
    # args_schema: Type[BaseModel] = GetLogNetworkDeviceInput
    # return_direct: bool = True

    def _run(self, param: GetLogNetworkDeviceInput) -> List[str]:
        """Returns network device logs."""
        output = []
        device_data = param['device']
        command_list = param['commands']
        with ConnectHandler(**device_data) as net_connect:
            if 'secret' in device_data:
                if device_data['secret']:
                    net_connect.enable()
            for command in command_list:
                output.append(net_connect.send_command(command['command']))
        return output

    async def _arun(self, param: GetLogNetworkDeviceInput) -> str:
        """Use the tool asynchronously."""
        raise NotImplementedError("Calculator does not support async")

4.プロンプトの定義

LLMに質問をする際のプロンプトのテンプレートを定義しています。
今回は、以下3つの質問をすることでツールを動かす内容を決めるようなイメージにしています

  1. 実行コマンドの生成
  2. 認証情報の生成(今回はテキストのインプットから解析する形にした)
  3. ツール実行結果から質問の回答の生成
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-
# NOTE: プロンプトテンプレート
command_template = """
・質問を解決するために実施すべきshowコマンドを出力してください
・回答はJSON形式で出力してください
    ・複数コマンドの変数サンプル:
        {{"commands": [{{"command": show run"}}, {{"command": show ip route"}}]}}
    ・単一コマンドの変数サンプル:
        {{"commands": [{{"command": show ip route"}}]}}

質問: {question}
"""

device_info_template = """
・質問を解決するためにログ取得が必要な対象機器情報を出力します
・回答はJSON形式で出力してください
    ```json
    {{ 
        "device": {{
            "device_type": "cisco_ix",
            "host": "192.0.2.1",
            "username": "test_user",
            "password": "password",
            "secret": "enable_password",  
        }}
    }}
    ```
・対応する値が存在しな場合も、空で出力してください
質問: {question}
"""

template = """
前提:
・日本語で回答してください
・必要に応じて、ツールを使い
  下記機器に接続して、情報取得をした結果を利用して回答してください
    - 機器情報 ----------------------
    {device_info}
    - 取得コマンド ----------------------
    {command_list}
    ----------------------------------
・実行コマンドは質問に対して最適なものを1から複数個生成して実行してください
・情報取得を実施した場合、実行したコマンドを出力に含めてください
・問題が特定できた場合は、解決方法について出力してください
・問題が解決できな場合、必要なアクションについて出力してください

質問: {question}
"""

5.メイン処理部分

ツールやプロンプトなどを定義していきます
認証情報とコマンドは質問に対する回答をJSON形式にするためJsonOutputParserを利用しています。

chainで処理定義しており、以下のような順序で動作します
1. 認証情報,実行コマンドの生成
2. エージェント起動

エージェント内で利用するためのツールはtoolにリストで定義します。
エージェントはモデルごとに専用のものがあるようで、Cohereではcreate_cohere_react_agent を利用します(create_cohere_tools_agentといったものもあるようですが、うまく動かずでした。。)

try:
    st.set_page_config(layout="wide")
    st.title("Cohere Agent Custom Tool")

    # NOTE: ツール定義
    tools = [GetLogNetworkDevice()]

    # NOTE: プロンプト定義
    prompt = ChatPromptTemplate.from_template(template)
    device_info_template_prompt = ChatPromptTemplate.from_template(device_info_template)
    command_prompt = ChatPromptTemplate.from_template(command_template)

    # NOTE: アウトプットパーサー定義
    output = StrOutputParser()
    command_parser = JsonOutputParser(pydantic_object=CommandList)
    dev_info_parser = JsonOutputParser(pydantic_object=NetworkDeviceInfo)

    # NOTE: LLM定義
    command_r = ChatCohere(model="command-r-plus", temperature=0)
    command_gen = command_prompt | command_r | command_parser
    dev_info_gen = device_info_template_prompt | command_r | dev_info_parser
    
    # NOTE: エージェント定義
    agent = create_cohere_react_agent(llm=command_r, tools=tools, prompt=prompt)
    agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

    chain = (
        {"device_info": dev_info_gen, "command_list": command_gen, "question": RunnablePassthrough()} 
        | agent_executor 
        )

    question = st.text_area("質問", value="ルートの統計情報教えて")
    if st.button("実行"):
        with st.spinner("生成中...."):
            result = chain.invoke({"question": question})
            st.write(result['output'])

except Exception as e:
    st.error(f"An error occurred: {e}")
    st.error("えらーっぽい")

認証情報

認証情報はdotenvを利用したためコードには登場しませんが、スクリプトと同じ階層に.envファイルを作成しAPI KEY情報などを定義します

.envファイル

COHERE_API_KEY = "<API KEY>"

感想

何とか自作ツールを動かすことができました、ツール実行のために定義した pydantic周りの理解が足りたいように感じました。。

いろいろなツールを作って、高度な運用補助ツールみたいなもの作りたいなーってモチベーションで頑張ります

NetBoxやServiceNowから機器情報を引っ張ってきてSSHして原因報告する。とかとか想像すると楽しいですね

ただ、エージェント難しい・・・ なんもわからん。。。

理解が誤ってる箇所などあれば(優しく)ご指摘いただけたら嬉しいです。