flash_attn安装

我们在安装flash_attn库的时候,如果直接是使用pip install,老是容易出错卡住,等多久都没有成功。

$ pip install flash-attn==2.7.4.post1
Looking in indexes: https://mirrors.ustc.edu.cn/pypi/web/simple
Collecting flash-attn==2.7.4.post1
  Using cached https://mirrors.ustc.edu.cn/pypi/packages/11/34/9bf60e736ed7bbe15055ac2dab48ec67d9dbd088d2b4ae318fd77190ab4e/flash_attn-2.7.4.post1.tar.gz (6.0 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... error
  error: subprocess-exited-with-error
  
  × Getting requirements to build wheel did not run successfully.
  │ exit code: 1
  ╰─> [20 lines of output]
      Traceback (most recent call last):
        File "/home/lewis/miniconda3/envs/infinitevl/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 389, in <module>
          main()
        File "/home/lewis/miniconda3/envs/infinitevl/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 373, in main
          json_out["return_val"] = hook(**hook_input["kwargs"])
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/lewis/miniconda3/envs/infinitevl/lib/python3.11/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 143, in get_requires_for_build_wheel
          return hook(config_settings)
                 ^^^^^^^^^^^^^^^^^^^^^
        File "/tmp/pip-build-env-ohl53xgr/overlay/lib/python3.11/site-packages/setuptools/build_meta.py", line 331, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=[])
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/tmp/pip-build-env-ohl53xgr/overlay/lib/python3.11/site-packages/setuptools/build_meta.py", line 301, in _get_build_requires
          self.run_setup()
        File "/tmp/pip-build-env-ohl53xgr/overlay/lib/python3.11/site-packages/setuptools/build_meta.py", line 512, in run_setup
          super().run_setup(setup_script=setup_script)
        File "/tmp/pip-build-env-ohl53xgr/overlay/lib/python3.11/site-packages/setuptools/build_meta.py", line 317, in run_setup
          exec(code, locals())
        File "<string>", line 22, in <module>
      ModuleNotFoundError: No module named 'torch'
      [end of output]

这里给出两种安装方式

第一种安装方式:

访问该网站:

https://github.com/Dao-AILab/flash-attention/releases/

找到对应torch、python、cuda版本的flash_attn进行下载,并上传到服务器

# 例如python3.10 torch2.4 cuda12
pip install flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

可以在环境中运行如下指令判断是选择TRUE还是FALSE。

python  -c "import torch; print(torch._C._GLIBCXX_USE_CXX11_ABI)"

第二种安装方式:

同样是要访问网站:

https://github.com/Dao-AILab/flash-attention/releases/

查看要安装的版本:一般是选择最新的版本:比如是2.7.4.post1。将下面脚本中的版本号修改一下:

flash_attn_version=“2.7.4.post1”

import platform
import sys
 
import torch
 
 
def get_cuda_version():
    if torch.cuda.is_available():
        cuda_version = torch.version.cuda
        return f"cu{cuda_version.replace('.', '')[:2]}"  # 例如:cu121
    return "cpu"
 
 
def get_torch_version():
    return f"torch{torch.__version__.split('+')[0]}"[:-2]  # 例如:torch2.2
 
 
def get_python_version():
    version = sys.version_info
    return f"cp{version.major}{version.minor}"  # 例如:cp310
 
 
def get_abi_flag():
    return "abiTRUE" if torch._C._GLIBCXX_USE_CXX11_ABI else "abiFALSE"
 
 
def get_platform():
    system = platform.system().lower()
    machine = platform.machine().lower()
    if system == "linux" and machine == "x86_64":
        return "linux_x86_64"
    elif system == "windows" and machine == "amd64":
        return "win_amd64"
    elif system == "darwin" and machine == "x86_64":
        return "macosx_x86_64"
    else:
        raise ValueError(f"Unsupported platform: {system}_{machine}")
 
 
def generate_flash_attn_filename(flash_attn_version="2.7.4.post1"):
    cuda_version = get_cuda_version()
    torch_version = get_torch_version()
    python_version = get_python_version()
    abi_flag = get_abi_flag()
    platform_tag = get_platform()
 
    filename = (
        f"flash_attn-{flash_attn_version}+{cuda_version}{torch_version}cxx11{abi_flag}-"
        f"{python_version}-{python_version}-{platform_tag}.whl"
    )
    return filename
 
 
if __name__ == "__main__":
    try:
        filename = generate_flash_attn_filename()
        print("Generated filename:", filename)
    except Exception as e:
        print("Error generating filename:", e)

然后在环境中运行这个脚本会自动得到你的环境对应的版本,例如:

$ python test.py 
Generated filename: flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl

然后访问网站(https://github.com/Dao-AILab/flash-attention/releases/)下载对应的版本名字即可。

上传服务器,安装方式同一。

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐