Skip to content

Commit

Permalink
Fix encoding polyfill issue around null pointers (npgsql#5453)
Browse files Browse the repository at this point in the history
Fixes npgsql#5446

Signed-off-by: monjowe <jonas.westman@monitor.se>
  • Loading branch information
NinoFloris authored and JonasWestman committed Dec 20, 2023
1 parent 0bb48cd commit 8adce6c
Showing 1 changed file with 31 additions and 13 deletions.
44 changes: 31 additions & 13 deletions src/Npgsql/Shims/EncodingExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// ReSharper disable RedundantUsingDirective
using System.Buffers;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
// ReSharper restore RedundantUsingDirective

Expand All @@ -10,70 +11,87 @@ namespace System.Text;
static class EncodingExtensions
{
#if NETSTANDARD2_0

/// <summary>
/// Returns a reference to the 0th element of the ReadOnlySpan. If the ReadOnlySpan is empty, returns a reference to fake non-null pointer. Such a reference
/// can be used for pinning but must never be dereferenced. This is useful for interop with methods that do not accept null pointers for zero-sized buffers.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static unsafe ref readonly T GetNonNullPinnableReference<T>(ReadOnlySpan<T> span)
=> ref span.Length != 0 ? ref span.GetPinnableReference() : ref Unsafe.AsRef<T>((void*)1);

/// <summary>
/// Returns a reference to the 0th element of the ReadOnlySpan. If the ReadOnlySpan is empty, returns a reference to fake non-null pointer. Such a reference
/// can be used for pinning but must never be dereferenced. This is useful for interop with methods that do not accept null pointers for zero-sized buffers.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static unsafe ref T GetNonNullPinnableReference<T>(Span<T> span)
=> ref span.Length != 0 ? ref span.GetPinnableReference() : ref Unsafe.AsRef<T>((void*)1);

public static unsafe int GetByteCount(this Encoding encoding, ReadOnlySpan<char> chars)
{
fixed (char* charsPtr = chars)
fixed (char* charsPtr = &GetNonNullPinnableReference(chars))
{
return encoding.GetByteCount(charsPtr, chars.Length);
}
}

public static unsafe int GetBytes(this Encoding encoding, ReadOnlySpan<char> chars, Span<byte> bytes)
{
fixed (char* charsPtr = chars)
fixed (byte* bytesPtr = bytes)
fixed (char* charsPtr = &GetNonNullPinnableReference(chars))
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
{
return encoding.GetBytes(charsPtr, chars.Length, bytesPtr, bytes.Length);
}
}

public static unsafe int GetCharCount(this Encoding encoding, ReadOnlySpan<byte> bytes)
{
fixed (byte* bytesPtr = bytes)
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
{
return encoding.GetCharCount(bytesPtr, bytes.Length);
}
}

public static unsafe int GetCharCount(this Decoder encoding, ReadOnlySpan<byte> bytes, bool flush)
{
fixed (byte* bytesPtr = bytes)
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
{
return encoding.GetCharCount(bytesPtr, bytes.Length, flush);
}
}

public static unsafe int GetChars(this Decoder encoding, ReadOnlySpan<byte> bytes, Span<char> chars, bool flush)
{
fixed (byte* bytesPtr = bytes)
fixed (char* charsPtr = chars)
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
fixed (char* charsPtr = &GetNonNullPinnableReference(chars))
{
return encoding.GetChars(bytesPtr, bytes.Length, charsPtr, chars.Length, flush);
}
}

public static unsafe int GetChars(this Encoding encoding, ReadOnlySpan<byte> bytes, Span<char> chars)
{
fixed (byte* bytesPtr = bytes)
fixed (char* charsPtr = chars)
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
fixed (char* charsPtr = &GetNonNullPinnableReference(chars))
{
return encoding.GetChars(bytesPtr, bytes.Length, charsPtr, chars.Length);
}
}

public static unsafe void Convert(this Encoder encoder, ReadOnlySpan<char> chars, Span<byte> bytes, bool flush, out int charsUsed, out int bytesUsed, out bool completed)
{
fixed (char* charsPtr = chars)
fixed (byte* bytesPtr = bytes)
fixed (char* charsPtr = &GetNonNullPinnableReference(chars))
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
{
encoder.Convert(charsPtr, chars.Length, bytesPtr, bytes.Length, flush, out charsUsed, out bytesUsed, out completed);
}
}

public static unsafe void Convert(this Decoder encoder, ReadOnlySpan<byte> bytes, Span<char> chars, bool flush, out int bytesUsed, out int charsUsed, out bool completed)
{
fixed (byte* bytesPtr = bytes)
fixed (char* charsPtr = chars)
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
fixed (char* charsPtr = &GetNonNullPinnableReference(chars))
{
encoder.Convert(bytesPtr, bytes.Length, charsPtr, chars.Length, flush, out bytesUsed, out charsUsed, out completed);
}
Expand Down

0 comments on commit 8adce6c

Please sign in to comment.