Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix encoding polyfill issue around null pointers #5453

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 31 additions & 13 deletions src/Npgsql/Shims/EncodingExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// ReSharper disable RedundantUsingDirective
using System.Buffers;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
// ReSharper restore RedundantUsingDirective

Expand All @@ -10,70 +11,87 @@ namespace System.Text;
static class EncodingExtensions
{
#if NETSTANDARD2_0

/// <summary>
/// Returns a reference to the 0th element of the ReadOnlySpan. If the ReadOnlySpan is empty, returns a reference to fake non-null pointer. Such a reference
/// can be used for pinning but must never be dereferenced. This is useful for interop with methods that do not accept null pointers for zero-sized buffers.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static unsafe ref readonly T GetNonNullPinnableReference<T>(ReadOnlySpan<T> span)
=> ref span.Length != 0 ? ref span.GetPinnableReference() : ref Unsafe.AsRef<T>((void*)1);

/// <summary>
/// Returns a reference to the 0th element of the ReadOnlySpan. If the ReadOnlySpan is empty, returns a reference to fake non-null pointer. Such a reference
/// can be used for pinning but must never be dereferenced. This is useful for interop with methods that do not accept null pointers for zero-sized buffers.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static unsafe ref T GetNonNullPinnableReference<T>(Span<T> span)
=> ref span.Length != 0 ? ref span.GetPinnableReference() : ref Unsafe.AsRef<T>((void*)1);

public static unsafe int GetByteCount(this Encoding encoding, ReadOnlySpan<char> chars)
{
fixed (char* charsPtr = chars)
fixed (char* charsPtr = &GetNonNullPinnableReference(chars))
{
return encoding.GetByteCount(charsPtr, chars.Length);
}
}

public static unsafe int GetBytes(this Encoding encoding, ReadOnlySpan<char> chars, Span<byte> bytes)
{
fixed (char* charsPtr = chars)
fixed (byte* bytesPtr = bytes)
fixed (char* charsPtr = &GetNonNullPinnableReference(chars))
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
{
return encoding.GetBytes(charsPtr, chars.Length, bytesPtr, bytes.Length);
}
}

public static unsafe int GetCharCount(this Encoding encoding, ReadOnlySpan<byte> bytes)
{
fixed (byte* bytesPtr = bytes)
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
{
return encoding.GetCharCount(bytesPtr, bytes.Length);
}
}

public static unsafe int GetCharCount(this Decoder encoding, ReadOnlySpan<byte> bytes, bool flush)
{
fixed (byte* bytesPtr = bytes)
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
{
return encoding.GetCharCount(bytesPtr, bytes.Length, flush);
}
}

public static unsafe int GetChars(this Decoder encoding, ReadOnlySpan<byte> bytes, Span<char> chars, bool flush)
{
fixed (byte* bytesPtr = bytes)
fixed (char* charsPtr = chars)
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
fixed (char* charsPtr = &GetNonNullPinnableReference(chars))
{
return encoding.GetChars(bytesPtr, bytes.Length, charsPtr, chars.Length, flush);
}
}

public static unsafe int GetChars(this Encoding encoding, ReadOnlySpan<byte> bytes, Span<char> chars)
{
fixed (byte* bytesPtr = bytes)
fixed (char* charsPtr = chars)
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
fixed (char* charsPtr = &GetNonNullPinnableReference(chars))
{
return encoding.GetChars(bytesPtr, bytes.Length, charsPtr, chars.Length);
}
}

public static unsafe void Convert(this Encoder encoder, ReadOnlySpan<char> chars, Span<byte> bytes, bool flush, out int charsUsed, out int bytesUsed, out bool completed)
{
fixed (char* charsPtr = chars)
fixed (byte* bytesPtr = bytes)
fixed (char* charsPtr = &GetNonNullPinnableReference(chars))
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
{
encoder.Convert(charsPtr, chars.Length, bytesPtr, bytes.Length, flush, out charsUsed, out bytesUsed, out completed);
}
}

public static unsafe void Convert(this Decoder encoder, ReadOnlySpan<byte> bytes, Span<char> chars, bool flush, out int bytesUsed, out int charsUsed, out bool completed)
{
fixed (byte* bytesPtr = bytes)
fixed (char* charsPtr = chars)
fixed (byte* bytesPtr = &GetNonNullPinnableReference(bytes))
fixed (char* charsPtr = &GetNonNullPinnableReference(chars))
{
encoder.Convert(bytesPtr, bytes.Length, charsPtr, chars.Length, flush, out bytesUsed, out charsUsed, out completed);
}
Expand Down