| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| import os |
|
|
| import pytest |
| import torch |
| from accelerate.utils.imports import is_bf16_available |
| from torch import nn |
|
|
| from peft import PeftModel, ShiraConfig, get_peft_model |
|
|
|
|
| def custom_random_mask_function_with_custom_kwargs(custom_arg): |
| def mask_fn(base_layer, r): |
| """ |
| This mask function is similar to the random_mask provided in src/peft/tuners/shira/mask_functions.py except the |
| seed is derived from custom_kwargs. Please use this as an example to create your own custom sparse masks that |
| may use custom_kwargs. Remember, for a pretrained weight with shape m, n, mask_fn must return only one mask |
| (shape: m, n) which must be binary 0 or 1 with num_shira_parameters = r(m+n) for linear layers. Device and |
| dtype of mask must be same as base layer's weight's device and dtype. |
| """ |
| new_seed = custom_arg |
| shape = base_layer.weight.shape |
| num_shira_weights = r * (shape[0] + shape[1]) |
| random_generator = torch.Generator() |
| random_generator.manual_seed(new_seed) |
|
|
| idx = (torch.randperm(base_layer.weight.numel(), generator=random_generator)[:num_shira_weights]).to( |
| base_layer.weight.device |
| ) |
| val = torch.ones_like(idx.type(base_layer.weight.dtype)) |
| mask = torch.zeros_like(base_layer.weight.view(1, -1)) |
| mask = mask.scatter_(1, idx.unsqueeze(0), val.unsqueeze(0)).view(shape) |
|
|
| return mask |
|
|
| return mask_fn |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, bias=True): |
| super().__init__() |
| self.relu = nn.ReLU() |
| self.lin0 = nn.Linear(10, 20, bias=bias) |
| self.lin1 = nn.Linear(20, 40, bias=bias) |
| self.lin2 = nn.Linear(40, 30, bias=bias) |
| self.lin3 = nn.Linear(30, 10, bias=bias) |
| self.sm = nn.LogSoftmax(dim=-1) |
|
|
| def forward(self, X): |
| X = self.lin0(X) |
| X = self.relu(X) |
| X = self.lin1(X) |
| X = self.relu(X) |
| X = self.lin2(X) |
| X = self.relu(X) |
| X = self.lin3(X) |
| X = self.sm(X) |
| return X |
|
|
|
|
| class TestShira: |
| @pytest.fixture |
| def mlp(self): |
| torch.manual_seed(0) |
| model = MLP() |
| return model |
|
|
| def test_mlp_single_adapter_shapes(self, mlp): |
| |
|
|
| r = 2 |
| config = ShiraConfig(r=r, target_modules=["lin1", "lin2"]) |
| |
| peft_model = get_peft_model(mlp, config) |
|
|
| shira_weight1_size = peft_model.base_model.model.lin1.shira_weight["default"].shape[0] |
| shira_weight2_size = peft_model.base_model.model.lin2.shira_weight["default"].shape[0] |
| shira_indices1_size = peft_model.base_model.model.lin1.shira_indices["default"].shape[1] |
| shira_indices2_size = peft_model.base_model.model.lin2.shira_indices["default"].shape[1] |
|
|
| base_weight1_size = peft_model.base_model.model.lin1.base_layer.weight.shape |
| base_weight2_size = peft_model.base_model.model.lin2.base_layer.weight.shape |
|
|
| delta_weight1_shape = peft_model.base_model.model.lin1.get_delta_weight("default").shape |
| delta_weight2_shape = peft_model.base_model.model.lin2.get_delta_weight("default").shape |
|
|
| assert shira_weight1_size == r * (base_weight1_size[0] + base_weight1_size[1]) |
| assert shira_weight2_size == r * (base_weight2_size[0] + base_weight2_size[1]) |
|
|
| assert shira_weight1_size == shira_indices1_size |
| assert shira_weight2_size == shira_indices2_size |
|
|
| assert delta_weight1_shape == base_weight1_size |
| assert delta_weight2_shape == base_weight2_size |
|
|
| return peft_model |
|
|
| def test_multiple_adapters_save_load(self, mlp, tmp_path): |
| |
| |
| |
| |
| |
| config = ShiraConfig(r=2, target_modules=["lin1", "lin2"], random_seed=56) |
| |
| peft_model = get_peft_model(mlp, config, adapter_name="first") |
| config2 = ShiraConfig(r=3, target_modules=["lin1", "lin2", "lin3"], random_seed=67) |
| peft_model.add_adapter("second", config2) |
|
|
| assert torch.all(peft_model.base_model.model.lin1.shira_weight["first"] == 0) |
| assert torch.all(peft_model.base_model.model.lin2.shira_weight["first"] == 0) |
| assert torch.all(peft_model.base_model.model.lin1.shira_weight["second"] == 0) |
| assert torch.all(peft_model.base_model.model.lin2.shira_weight["second"] == 0) |
| assert torch.all(peft_model.base_model.model.lin3.shira_weight["second"] == 0) |
|
|
| shira_assign_val1_f = torch.randn_like(peft_model.base_model.model.lin1.shira_weight["first"]) |
| peft_model.base_model.model.lin1.shira_weight["first"] = shira_assign_val1_f |
| shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"] |
| shira_assign_val2_f = torch.randn_like(peft_model.base_model.model.lin2.shira_weight["first"]) |
| peft_model.base_model.model.lin2.shira_weight["first"] = shira_assign_val2_f |
| shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"] |
|
|
| shira_assign_val1_s = torch.randn_like(peft_model.base_model.model.lin1.shira_weight["second"]) |
| peft_model.base_model.model.lin1.shira_weight["second"] = shira_assign_val1_s |
| shira_indices1_s = peft_model.base_model.model.lin1.shira_indices["second"] |
| shira_assign_val2_s = torch.randn_like(peft_model.base_model.model.lin2.shira_weight["second"]) |
| peft_model.base_model.model.lin2.shira_weight["second"] = shira_assign_val2_s |
| shira_indices2_s = peft_model.base_model.model.lin2.shira_indices["second"] |
| shira_assign_val3_s = torch.randn_like(peft_model.base_model.model.lin3.shira_weight["second"]) |
| peft_model.base_model.model.lin3.shira_weight["second"] = shira_assign_val3_s |
| shira_indices3_s = peft_model.base_model.model.lin3.shira_indices["second"] |
|
|
| input = torch.randn(5, 10) |
| peft_model.set_adapter("first") |
| output_first = peft_model(input) |
| peft_model.set_adapter("second") |
| output_second = peft_model(input) |
|
|
| |
| assert not torch.allclose(output_first, output_second, atol=1e-3, rtol=1e-3) |
|
|
| save_path = os.path.join(tmp_path, "shira") |
| peft_model.save_pretrained(save_path) |
| assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json")) |
| assert os.path.exists(os.path.join(save_path, "second", "adapter_config.json")) |
| del peft_model |
|
|
| torch.manual_seed(0) |
| mlp = MLP() |
| peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first") |
| peft_model.load_adapter(os.path.join(save_path, "second"), "second") |
|
|
| peft_model.set_adapter("first") |
| output_first_loaded = peft_model(input) |
| peft_model.set_adapter("second") |
| output_second_loaded = peft_model(input) |
|
|
| assert torch.allclose(output_first, output_first_loaded) |
| assert torch.allclose(output_second, output_second_loaded) |
|
|
| assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"]) |
| assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"]) |
| assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"]) |
| assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"]) |
| assert torch.all(shira_assign_val1_s == peft_model.base_model.model.lin1.shira_weight["second"]) |
| assert torch.all(shira_assign_val2_s == peft_model.base_model.model.lin2.shira_weight["second"]) |
| assert torch.all(shira_assign_val3_s == peft_model.base_model.model.lin3.shira_weight["second"]) |
| assert torch.all(shira_indices1_s == peft_model.base_model.model.lin1.shira_indices["second"]) |
| assert torch.all(shira_indices2_s == peft_model.base_model.model.lin2.shira_indices["second"]) |
| assert torch.all(shira_indices3_s == peft_model.base_model.model.lin3.shira_indices["second"]) |
|
|
| return peft_model |
|
|
| def test_save_load_custom_mask_function(self, mlp, tmp_path): |
| |
| config = ShiraConfig(r=2, mask_type="custom", target_modules=["lin1", "lin2"], init_weights=False) |
| custom_arg = 120 |
| custom_mask_fn = custom_random_mask_function_with_custom_kwargs(custom_arg) |
| config.mask_fn = custom_mask_fn |
|
|
| |
| peft_model = get_peft_model(mlp, config, adapter_name="first") |
|
|
| shira_assign_val1_f = peft_model.base_model.model.lin1.shira_weight["first"] |
| shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"] |
| shira_assign_val2_f = peft_model.base_model.model.lin2.shira_weight["first"] |
| shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"] |
|
|
| input = torch.randn(5, 10) |
| peft_model.set_adapter("first") |
| output_first = peft_model(input) |
|
|
| save_path = os.path.join(tmp_path, "shira") |
| peft_model.save_pretrained(save_path) |
| assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json")) |
| del peft_model |
|
|
| torch.manual_seed(0) |
| mlp = MLP() |
| peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first") |
|
|
| peft_model.set_adapter("first") |
| output_first_loaded = peft_model(input) |
|
|
| assert torch.allclose(output_first, output_first_loaded) |
|
|
| assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"]) |
| assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"]) |
| assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"]) |
| assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"]) |
|
|
| return peft_model |
|
|
| def test_save_load_default_random_mask_with_seed_function(self, mlp, tmp_path): |
| |
| config = ShiraConfig(r=2, target_modules=["lin1", "lin2"], random_seed=567, init_weights=False) |
|
|
| |
| peft_model = get_peft_model(mlp, config, adapter_name="first") |
|
|
| shira_assign_val1_f = peft_model.base_model.model.lin1.shira_weight["first"] |
| shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"] |
| shira_assign_val2_f = peft_model.base_model.model.lin2.shira_weight["first"] |
| shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"] |
|
|
| input = torch.randn(5, 10) |
| peft_model.set_adapter("first") |
| output_first = peft_model(input) |
|
|
| save_path = os.path.join(tmp_path, "shira") |
| peft_model.save_pretrained(save_path) |
| assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json")) |
| del peft_model |
|
|
| torch.manual_seed(0) |
| mlp = MLP() |
| peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first") |
|
|
| peft_model.set_adapter("first") |
| output_first_loaded = peft_model(input) |
|
|
| assert torch.allclose(output_first, output_first_loaded) |
|
|
| assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"]) |
| assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"]) |
| assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"]) |
| assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"]) |
|
|
| return peft_model |
|
|
| @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
| def test_shira_dtypes(self, dtype): |
| if dtype == torch.bfloat16: |
| |
| if not is_bf16_available(): |
| pytest.skip("bfloat16 not supported on this system, skipping the test") |
|
|
| model = MLP().to(dtype) |
| config = ShiraConfig(r=2, target_modules=["lin1", "lin2"]) |
| peft_model = get_peft_model(model, config) |
| inputs = torch.randn(5, 10).to(dtype) |
| output = peft_model(inputs) |
| assert output.dtype == dtype |
|
|