今天執行基於 PyTorch 的圖像分類算法程序時,觸發了自己寫的斷言錯誤。而斷言的細節,就是判斷用戶輸入的 GPU 編號是否合法。
調試打開,發現 torch.cuda.device_count()
返回的是 1。而我機器上明明是兩張卡。
一臉懵逼。
查閱 PyTorch 官網后,發現是使用問題。我在調用 device_count 之前,已經設置過了環境變量 CUDA_VISIBLE_DEVICES
。
通過在 os.environ["CUDA_VISIBLE_DEVICES"]
代碼之前執行 device_count, 發現返回的是 2。至此,問題已定位。
PS. 官方推進使用 os.environ["CUDA_VISIBLE_DEVICES"]
的形式來設定使用的 GPU 顯卡。