diff --git a/rwkv/compare_with_reference_implementation.py b/rwkv/compare_with_reference_implementation.py index 05c1a58..d45e2de 100644 --- a/rwkv/compare_with_reference_implementation.py +++ b/rwkv/compare_with_reference_implementation.py @@ -103,16 +103,16 @@ def main() -> None: threshold = 0.00002 if model_has_scaled_matrices else 0.000005 elif data_type == 1: # FP16 - threshold = 0.001 if model_has_scaled_matrices else 0.0032 + threshold = 0.0022 if model_has_scaled_matrices else 0.0032 elif data_type == 2: # Q4_0 - threshold = 0.0054 if model_has_scaled_matrices else 0.4 + threshold = 1.59 if model_has_scaled_matrices else 0.4 elif data_type == 3: # Q4_1 - threshold = 0.84 if model_has_scaled_matrices else 1.21 + threshold = 0.74 if model_has_scaled_matrices else 1.21 elif data_type == 4: # Q4_1_O - threshold = 0.32 if model_has_scaled_matrices else 0.2 + threshold = 0.53 if model_has_scaled_matrices else 0.2 model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path) @@ -124,9 +124,6 @@ def compare_logits(tokens_subset: List[int]) -> None: for i in range(token_count): token: int = tokens_subset[i] - if token_count <= 4 or i % (token_count // 4) == 0: - print(f'{i + 1}/{token_count}') - logits, state = model.eval(token, state, state, logits) actual_logits = logits @@ -145,6 +142,8 @@ def compare_logits(tokens_subset: List[int]) -> None: difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() + torch.set_printoptions(sci_mode=False) + print(f'Reference logits: {expected_logits}') print(f'Actual logits: {actual_logits}') print('Difference per token: %.8f' % (difference,)) diff --git a/rwkv/scale_columns.py b/rwkv/scale_columns.py index e469dde..ea12462 100644 --- a/rwkv/scale_columns.py +++ b/rwkv/scale_columns.py @@ -41,10 +41,9 @@ def test_scale(max_by_column: np.ndarray, min_by_column: np.ndarray, normalized_ actual_result: np.ndarray = (normalized_parameter @ in_by_max) + min_dot_in expected_result: np.ndarray = parameter @ random_input - diff_0: float = np.sum(actual_result - expected_result) / len(random_input) - diff_1: float = np.sum(expected_result - actual_result) / len(random_input) + diff: float = np.sum(actual_result - expected_result) / len(random_input) - assert diff_0 < 0.000001 or diff_1 < 0.000001, f'{diff_0}, {diff_1}' + assert abs(diff) < 0.000001, f'Difference {diff} is too big' def main() -> None: args = parse_args() @@ -112,11 +111,18 @@ def main() -> None: if dim_count == 2 and key != 'emb.weight' and key != 'head.weight': # Scale + # Do all computation in float32 for better precision + dtype = parameter.dtype + parameter = np.float32(parameter) + min_by_column: np.ndarray = np.amin(parameter, axis=0) normalized_parameter: np.ndarray = parameter - min_by_column max_by_column: np.ndarray = np.amax(normalized_parameter, axis=0) - normalized_parameter: np.ndarray = normalized_parameter / max_by_column - normalized_parameter.tofile(out_file) + normalized_parameter: np.ndarray = normalized_parameter / max_by_column * 2 - 1 + normalized_parameter.astype(dtype).tofile(out_file) + + min_by_column += max_by_column / 2 + max_by_column /= 2 test_scale(max_by_column, min_by_column, normalized_parameter, parameter)