{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from safetensors import safe_open\n", "\n", "lora = {}\n", "with safe_open(\"/data2/bjh/diffusion-pipe/cosmos_test/20250327_02-37-25/epoch5/adapter_model.safetensors\", framework=\"pt\", device='cpu') as f:\n", " for k in f.keys():\n", " lora[k] = f.get_tensor(k)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "tensors = {}\n", "with safe_open(\"/data2/bjh/ComfyUI/models/diffusion_models/Cosmos-1_0-Diffusion-14B-Text2World.safetensors\", framework=\"pt\", device='cpu') as f:\n", " for k in f.keys():\n", " tensors[k] = f.get_tensor(k)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1152" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(lora)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "name_lis = []\n", "for k in lora:\n", " a = k.split('.')[1:][:-2]\n", " name = '.'.join(a)\n", " name_lis.append(name)\n", "name_lis=set(name_lis)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import torch\n", "new_dic = {}\n", "for k in tensors:\n", " name='.'.join(k.split('.')[1:][:-1])\n", " if name in name_lis:\n", " a,b = lora['diffusion_model.'+name+'.lora_A.weight'],lora['diffusion_model.'+name+'.lora_B.weight']\n", " new_dic[k]=tensors[k]+torch.matmul(b,a)\n", " else:\n", " new_dic[k]=tensors[k]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from safetensors.torch import save_file\n", "save_file(new_dic,'test.safetensors')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "dp", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 2 }