Hello, PyTorch (1) | 环境配置
安装 Docker
Windows
Windows 系统不能直接安装使用 Docker,需要先安装 WSL2(Windows Subsystem Linux 2)。以下以 x86 架构的 Windows 10 系统为例,进行 WSL2 的安装流程说明。
检查系统版本(Windows 11 无需进行此步骤):打开设置→系统→关于页面,检查 Windows 版本号,应为 2004、20H2、21H1、21H2、22H2 中的任意一个,否则需要更新 Windows 版本。
安装 Windows Terminal(Windows 11 无需进行此步骤。此步骤非必须,但建议进行):访问 Windows Terminal | 微软应用商店 并下载 Windows Terminal 安装器,运行该安装器即可安装 Windows Terminal。
安装 WSL2:在开始菜单中找到上一步安装的 Windows Terminal(或「终端」),右键→更多→以管理员身份运行,然后在其中输入并按回车执行如下命令:
等待自动安装完成即可,过程中可能需要重启计算机。1
wsl --install
设置 WSL2:打开开始菜单,找到上一步安装的「Ubuntu」并打开,初次使用会要求设置用户名与密码(输入密码时,屏幕不会显示任何字符,这称为「盲人键入」,是完全正常的现象),密码建议妥善保管(如果忘记密码,可以在 PowerShell 中使用
wsl -u root
进入 root 账户,然后使用passwd <username>
命令为账户设置新密码。)。设置软件仓库镜像源:由于国内大陆地区网络问题,需要为 Ubuntu 的软件仓库设置国内镜像源。以清华镜像源为例,在 Ubuntu 中执行如下指令:
查看 Ubuntu 的版本号,在清华镜像站中找到并复制对应版本的代码。接着执行如下指令(注意 Ubuntu 24.04 版本及以上,需要将下面指令中的地址安装清华镜像源中的说明进行对应的替换):1
cat /etc/os-release
使用 vi 打开文件,直接输入 ggdG (区分大小写)删除全部内容,接着按 i 键进入插入模式,单击右键将复制的代码粘贴进去,随后按 ESC 键退出插入模式,再输入 :wq,按回车保存并退出 vi。镜像源配置完成后,在 Ubuntu 中执行如下命令更新软件。1
2sudo cp /etc/apt/sources.list /etc/apt/sources.list.bak # 备份 sudo vi /etc/apt/sources.list
1
2sudo apt update sudo apt upgrade
安装 Docker:按照 Linux 系统的安装步骤,在 WSL2 中安装 Docker 即可。
Linux
Linux 系统可以直接安装使用 docker-ce,由于国内大陆地区网络问题,需要通过国内镜像源安装。下面以清华镜像源与 Ubuntu 系统(其他 Linux 发行版可参考 docker-ce | 清华大学开源软件镜像站进行相关配置)为例,进行安装流程说明。
- 卸载旧版本(如有):在 bash 中输入并执行如下命令,下同:
1
sudo apt-get remove docker.io docker-doc docker-compose podman-docker containerd runc
- 安装依赖:
1
2sudo apt-get update sudo apt-get install ca-certificates curl gnupg
- 下载 GPG key:
1
2sudo mkdir -m 0755 -p /etc/apt/keyrings curl -fsSL https://mirrors.tuna.tsinghua.edu.cn/docker-ce/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
- 添加镜像源软件仓库:
1
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://mirrors.tuna.tsinghua.edu.cn/docker-ce/linux/ubuntu"$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
- 安装 docker-ce:
1
2
3sudo apt-get update sudo apt-get install docker-ce docker-ce-cli containerd.io \ docker-buildx-plugin docker-compose-plugin
- 配置 Docker 镜像仓库的镜像:由于国内大陆地区网络问题,需要为 Docker Hub 配置镜像或代理。目前较正规的镜像(即各大高校或腾讯、阿里、字节等大企业建立的镜像源)均已关停,需要自行寻找可用镜像源,此处使用目前可用的一个镜像 https://docker.1panel.live/ 讲解如何配置 Docker Hub 镜像。对于其他来源的镜像,如 nvcr.io、ghcr.io 等来源的镜像,目前南京大学仍有可用镜像源。同时也可参考附录自建 Docker 镜像使用。同样使用 vi 打开 docker 的配置文件 /etc/docker/daemon.json 通过方向键将光标移动到最外层的大括号对(
1
sudo vi /etc/docker/daemon.json
{}
)之间,按 i 键进入插入模式,输入如下内容(如果已有"registry-mirrors"
,则换成修改):随后按 ESC 键退出插入模式,再输入 :wq,按回车保存并退出。1
"registry-mirrors": ["https://docker.1panel.live"]
安装 PyTorch
安装完成 Docker 后,就可以开始配置 PyTorch 环境了。创建一个合适的目录用于存放 Docker 镜像,以下以 ~/workspace/pinn 作为示例目录进行讲解。
在 Docker 中进行 Python 开发需要一个可以连接到容器内进行开发的集成开发环境(IDE),如 VSCode(Visual Studio Code) 或 Pycharm。如果不想安装集成开发环境,也可使用 Jupyter Lab 在浏览器中进行开发(但建议使用集成开发环境,而非 Jupyter)。
安装 CUDA 与 CuDNN
本节仅 Nvidia 显卡需要进行,默认系统已安装 Nvidia 显卡驱动。在终端中执行
1 | nvidia-smi |
如果 CUDA 为最新版(目前最新版为 12.6),可以直接在 cuDNN Downloads | NVIDIA Developer 下载最新版的 CuDNN,否则需要在 cuDNN Archive | NVIDIA Developer 处下载对应 CUDA 版本的 CuDNN,注意需要注册一个账号才能下载。与 CUDA 相同,下载页面也提供了不同系统的不同版本安装下载方式。
对 Windows 系统来说,此时在 WSL 中就已经可以使用 nvidia-smi
命令显示穿透到 WSL 中的显卡信息了。
构建镜像
新建一个文件 ~/workspace/pytorch/dockerfile,由于此时不再有 Linux 系统权限限制,Windows 系统可以使用
1 | notepad.exe ~/workspace/pytorch/dockerfile |
notepad.exe
改为 code
);Linux 系统仍可使用 vi 或 vim 进行编辑。 由于不同的电脑配置需要使用不同的镜像,因此以下将分类讲解 Nvidia 显卡、AMD 显卡与纯 CPU 三种不同配置下构建镜像的方式。此处使用前人已构建好的镜像。
- Nvidia 显卡:Nvidia 显卡可以直接使用 PyTorch 提供的 docker 镜像作为基础进行开发。在 dockerfile 文件中输入如下内容:
1
2
3
4
5
6FROM cnstark/pytorch:2.3.1-py3.10.15-cuda12.1.0-ubuntu22.04 WORKDIR /workspace COPY requirements.txt requirements.txt RUN pip install -r requirements.txt -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple RUN rm -r /temp && rm -r /root/.cache/pip
- AMD 显卡:AMD 显卡可以使用 ROCm 来进行 PyTorch 开发(需要参考 WSL How to guide - Use ROCm on Radeon GPUs — Use ROCm on Radeon GPUs 进行相关配置)。对上文中 dockerfile 文件中
FROM
部分修改为:1
FROM rocm/pytorch:rocm6.1.3_ubuntu22.04_py3.10_pytorch_release-2.1.2
- 纯 CPU:将 dockerfile 文件中 FROM 修改为:
1
FROM cnstark/pytorch:2.3.1-py3.10.15-ubuntu22.04
如果使用 Jupyter Lab,则需要在 dockerfile 末尾添加两行:
1 | EXPOSE 8888 ENTRYPOINT ["jupyter","lab","--ip=0.0.0.0","--allow-root","--no-browser"] |
接着在同一目录下新建 requirements.txt 文件,这个文件里是其他需要的 Python 包,例如:
1 | # Jupyter Lab 配置 # 使用 Pycharm/VSCode/其他 IDE 编辑代码,这部分可以全部注释掉 jupyterlab # Jupyter Lab 本体 jupyterlab-language-pack-zh-CN # Jupyter Lab 中文语言包 jupyterlab-lsp # Jupyter Lab 语言服务器(LSP)支持 jedi-language-server # Jedi 语言服务器 # 其他需要的包 ipykernel # 运行 Jupyter Notebook 的核心包,使用 VSCode 时需要 scipy # 提供一些实用函数 pandas # 数据处理包,可以注释掉 matplotlib # 绘图包,如果需要导出数据用其他软件绘图,可以注释掉 |
最后再在统一目录下创建 docker-compose.yaml,根据创建 dockerfile 时的不同,分别填入如下内容:
- Nvidia 显卡:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15services: nvidia-pytorch: build: . ipc: host volumes: - ./data:/workspace deploy: resources: reservations: devices: - driver: nvidia count: 1 capabilities: [gpu] tty: true stdin_open: true
- AMD 显卡: 需要注意的是以上为 Windows 系统下的 docker-compose.ymal 文件,如果是 Linux,则需删除 volumes 中 workspace 以外的两项,同时将 devices 修改为:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17services: rocm-pytorch: build: . cap_add: - SYS_PTRACE security_opt: - seccomp=unconfined ipc: host shm_size: 8G devices: - /dev/dxg volumes: - ./data:/workspace - /usr/lib/wsl/lib/libdxcore.so:/usr/lib/libdxcore.so - /opt/rocm/lib/libhsa-runtime64.so.1:/opt/rocm/lib/libhsa-runtime64.so.1 tty: true stdin_open: true
1
2
3devices: - /dev/kfd - /dev/dri
- 纯 CPU:
1
2
3
4
5
6
7services: pytorch: build: . volumes: - ./data:/workspace tty: true stdin_open: true
如果使用 Jupyter Lab,则需要在 volumes 前新增两行:
1 | ports: - "8888:8888" |
最后在 Linux/WSL 终端中执行
1 | docker-compose up -d |
附录
自建 Docker 镜像
Cloudflare Worker 搭建 Github 与 Docker 加速,来自用GPT融了一个Cloudflare Workers的github下载 + Docke pull加速 - 开发调优 - LINUX DO。建议自行修改部分代码(比如用 LLM 洗一遍),以防有大量相似代码的 Worker 导致被 Cloudflare 认定为滥用 Worker。
1 | 'use strict'; const HUB_HOST = 'registry-1.docker.io'; const AUTH_URL = 'https://auth.docker.io'; const WORKERS_URL = 'https://你的域名'; const ASSET_URL = 'https://hunshcn.github.io/gh-proxy/'; const PREFIX = '/'; const Config = { jsdelivr: 0 }; const whiteList = []; const exp1 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:releases|archive)\/.*$/i; const exp2 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:blob|raw)\/.*$/i; const exp3 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/(?:info|git-).*$/i; const exp4 = /^(?:https?:\/\/)?raw\.(?:githubusercontent|github)\.com\/.+?\/.+?\/.+?\/.+$/i; const exp5 = /^(?:https?:\/\/)?gist\.(?:githubusercontent|github)\.com\/.+?\/.+?\/.+$/i; const exp6 = /^(?:https?:\/\/)?github\.com\/.+?\/.+?\/tags.*$/i; /** @type {RequestInit} */ const PREFLIGHT_INIT = { // @ts-ignore status: 204, headers: new Headers({ 'access-control-allow-origin': '*', 'access-control-allow-methods': 'GET, POST, PUT, PATCH, TRACE, DELETE, HEAD, OPTIONS', 'access-control-max-age': '1728000', }), }; /** * Create a new response. * @param {any} body * @param {number} [status=200] * @param {Object<string, string>} headers * @returns {Response} */ function makeResponse(body, status = 200, headers = {}) { headers['access-control-allow-origin'] = '*'; return new Response(body, { status, headers }); } /** * Create a new URL object. * @param {string} urlStr * @returns {URL|null} */ function createURL(urlStr) { try { return new URL(urlStr); } catch (err) { return null; } } addEventListener('fetch', (event) => { event.respondWith(handleFetchEvent(event).catch(err => makeResponse(`cfworker error:\n${err.stack}`, 502))); }); /** * Handle the fetch event. * @param {FetchEvent} event * @returns {Promise<Response>} */ async function handleFetchEvent(event) { const req = event.request; const url = new URL(req.url); if (url.pathname.startsWith('/token') || url.pathname.startsWith('/v2')) { return handleDockerProxy(req, url); } if (url.pathname.startsWith(PREFIX)) { return handleGitHubProxy(req, url); } return makeResponse('Not Found', 404); } /** * Handle token requests and Docker proxy. * @param {Request} req * @param {URL} url * @returns {Promise<Response>} */ async function handleDockerProxy(req, url) { if (url.pathname === '/token') { const tokenURL = AUTH_URL + url.pathname + url.search; const headers = new Headers({ 'Host': 'auth.docker.io', 'User-Agent': req.headers.get('User-Agent'), 'Accept': req.headers.get('Accept'), 'Accept-Language': req.headers.get('Accept-Language'), 'Accept-Encoding': req.headers.get('Accept-Encoding'), 'Connection': 'keep-alive', 'Cache-Control': 'max-age=0' }); return fetch(new Request(tokenURL, req), { headers }); } url.hostname = HUB_HOST; const headers = new Headers({ 'Host': HUB_HOST, 'User-Agent': req.headers.get('User-Agent'), 'Accept': req.headers.get('Accept'), 'Accept-Language': req.headers.get('Accept-Language'), 'Accept-Encoding': req.headers.get('Accept-Encoding'), 'Connection': 'keep-alive', 'Cache-Control': 'max-age=0' }); if (req.headers.has('Authorization')) { headers.set('Authorization', req.headers.get('Authorization')); } const response = await fetch(new Request(url, req), { headers }); const responseHeaders = new Headers(response.headers); const status = response.status; if (responseHeaders.get('Www-Authenticate')) { const authHeader = responseHeaders.get('Www-Authenticate'); const re = new RegExp(AUTH_URL, 'g'); responseHeaders.set('Www-Authenticate', authHeader.replace(re, WORKERS_URL)); } if (responseHeaders.get('Location')) { return handleHttpRedirect(req, responseHeaders.get('Location')); } responseHeaders.set('access-control-expose-headers', '*'); responseHeaders.set('access-control-allow-origin', '*'); responseHeaders.set('Cache-Control', 'max-age=1500'); responseHeaders.delete('Content-Security-Policy'); responseHeaders.delete('Content-Security-Policy-Report-Only'); responseHeaders.delete('Clear-Site-Data'); return new Response(response.body, { status, headers: responseHeaders }); } /** * Handle GitHub proxy requests. * @param {Request} req * @param {URL} url * @returns {Promise<Response>} */ async function handleGitHubProxy(req, url) { let path = url.searchParams.get('q'); if (path) { return Response.redirect('https://' + url.host + PREFIX + path, 301); } path = url.href.substr(url.origin.length + PREFIX.length).replace(/^https?:\/+/, 'https://'); if (checkUrl(path)) { return httpHandler(req, path); } else if (path.search(exp2) === 0) { if (Config.jsdelivr) { const newUrl = path.replace('/blob/', '@').replace(/^(?:https?:\/\/)?github\.com/, 'https://cdn.jsdelivr.net/gh'); return Response.redirect(newUrl, 302); } else { path = path.replace('/blob/', '/raw/'); return httpHandler(req, path); } } else if (path.search(exp4) === 0) { const newUrl = path.replace(/(?<=com\/.+?\/.+?)\/(.+?\/)/, '@$1').replace(/^(?:https?:\/\/)?raw\.(?:githubusercontent|github)\.com/, 'https://cdn.jsdelivr.net/gh'); return Response.redirect(newUrl, 302); } else { return fetch(ASSET_URL + path); } } /** * Check if the URL matches GitHub patterns. * @param {string} url * @returns {boolean} */ function checkUrl(url) { return [exp1, exp2, exp3, exp4, exp5, exp6].some(exp => url.search(exp) === 0); } /** * Handle HTTP redirects. * @param {Request} req * @param {string} location * @returns {Promise<Response>} */ async function handleHttpRedirect(req, location) { const url = createURL(location); if (!url) { return makeResponse('Invalid URL', 400); } return proxyRequest(url, req); } /** * Handle HTTP requests. * @param {Request} req * @param {string} pathname * @returns {Promise<Response>} */ async function httpHandler(req, pathname) { if (req.method === 'OPTIONS' && req.headers.has('access-control-request-headers')) { return new Response(null, PREFLIGHT_INIT); } const headers = new Headers(req.headers); let flag = !whiteList.length; for (const i of whiteList) { if (pathname.includes(i)) { flag = true; break; } } if (!flag) { return new Response('blocked', { status: 403 }); } if (pathname.search(/^https?:\/\//) !== 0) { pathname = 'https://' + pathname; } const url = createURL(pathname); return proxyRequest(url, { method: req.method, headers, body: req.body }); } /** * Proxy a request. * @param {URL} url * @param {RequestInit} reqInit * @returns {Promise<Response>} */ async function proxyRequest(url, reqInit) { const response = await fetch(url.href, reqInit); const responseHeaders = new Headers(response.headers); if (responseHeaders.has('location')) { const location = responseHeaders.get('location'); if (checkUrl(location)) { responseHeaders.set('location', PREFIX + location); } else { reqInit.redirect = 'follow'; return proxyRequest(createURL(location), reqInit); } } responseHeaders.set('access-control-expose-headers', '*'); responseHeaders.set('access-control-allow-origin', '*'); responseHeaders.delete('content-security-policy'); responseHeaders.delete('content-security-policy-report-only'); responseHeaders.delete('clear-site-data'); return new Response(response.body, { status: response.status, headers: responseHeaders, }); } |