mirror of
https://github.com/deepfakes/faceswap
synced 2025-06-07 10:43:27 -04:00
* Remove tensorflow_probability requirement * setup.py - fix progress bars * requirements.txt: Remove pre python 3.9 packages * update apple requirements.txt * update INSTALL.md * Remove python<3.9 code * setup.py - fix Windows Installer * typing: python3.9 compliant * Update pytest and readthedocs python versions * typing fixes * Python Version updates - Reduce max version to 3.10 - Default to 3.10 in installers - Remove incompatible 3.11 tests * Update dependencies * Downgrade imageio dep for Windows * typing: merge optional unions and fixes * Updates - min python version 3.10 - typing to python 3.10 spec - remove pre-tf2.10 code - Add conda tests * train: re-enable optimizer saving * Update dockerfiles * Update setup.py - Apple Conda deps to setup.py - Better Cuda + dependency handling * bugfix: Patch logging to prevent Autograph errors * Update dockerfiles * Setup.py - Setup.py - stdout to utf-8 * Add more OSes to github Actions * suppress mac-os end to end test
626 lines
23 KiB
Python
626 lines
23 KiB
Python
#!/usr/bin/env python3
|
|
""" Collects and returns Information on DirectX 12 hardware devices for DirectML. """
|
|
from __future__ import annotations
|
|
import os
|
|
import sys
|
|
import typing as T
|
|
assert sys.platform == "win32"
|
|
|
|
import ctypes
|
|
from ctypes import POINTER, Structure, windll
|
|
from dataclasses import dataclass
|
|
from enum import Enum, IntEnum
|
|
|
|
from comtypes import COMError, IUnknown, GUID, STDMETHOD, HRESULT # pylint:disable=import-error
|
|
|
|
from ._base import _GPUStats
|
|
|
|
if T.TYPE_CHECKING:
|
|
from collections.abc import Callable
|
|
|
|
# Monkey patch default ctypes.c_uint32 value to Enum ctypes property for easier tracking of types
|
|
# We can't just subclass as the attribute will be assumed to be part of the Enumeration, so we
|
|
# attach it directly and suck up the typing errors.
|
|
setattr(Enum, "ctype", ctypes.c_uint32)
|
|
|
|
|
|
#############################
|
|
# CTYPES SUPPORTING OBJECTS #
|
|
#############################
|
|
# GUIDs
|
|
@dataclass
|
|
class LookupGUID:
|
|
""" GUIDs that are required for creating COM objects which are used and discarded.
|
|
|
|
Reference
|
|
---------
|
|
https://learn.microsoft.com/en-us/windows/win32/api/d3d12/nn-d3d12-id3d12device2
|
|
"""
|
|
IDXGIDevice = GUID("{54ec77fa-1377-44e6-8c32-88fd5f44c84c}")
|
|
ID3D12Device = GUID("{189819f1-1db6-4b57-be54-1821339b85f7}")
|
|
|
|
|
|
# ENUMS
|
|
class DXGIGpuPreference(IntEnum):
|
|
""" The preference of GPU for the app to run on.
|
|
|
|
Reference
|
|
---------
|
|
https://learn.microsoft.com/en-us/windows/win32/api/dxgi1_6/ne-dxgi1_6-dxgi_gpu_preference
|
|
"""
|
|
DXGI_GPU_PREFERENCE_UNSPECIFIED = 0
|
|
DXGI_GPU_PREFERENCE_MINIMUM_POWER = 1
|
|
DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE = 2
|
|
|
|
|
|
class DXGIAdapterFlag(IntEnum):
|
|
""" Identifies the type of DXGI adapter.
|
|
|
|
Reference
|
|
---------
|
|
https://learn.microsoft.com/en-us/windows/win32/api/dxgi/ne-dxgi-dxgi_adapter_flag
|
|
"""
|
|
DXGI_ADAPTER_FLAG_NONE = 0
|
|
DXGI_ADAPTER_FLAG_REMOTE = 1
|
|
DXGI_ADAPTER_FLAG_SOFTWARE = 2
|
|
DXGI_ADAPTER_FLAG_FORCE_DWORD = 0xffffffff
|
|
|
|
|
|
class DXGIMemorySegmentGroup(IntEnum):
|
|
""" Constants that specify an adapter's memory segment grouping.
|
|
|
|
Reference
|
|
---------
|
|
https://learn.microsoft.com/en-us/windows/win32/api/dxgi1_4/ne-dxgi1_4-dxgi_memory_segment_group
|
|
"""
|
|
DXGI_MEMORY_SEGMENT_GROUP_LOCAL = 0
|
|
DXGI_MEMORY_SEGMENT_GROUP_NON_LOCAL = 1
|
|
|
|
|
|
class D3DFeatureLevel(Enum):
|
|
""" Describes the set of features targeted by a Direct3D device.
|
|
|
|
Reference
|
|
---------
|
|
https://learn.microsoft.com/en-us/windows/win32/api/d3dcommon/ne-d3dcommon-d3d_feature_level
|
|
"""
|
|
D3D_FEATURE_LEVEL_1_0_CORE = 0x1000
|
|
D3D_FEATURE_LEVEL_9_1 = 0x9100
|
|
D3D_FEATURE_LEVEL_9_2 = 0x9200
|
|
D3D_FEATURE_LEVEL_9_3 = 0x9300
|
|
D3D_FEATURE_LEVEL_10_0 = 0xa000
|
|
D3D_FEATURE_LEVEL_10_1 = 0xa100
|
|
D3D_FEATURE_LEVEL_11_0 = 0xb000
|
|
D3D_FEATURE_LEVEL_11_1 = 0xb100
|
|
D3D_FEATURE_LEVEL_12_0 = 0xc000
|
|
D3D_FEATURE_LEVEL_12_1 = 0xc100
|
|
D3D_FEATURE_LEVEL_12_2 = 0xc200
|
|
|
|
|
|
class VendorID(Enum):
|
|
""" DirectX VendorID Enum """
|
|
AMD = 0x1002
|
|
NVIDIA = 0x10DE
|
|
MICROSOFT = 0x1414
|
|
QUALCOMM = 0x4D4F4351
|
|
INTEL = 0x8086
|
|
|
|
|
|
# STRUCTS
|
|
class StructureRepr(Structure): # pylint:disable=too-few-public-methods
|
|
""" Override the standard structure class to add a useful __repr__ for logging """
|
|
def __repr__(self) -> str:
|
|
""" Output the class name and the structure contents """
|
|
content = ["=".join([field[0], str(getattr(self, field[0]))])
|
|
for field in self._fields_]
|
|
if self.__dict__: # Add manually added parameters
|
|
content.extend("=".join([key, str(val)]) for key, val in self.__dict__.items())
|
|
return f"{self.__class__.__name__}({', '.join(content)})"
|
|
|
|
|
|
class LUID(StructureRepr): # pylint:disable=too-few-public-methods
|
|
""" Local Identifier for an adaptor
|
|
|
|
Reference
|
|
---------
|
|
https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-luid """
|
|
_fields_ = [("LowPart", ctypes.c_ulong), ("HighPart", ctypes.c_long)]
|
|
|
|
|
|
class DriverVersion(StructureRepr): # pylint:disable=too-few-public-methods
|
|
""" Stucture (based off LARGE_INTEGER) to hold the driver version
|
|
|
|
Reference
|
|
---------
|
|
https://docs.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-large_integer-r1"""
|
|
_fields_ = [("parts_a", ctypes.c_uint16),
|
|
("parts_b", ctypes.c_uint16),
|
|
("parts_c", ctypes.c_uint16),
|
|
("parts_d", ctypes.c_uint16)]
|
|
|
|
|
|
class DXGIAdapterDesc1(StructureRepr): # pylint:disable=too-few-public-methods
|
|
""" Describes an adapter (or video card) using DXGI 1.1
|
|
|
|
Reference
|
|
---------
|
|
https://learn.microsoft.com/en-us/windows/win32/api/dxgi/ns-dxgi-DXGIAdapterDesc1 """
|
|
_fields_ = [
|
|
("Description", ctypes.c_wchar * 128),
|
|
("VendorId", ctypes.c_uint),
|
|
("DeviceId", ctypes.c_uint),
|
|
("SubSysId", ctypes.c_uint),
|
|
("Revision", ctypes.c_uint),
|
|
("DedicatedVideoMemory", ctypes.c_size_t),
|
|
("DedicatedSystemMemory", ctypes.c_size_t),
|
|
("SharedSystemMemory", ctypes.c_size_t),
|
|
("AdapterLuid", LUID),
|
|
("Flags", DXGIAdapterFlag.ctype)] # type:ignore[attr-defined] # pylint: disable=no-member
|
|
|
|
|
|
class DXGIQueryVideoMemoryInfo(StructureRepr): # pylint:disable=too-few-public-methods
|
|
""" Describes the current video memory budgeting parameters.
|
|
|
|
Reference
|
|
---------
|
|
https://learn.microsoft.com/en-us/windows/win32/api/dxgi1_4/ns-dxgi1_4-dxgi_query_video_memory_info
|
|
"""
|
|
_fields_ = [("Budget", ctypes.c_uint64),
|
|
("CurrentUsage", ctypes.c_uint64),
|
|
("AvailableForReservation", ctypes.c_uint64),
|
|
("CurrentReservation", ctypes.c_uint64)]
|
|
|
|
|
|
# COM OBjects
|
|
class IDXObject(IUnknown): # pylint:disable=too-few-public-methods
|
|
""" Base interface for all DXGI objects.
|
|
|
|
Reference
|
|
---------
|
|
https://learn.microsoft.com/en-us/windows/win32/api/dxgi/nn-dxgi-idxgiobject
|
|
"""
|
|
_iid_ = GUID("{aec22fb8-76f3-4639-9be0-28eb43a67a2e}")
|
|
_methods_ = [STDMETHOD(HRESULT, "SetPrivateData",
|
|
[GUID, ctypes.c_uint, POINTER(ctypes.c_void_p)]),
|
|
STDMETHOD(HRESULT, "SetPrivateDataInterface", [GUID, POINTER(IUnknown)]),
|
|
STDMETHOD(HRESULT, "GetPrivateData",
|
|
[GUID, POINTER(ctypes.c_uint), POINTER(ctypes.c_void_p)]),
|
|
STDMETHOD(HRESULT, "GetParent", [GUID, POINTER(POINTER(ctypes.c_void_p))])]
|
|
|
|
|
|
class IDXGIFactory6(IDXObject): # pylint:disable=too-few-public-methods
|
|
""" Implements methods for generating DXGI objects
|
|
|
|
Reference
|
|
---------
|
|
https://learn.microsoft.com/en-us/windows/win32/api/dxgi/nn-dxgi-idxgifactory
|
|
"""
|
|
_iid_ = GUID("{c1b6694f-ff09-44a9-b03c-77900a0a1d17}")
|
|
|
|
_methods_ = [STDMETHOD(HRESULT, "EnumAdapters"), # IDXGIFactory
|
|
STDMETHOD(HRESULT, "MakeWindowAssociation"),
|
|
STDMETHOD(HRESULT, "GetWindowAssociation"),
|
|
STDMETHOD(HRESULT, "CreateSwapChain"),
|
|
STDMETHOD(HRESULT, "CreateSoftwareAdapter"),
|
|
STDMETHOD(HRESULT, "EnumAdapters1"), # IDXGIFactory1
|
|
STDMETHOD(ctypes.c_bool, "IsCurrent"),
|
|
STDMETHOD(ctypes.c_bool, "IsWindowedStereoEnabled"), # IDXGIFactory2
|
|
STDMETHOD(HRESULT, "CreateSwapChainForHwnd"),
|
|
STDMETHOD(HRESULT, "CreateSwapChainForCoreWindow"),
|
|
STDMETHOD(HRESULT, "GetSharedResourceAdapterLuid"),
|
|
STDMETHOD(HRESULT, "RegisterStereoStatusWindow"),
|
|
STDMETHOD(HRESULT, "RegisterStereoStatusEvent"),
|
|
STDMETHOD(None, "UnregisterStereoStatus"),
|
|
STDMETHOD(HRESULT, "RegisterOcclusionStatusWindow"),
|
|
STDMETHOD(HRESULT, "RegisterOcclusionStatusEvent"),
|
|
STDMETHOD(None, "UnregisterOcclusionStatus"),
|
|
STDMETHOD(HRESULT, "CreateSwapChainForComposition"),
|
|
STDMETHOD(ctypes.c_uint, "GetCreationFlags"), # IDXGIFactory3
|
|
STDMETHOD(HRESULT, "EnumAdapterByLuid", # IDXGIFactory4
|
|
[LUID, GUID, POINTER(POINTER(ctypes.c_void_p))]),
|
|
STDMETHOD(HRESULT, "EnumWarpAdapter"),
|
|
STDMETHOD(HRESULT, "CheckFeatureSupport"), # IDXGIFactory5
|
|
STDMETHOD(HRESULT, # IDXGIFactory6
|
|
"EnumAdapterByGpuPreference",
|
|
[ctypes.c_uint,
|
|
DXGIGpuPreference.ctype, # type:ignore[attr-defined] # pylint:disable=no-member # noqa:E501
|
|
GUID,
|
|
POINTER(ctypes.c_void_p)])]
|
|
|
|
|
|
class IDXGIAdapter3(IDXObject): # pylint:disable=too-few-public-methods
|
|
""" Represents a display sub-system (including one or more GPU's, DACs and video memory).
|
|
|
|
Reference
|
|
---------
|
|
https://learn.microsoft.com/en-us/windows/win32/api/dxgi1_4/nn-dxgi1_4-idxgiadapter3
|
|
"""
|
|
_iid_ = GUID("{645967a4-1392-4310-a798-8053ce3e93fd}")
|
|
_methods_ = [STDMETHOD(HRESULT, "EnumOutputs"), # v1.0 Methods
|
|
STDMETHOD(HRESULT, "GetDesc"),
|
|
STDMETHOD(HRESULT, "CheckInterfaceSupport", # v1.1 Methods
|
|
[GUID, POINTER(DriverVersion)]),
|
|
STDMETHOD(HRESULT, "GetDesc1", [POINTER(DXGIAdapterDesc1)]),
|
|
STDMETHOD(HRESULT, "GetDesc2"), # v1.2 Methods
|
|
STDMETHOD(HRESULT, # v1.3 Methods
|
|
"RegisterHardwareContentProtectionTeardownStatusEvent"),
|
|
STDMETHOD(None, "UnregisterHardwareContentProtectionTeardownStatus"),
|
|
STDMETHOD(HRESULT,
|
|
"QueryVideoMemoryInfo",
|
|
[ctypes.c_uint,
|
|
DXGIMemorySegmentGroup.ctype, # type:ignore[attr-defined] # pylint:disable=no-member # noqa:E501
|
|
POINTER(DXGIQueryVideoMemoryInfo)]),
|
|
STDMETHOD(HRESULT, "SetVideoMemoryReservation"),
|
|
STDMETHOD(HRESULT, "RegisterVideoMemoryBudgetChangeNotificationEvent"),
|
|
STDMETHOD(None, "UnregisterVideoMemoryBudgetChangeNotification")]
|
|
|
|
|
|
###########################
|
|
# PYTHON COLLATED OBJECTS #
|
|
###########################
|
|
@dataclass
|
|
class Device:
|
|
""" Holds information about a device attached to an adapter.
|
|
|
|
Parameters
|
|
----------
|
|
description: :class:`DXGIAdapterDesc1`
|
|
The information returned from DXGI.dll about the device
|
|
driver_version: str
|
|
The driver version of the device
|
|
local_mem: :class:`DXGIQueryVideoMemoryInfo`
|
|
The amount of local memory currently available
|
|
non_local_mem: :class:`DXGIQueryVideoMemoryInfo`
|
|
The amount of non-local memory currently available
|
|
is_d3d12: bool
|
|
``True`` if the device supports DirectX12
|
|
is_compute_only: bool
|
|
``True`` if the device is only compute (no graphics)
|
|
"""
|
|
description: DXGIAdapterDesc1
|
|
driver_version: str
|
|
local_mem: DXGIQueryVideoMemoryInfo
|
|
non_local_mem: DXGIQueryVideoMemoryInfo
|
|
is_d3d12: bool
|
|
is_compute_only: bool = False
|
|
|
|
@property
|
|
def is_software_adapter(self) -> bool:
|
|
""" bool: ``True`` if this is a software adapter. """
|
|
return self.description.Flags == DXGIAdapterFlag.DXGI_ADAPTER_FLAG_SOFTWARE.value
|
|
|
|
@property
|
|
def is_valid(self) -> bool:
|
|
""" bool: ``True`` if this adapter is a hardware adaptor and is not the basic renderer """
|
|
if self.is_software_adapter:
|
|
return False
|
|
|
|
if (self.description.VendorId == VendorID.MICROSOFT.value and
|
|
self.description.DeviceId == 0x8c):
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class Adapters(): # pylint:disable=too-few-public-methods
|
|
""" Wrapper to obtain connected DirectX Graphics interface adapters from Windows
|
|
|
|
Parameters
|
|
----------
|
|
log_func: :func:`~lib.gpu_stats._base._log`
|
|
The logging function to use from the parent GPUStats class
|
|
"""
|
|
def __init__(self, log_func: Callable[[str, str], None]) -> None:
|
|
self._log = log_func
|
|
self._log("debug", f"Initializing {self.__class__.__name__}: (log_func: {log_func})")
|
|
|
|
self._factory = self._get_factory()
|
|
self._adapters = self._get_adapters()
|
|
self._devices = self._process_adapters()
|
|
|
|
self._valid_adaptors: list[Device] = []
|
|
self._log("debug", f"Initialized {self.__class__.__name__}")
|
|
|
|
def _get_factory(self) -> ctypes._Pointer:
|
|
""" Get a DXGI 1.1 Factory object
|
|
|
|
Reference
|
|
---------
|
|
https://learn.microsoft.com/en-us/windows/win32/api/dxgi/nf-dxgi-createdxgifactory1
|
|
|
|
Returns
|
|
-------
|
|
:class:`ctypes._Pointer`
|
|
A pointer to a :class:`IDXGIFactory6` COM instance
|
|
"""
|
|
factory_func = windll.dxgi.CreateDXGIFactory
|
|
factory_func.argtypes = (GUID, POINTER(ctypes.c_void_p))
|
|
factory_func.restype = HRESULT
|
|
handle = ctypes.c_void_p(0)
|
|
factory_func(IDXGIFactory6._iid_, ctypes.byref(handle)) # pylint:disable=protected-access
|
|
retval = ctypes.POINTER(IDXGIFactory6)(T.cast(IDXGIFactory6, handle.value))
|
|
self._log("debug", f"factory: {retval}")
|
|
return retval
|
|
|
|
@property
|
|
def valid_adapters(self) -> list[Device]:
|
|
""" list[:class:`Device`]: DirectX 12 compatible hardware :class:`Device` objects """
|
|
if self._valid_adaptors:
|
|
return self._valid_adaptors
|
|
|
|
for device in self._devices:
|
|
if not device.is_valid:
|
|
# Sorted by most performant so everything after first basic adapter is skipped
|
|
break
|
|
if not device.is_d3d12:
|
|
continue
|
|
self._valid_adaptors.append(device)
|
|
self._log("debug", f"valid_adaptors: {self._valid_adaptors}")
|
|
return self._valid_adaptors
|
|
|
|
def _get_adapters(self) -> list[ctypes._Pointer]:
|
|
""" Obtain DirectX 12 supporting hardware adapter objects and add a Device class for
|
|
obtaining details
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
List of :class:`ctypes._Pointer` objects
|
|
"""
|
|
idx = 0
|
|
retval = []
|
|
while True:
|
|
try:
|
|
handle = ctypes.c_void_p(0)
|
|
success = self._factory.EnumAdapterByGpuPreference( # type:ignore[attr-defined]
|
|
idx,
|
|
DXGIGpuPreference.DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE.value,
|
|
IDXGIAdapter3._iid_, # pylint:disable=protected-access
|
|
ctypes.byref(handle))
|
|
if success != 0:
|
|
raise AttributeError("Error calling EnumAdapterByGpuPreference. Result: "
|
|
f"{hex(ctypes.c_ulong(success).value)}")
|
|
adapter = POINTER(IDXGIAdapter3)(T.cast(IDXGIAdapter3, handle.value))
|
|
self._log("debug", f"found adapter: {adapter}")
|
|
retval.append(adapter)
|
|
except COMError as err:
|
|
err_code = hex(ctypes.c_ulong(err.hresult).value) # pylint:disable=no-member
|
|
self._log(
|
|
"debug",
|
|
"COM Error. Breaking: "
|
|
f"{err.text}({err_code})") # pylint:disable=no-member
|
|
break
|
|
finally:
|
|
idx += 1
|
|
|
|
self._log("debug", f"adapters: {retval}")
|
|
return retval
|
|
|
|
def _query_adapter(self, func: Callable[[T.Any], T.Any], *args: T.Any) -> None:
|
|
""" Query an adapter function, logging if the HRESULT is not a success
|
|
|
|
Parameters
|
|
----------
|
|
func: Callable[[Any], Any]
|
|
The adaptor function to call
|
|
args: Any
|
|
The arguments to pass to the adaptor function
|
|
"""
|
|
check = func(*args)
|
|
if check:
|
|
self._log("debug", f"Failed HRESULT for func {func}({args}): "
|
|
f"{hex(ctypes.c_ulong(check).value)}")
|
|
|
|
def _test_d3d12(self, adapter: ctypes._Pointer) -> bool:
|
|
""" Test whether the given adapter supports DirectX 12
|
|
|
|
Parameters
|
|
----------
|
|
adapter: :class:`ctypes._Pointer`
|
|
A pointer to an adapter instance
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
``True`` if the given adapter supports DirectX 12
|
|
"""
|
|
factory_func = windll.d3d12.D3D12CreateDevice
|
|
factory_func.argtypes = (
|
|
POINTER(IUnknown),
|
|
D3DFeatureLevel.ctype, GUID) # type:ignore[attr-defined] # pylint:disable=no-member
|
|
factory_func.restype = HRESULT
|
|
success = factory_func(adapter,
|
|
D3DFeatureLevel.D3D_FEATURE_LEVEL_11_0.value,
|
|
LookupGUID.ID3D12Device)
|
|
return success in (0, 1)
|
|
|
|
def _process_adapters(self) -> list[Device]:
|
|
""" Process the adapters to add discovered information.
|
|
|
|
Returns
|
|
-------
|
|
list[:class:`Device`]
|
|
List of device of objects found in the adapters
|
|
"""
|
|
retval = []
|
|
for adapter in self._adapters:
|
|
# Description
|
|
desc = DXGIAdapterDesc1()
|
|
self._query_adapter(adapter.GetDesc1, ctypes.byref(desc)) # type:ignore[attr-defined]
|
|
|
|
# Driver Version
|
|
driver = DriverVersion()
|
|
self._query_adapter(adapter.CheckInterfaceSupport, # type:ignore[attr-defined]
|
|
LookupGUID.IDXGIDevice,
|
|
ctypes.byref(driver))
|
|
driver_version = f"{driver.parts_d}.{driver.parts_c}.{driver.parts_b}.{driver.parts_a}"
|
|
|
|
# Current Memory
|
|
local_mem = DXGIQueryVideoMemoryInfo()
|
|
self._query_adapter(adapter.QueryVideoMemoryInfo, # type:ignore[attr-defined]
|
|
0,
|
|
DXGIMemorySegmentGroup.DXGI_MEMORY_SEGMENT_GROUP_LOCAL.value,
|
|
local_mem)
|
|
non_local_mem = DXGIQueryVideoMemoryInfo()
|
|
self._query_adapter(
|
|
adapter.QueryVideoMemoryInfo, # type:ignore[attr-defined]
|
|
0,
|
|
DXGIMemorySegmentGroup.DXGI_MEMORY_SEGMENT_GROUP_NON_LOCAL.value,
|
|
non_local_mem)
|
|
|
|
# is_d3d12
|
|
is_d3d12 = self._test_d3d12(adapter)
|
|
|
|
retval.append(Device(desc, driver_version, local_mem, non_local_mem, is_d3d12))
|
|
|
|
return retval
|
|
|
|
|
|
class DirectML(_GPUStats):
|
|
""" Holds information and statistics about GPUs connected using Windows API
|
|
|
|
Parameters
|
|
----------
|
|
log: bool, optional
|
|
Whether the class should output information to the logger. There may be occasions where the
|
|
logger has not yet been set up when this class is queried. Attempting to log in these
|
|
instances will raise an error. If GPU stats are being queried prior to the logger being
|
|
available then this parameter should be set to ``False``. Otherwise set to ``True``.
|
|
Default: ``True``
|
|
"""
|
|
def __init__(self, log: bool = True) -> None:
|
|
self._devices: list[Device] = []
|
|
super().__init__(log=log)
|
|
|
|
@property
|
|
def _all_vram(self) -> list[int]:
|
|
""" list: The VRAM of each GPU device that the DX API has discovered. """
|
|
return [int(device.description.DedicatedVideoMemory / (1024 * 1024))
|
|
for device in self._devices]
|
|
|
|
@property
|
|
def names(self) -> list[str]:
|
|
""" list: The name of each GPU device that the DX API has discovered. """
|
|
return [device.description.Description for device in self._devices]
|
|
|
|
def _get_active_devices(self) -> list[int]:
|
|
""" Obtain the indices of active GPUs (those that have not been explicitly excluded by
|
|
DML_VISIBLE_DEVICES environment variable or explicitly excluded in the command line
|
|
arguments).
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
The list of device indices that are available for Faceswap to use
|
|
"""
|
|
devices = super()._get_active_devices()
|
|
env_devices = os.environ.get("DML_VISIBLE_DEVICES")
|
|
if env_devices:
|
|
new_devices = [int(i) for i in env_devices.split(",")]
|
|
devices = [idx for idx in devices if idx in new_devices]
|
|
self._log("debug", f"Active GPU Devices: {devices}")
|
|
return devices
|
|
|
|
def _get_devices(self) -> list[Device]:
|
|
""" Obtain all detected DX API devices.
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
The :class:`~dx_lib.Device` objects for GPUs that the DX API has discovered.
|
|
"""
|
|
adapters = Adapters(log_func=self._log)
|
|
devices = adapters.valid_adapters
|
|
self._log("debug", f"Obtained Devices: {devices}")
|
|
return devices
|
|
|
|
def _initialize(self) -> None:
|
|
""" Initialize DX Core for DirectML backend.
|
|
|
|
If :attr:`_is_initialized` is ``True`` then this function just returns performing no
|
|
action.
|
|
|
|
if ``False`` then DirectML is setup, if not already, and GPU information is extracted
|
|
from the DirectML context.
|
|
"""
|
|
if self._is_initialized:
|
|
return
|
|
self._log("debug", "Initializing Win DX API for DirectML.")
|
|
self._devices = self._get_devices()
|
|
super()._initialize()
|
|
|
|
def _get_device_count(self) -> int:
|
|
""" Detect the number of GPUs available from the DX API.
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
The total number of GPUs available
|
|
"""
|
|
retval = len(self._devices)
|
|
self._log("debug", f"GPU Device count: {retval}")
|
|
return retval
|
|
|
|
def _get_handles(self) -> list:
|
|
""" The DX API doesn't really use device handles, so we just return the all devices list
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
The list of all discovered GPUs
|
|
"""
|
|
handles = self._devices
|
|
self._log("debug", f"DirectML GPU Handles found: {handles}")
|
|
return handles
|
|
|
|
def _get_driver(self) -> str:
|
|
""" Obtain the driver versions currently in use.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The current DirectX 12 GPU driver versions
|
|
"""
|
|
drivers = "|".join([device.driver_version if device.driver_version else "No Driver Found"
|
|
for device in self._devices])
|
|
self._log("debug", f"GPU Drivers: {drivers}")
|
|
return drivers
|
|
|
|
def _get_device_names(self) -> list[str]:
|
|
""" Obtain the list of names of connected GPUs as identified in :attr:`_handles`.
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
The list of connected Nvidia GPU names
|
|
"""
|
|
names = self.names
|
|
self._log("debug", f"GPU Devices: {names}")
|
|
return names
|
|
|
|
def _get_vram(self) -> list[int]:
|
|
""" Obtain the VRAM in Megabytes for each connected DirectML GPU as identified in
|
|
:attr:`_handles`.
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
The VRAM in Megabytes for each connected Nvidia GPU
|
|
"""
|
|
vram = self._all_vram
|
|
self._log("debug", f"GPU VRAM: {vram}")
|
|
return vram
|
|
|
|
def _get_free_vram(self) -> list[int]:
|
|
""" Obtain the amount of VRAM that is available, in Megabytes, for each connected DirectX
|
|
12 supporting GPU.
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
List of `float`s containing the amount of VRAM available, in Megabytes, for each
|
|
connected GPU as corresponding to the values in :attr:`_handles
|
|
"""
|
|
vram = [int(device.local_mem.Budget / (1024 * 1024)) for device in self._devices]
|
|
self._log("debug", f"GPU VRAM free: {vram}")
|
|
return vram
|