EveryDream2trainer/utils/patch_bnb.py

134 lines
4.7 KiB
Python
Raw Normal View History

2022-12-17 20:32:48 -07:00
"""
Copyright [2022] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# see: https://github.com/TimDettmers/bitsandbytes/issues/30 for explanation
import sys
import os
from subprocess import check_output
import shutil
_CEXT_PATCH = " self.lib = ct.cdll.LoadLibrary(str(binary_path))"
_MAIN_PATCH = " return 'libbitsandbytes_cuda116.dll'"
def patch_main():
bnbpath_main = "venv/Lib/site-packages/bitsandbytes/cuda_setup/main.py"
try:
with open(bnbpath_main, "r") as f:
contents = f.read()
contents = contents.split('\n')
except Exception as ex:
print(f"cannot find bitsandbytes install, aborting, error: {ex}")
return False
main_patched = False
for i, line in enumerate(contents):
if i == 112:
if line != _MAIN_PATCH:
contents[i] = _MAIN_PATCH
main_patched = True
else:
print(" *** Already patched!")
main_patched = True
assert main_patched, "unable to patch bitsandbytes, may be mismatched version, requires 0.35.0"
with open(bnbpath_main, "w") as f:
for line in contents:
f.write(line + "\n")
#print(contents)
return main_patched
def patch_cext():
bnbpath_cextension = "venv/Lib/site-packages/bitsandbytes/cextension.py"
try:
with open(bnbpath_cextension, "r") as f:
contents = f.read()
contents = contents.split('\n')
except Exception as ex:
print(f"cannot find bitsandbytes install, aborting, error: {ex}")
return False
cext_patched = False
for i, line in enumerate(contents):
# update both lines 28 and 31 to be sure correct dll is returned
if (i == 30 or i == 27):
if line != _CEXT_PATCH:
contents[i] = _CEXT_PATCH
cext_patched = True
else:
cext_patched = True
assert cext_patched, "unable to patch bitsandbytes, died midprocess, something broke and may need to reinstall bitsandbytes==0.35.0"
with open(bnbpath_cextension, "w") as f:
for line in contents:
f.write(line + "\n")
#print(contents)
return cext_patched
def iswindows():
return sys.platform.startswith('win')
def error():
print("Somethnig went wrong trying to patch bitsandbytes, aborting")
print("make sure your venv is activated and try again")
print("or if activated try: ")
print(" pip install bitsandbytes==0.35.0")
raise RuntimeError("** FATAL ERROR: unable to patch bitsandbytes for Windows env")
def check_dlls():
dll_exists = os.path.exists("venv/Lib/site-packages/bitsandbytes/libbitsandbytes_cuda116.dll")
if not dll_exists:
if not os.path.exists("tmp/bnb_cache"):
check_output("git clone https://github.com/DeXtmL/bitsandbytes-win-prebuilt tmp/bnb_cache", shell=True)
shutil.copy("tmp/bnb_cache/libbitsandbytes_cuda116.dll", "venv/Lib/site-packages/bitsandbytes/libbitsandbytes_cuda116.dll")
dll_exists = os.path.exists("venv/Lib/site-packages/bitsandbytes/libbitsandbytes_cuda116.dll")
return dll_exists
def main():
"""
applies a patch for windows compatibility for bitsandbytes 0.35.0 for using their AdamW8bit optimizer
"""
if iswindows():
print()
print(" *** Applying bitsandbytes patch for windows ***")
if not check_dlls():
print("unable to find bitsandbytes dll or clone them from git, aborting")
raise RuntimeError("** FATAL ERROR: unable to patch bitsandbytes for Windows env")
main_patched = patch_main()
cext_patched = patch_cext()
if main_patched and cext_patched:
try:
print(" *************************************************************")
print(" *** bitsandbytes windows patch applied, attempting import *** ")
import bitsandbytes
print(f" *** bitsandbytes patch succeeded, everything looks good! ***")
except:
error()
else:
error()
else:
print(" *** not using windows environment, skipping bitsandbytes patch ***")
return
if __name__ == "__main__":
main()