Skip to content

Commit

Permalink
Scale columns to -1..1 range
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Apr 7, 2023
1 parent c73e984 commit b60c38e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
13 changes: 6 additions & 7 deletions rwkv/compare_with_reference_implementation.py
Expand Up @@ -103,16 +103,16 @@ def main() -> None:
threshold = 0.00002 if model_has_scaled_matrices else 0.000005
elif data_type == 1:
# FP16
threshold = 0.001 if model_has_scaled_matrices else 0.0032
threshold = 0.0022 if model_has_scaled_matrices else 0.0032
elif data_type == 2:
# Q4_0
threshold = 0.0054 if model_has_scaled_matrices else 0.4
threshold = 1.59 if model_has_scaled_matrices else 0.4
elif data_type == 3:
# Q4_1
threshold = 0.84 if model_has_scaled_matrices else 1.21
threshold = 0.74 if model_has_scaled_matrices else 1.21
elif data_type == 4:
# Q4_1_O
threshold = 0.32 if model_has_scaled_matrices else 0.2
threshold = 0.53 if model_has_scaled_matrices else 0.2

model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)

Expand All @@ -124,9 +124,6 @@ def compare_logits(tokens_subset: List[int]) -> None:
for i in range(token_count):
token: int = tokens_subset[i]

if token_count <= 4 or i % (token_count // 4) == 0:
print(f'{i + 1}/{token_count}')

logits, state = model.eval(token, state, state, logits)

actual_logits = logits
Expand All @@ -145,6 +142,8 @@ def compare_logits(tokens_subset: List[int]) -> None:

difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()

torch.set_printoptions(sci_mode=False)

print(f'Reference logits: {expected_logits}')
print(f'Actual logits: {actual_logits}')
print('Difference per token: %.8f' % (difference,))
Expand Down
16 changes: 11 additions & 5 deletions rwkv/scale_columns.py
Expand Up @@ -41,10 +41,9 @@ def test_scale(max_by_column: np.ndarray, min_by_column: np.ndarray, normalized_
actual_result: np.ndarray = (normalized_parameter @ in_by_max) + min_dot_in
expected_result: np.ndarray = parameter @ random_input

diff_0: float = np.sum(actual_result - expected_result) / len(random_input)
diff_1: float = np.sum(expected_result - actual_result) / len(random_input)
diff: float = np.sum(actual_result - expected_result) / len(random_input)

assert diff_0 < 0.000001 or diff_1 < 0.000001, f'{diff_0}, {diff_1}'
assert abs(diff) < 0.000001, f'Difference {diff} is too big'

def main() -> None:
args = parse_args()
Expand Down Expand Up @@ -112,11 +111,18 @@ def main() -> None:

if dim_count == 2 and key != 'emb.weight' and key != 'head.weight':
# Scale
# Do all computation in float32 for better precision
dtype = parameter.dtype
parameter = np.float32(parameter)

min_by_column: np.ndarray = np.amin(parameter, axis=0)
normalized_parameter: np.ndarray = parameter - min_by_column
max_by_column: np.ndarray = np.amax(normalized_parameter, axis=0)
normalized_parameter: np.ndarray = normalized_parameter / max_by_column
normalized_parameter.tofile(out_file)
normalized_parameter: np.ndarray = normalized_parameter / max_by_column * 2 - 1
normalized_parameter.astype(dtype).tofile(out_file)

min_by_column += max_by_column / 2
max_by_column /= 2

test_scale(max_by_column, min_by_column, normalized_parameter, parameter)

Expand Down

0 comments on commit b60c38e

Please sign in to comment.