Skip to content

Commit

Permalink
Fix570 (#571)
Browse files Browse the repository at this point in the history
* Doc: message BodySection and possible types.

* Handle DataList in GetEstimatedBodySize

* [#570] ConnectionFactory should fail connection as early as possible

* Build error
  • Loading branch information
xinchen10 committed Sep 8, 2023
1 parent ac53888 commit 772d0d3
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 30 deletions.
38 changes: 29 additions & 9 deletions src/Message.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ public class Message : IDisposable
public ApplicationProperties ApplicationProperties;

/// <summary>
/// The body section. The library supports one section only.
/// The body section. It can be one of the following,
/// * <see cref="AmqpValue"/>
/// * <see cref="AmqpSequence"/>
/// * <see cref="Data"/>
/// * <see cref="DataList"/>
/// </summary>
public RestrictedDescribed BodySection;

Expand Down Expand Up @@ -301,14 +305,7 @@ static int GetEstimatedBodySize(RestrictedDescribed body)
var data = body as Data;
if (data != null)
{
if (data.Buffer != null)
{
return data.Buffer.Length;
}
else
{
return data.Binary.Length;
}
return GetEstimatedDataSize(data);
}

var value = body as AmqpValue;
Expand All @@ -327,9 +324,32 @@ static int GetEstimatedBodySize(RestrictedDescribed body)
}
}

var dataList = body as DataList;
if (dataList != null)
{
int size = 0;
for (int i = 0; i < dataList.Count; i++)
{
size += GetEstimatedDataSize(dataList[i]);
}
return size;
}

return 64;
}

static int GetEstimatedDataSize(Data data)
{
if (data.Buffer != null)
{
return data.Buffer.Length;
}
else
{
return data.Binary.Length;
}
}

/// <summary>
/// Gets estimated message size in bytes.
/// </summary>
Expand Down
15 changes: 11 additions & 4 deletions src/Net/AsyncPump.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ public AsyncPump(IBufferManager bufferManager, IAsyncTransport transport)
this.transport = transport;
}

public void Start(Connection connection)
public void Start(Connection connection, Action<Exception> onException = null)
{
Task task = this.StartAsync(connection);
Task task = this.StartAsync(connection, onException);
}

public async Task PumpAsync(uint maxFrameSize, Func<ProtocolHeader, bool> onHeader, Func<ByteBuffer, bool> onBuffer)
Expand Down Expand Up @@ -85,7 +85,7 @@ public async Task PumpAsync(uint maxFrameSize, Func<ProtocolHeader, bool> onHead
}
}

async Task StartAsync(Connection connection)
async Task StartAsync(Connection connection, Action<Exception> onException)
{
try
{
Expand All @@ -94,12 +94,19 @@ async Task StartAsync(Connection connection)
catch (AmqpException amqpException)
{
connection.OnException(amqpException);
if (onException != null)
{
onException(amqpException);
}
}
catch (Exception exception)
{
connection.OnIoException(exception);
if (onException != null)
{
onException(exception);
}
}

}

async Task ReceiveBufferAsync(byte[] buffer, int offset, int count)
Expand Down
39 changes: 34 additions & 5 deletions src/Net/ConnectionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,9 @@ async Task<Connection> CreateAsync(Address address, Open open, OnOpened onOpened
}

IAsyncTransport transport = await this.CreateTransportAsync(address, saslProfile, handler).ConfigureAwait(false);
Connection connection = new Connection(this.BufferManager, this.AMQP, address, transport, open, onOpened, handler);

AsyncPump pump = new AsyncPump(this.BufferManager, transport);
pump.Start(connection);

return connection;
var tcs = new ConnectTaskCompletionSource(this, address, open, onOpened, handler, transport);
return await tcs.Task.ConfigureAwait(false);
}

/// <summary>
Expand Down Expand Up @@ -281,5 +278,37 @@ public SaslProfile Profile
set;
}
}

sealed class ConnectTaskCompletionSource : TaskCompletionSource<Connection>
{
readonly ConnectionFactory factory;
readonly OnOpened onOpened;
Connection connection;

public ConnectTaskCompletionSource(ConnectionFactory factory, Address address, Open open, OnOpened onOpened, IHandler handler, IAsyncTransport transport)
{
this.factory = factory;
this.onOpened = onOpened;

this.connection = new Connection(this.factory.BufferManager, this.factory.AMQP, address, transport, open, this.OnOpen, handler);
AsyncPump pump = new AsyncPump(this.factory.BufferManager, transport);
pump.Start(this.connection, this.OnException);
}

void OnOpen(IConnection connection, Open open)
{
if (this.onOpened != null)
{
this.onOpened(connection, open);
}

this.TrySetResult(this.connection);
}

void OnException(Exception exception)
{
this.TrySetException(exception);
}
}
}
}
26 changes: 16 additions & 10 deletions test/Common/ProtocolTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ public void ClosedEventOnTransportResetTest()
this.testListener.RegisterTarget(TestPoint.Begin, (stream, channel, fields) =>
{
stream.Dispose();
return TestOutcome.Continue;
return TestOutcome.Stop;
});

Trace.WriteLine(TraceLevel.Information, "sync test");
Expand All @@ -802,7 +802,8 @@ public void ClosedEventOnTransportResetTest()
Connection connection = await Connection.Factory.CreateAsync(this.address);
connection.Closed += (o, e) => closed.Set();
Session session = new Session(connection);
Assert.IsTrue(closed.WaitOne(5000), "closed event not fired");
await Task.Factory.StartNew(o => ((ManualResetEvent)o).WaitOne(5000), closed);
Assert.IsTrue(closed.WaitOne(10), "closed event not fired");
Assert.AreEqual(ErrorCode.ConnectionForced, (string)connection.Error.Condition);
}).Unwrap().GetAwaiter().GetResult();
}
Expand Down Expand Up @@ -1053,8 +1054,9 @@ public void ClosedCallbackGuaranteeTest()
{
this.testListener.RegisterTarget(TestPoint.Open, (stream, channel, fields) =>
{
stream.Dispose();
return TestOutcome.Continue;
TestListener.FRM(stream, 0x10UL, 0, 0, "Test"); // open
TestListener.FRM(stream, 0x18UL, 0, channel, new Error(ErrorCode.UnauthorizedAccess)); // close
return TestOutcome.Stop;
});

Trace.WriteLine(TraceLevel.Information, "sync test");
Expand All @@ -1063,7 +1065,7 @@ public void ClosedCallbackGuaranteeTest()
Connection connection = new Connection(this.address);
connection.AddClosedCallback((o, e) => closed.Set());
Assert.IsTrue(closed.WaitOne(5000), "closed event not fired");
Assert.AreEqual(ErrorCode.ConnectionForced, (string)connection.Error.Condition);
Assert.AreEqual(ErrorCode.UnauthorizedAccess, (string)connection.Error.Condition);
closed.Reset();
connection.AddClosedCallback((o, e) => closed.Set());
Assert.IsTrue(closed.WaitOne(5000), "closed event not fired again");
Expand All @@ -1075,8 +1077,9 @@ public void ClosedCallbackGuaranteeTest()
ManualResetEvent closed = new ManualResetEvent(false);
Connection connection = await Connection.Factory.CreateAsync(this.address);
connection.AddClosedCallback((o, e) => closed.Set());
Assert.IsTrue(closed.WaitOne(5000), "closed event not fired");
Assert.AreEqual(ErrorCode.ConnectionForced, (string)connection.Error.Condition);
await Task.Factory.StartNew(o => ((ManualResetEvent)o).WaitOne(5000), closed);
Assert.IsTrue(closed.WaitOne(10), "closed event not fired");
Assert.AreEqual(ErrorCode.UnauthorizedAccess, (string)connection.Error.Condition);
closed.Reset();
connection.AddClosedCallback((o, e) => closed.Set());
Assert.IsTrue(closed.WaitOne(5000), "closed event not fired again");
Expand Down Expand Up @@ -1455,7 +1458,8 @@ public void ReceiveWithNoCreditTest()
connection.Closed += (s, a) => closed.Set();
Session session = new Session(connection);
ReceiverLink receiver = new ReceiverLink(session, "receiver-" + testName, "any");
Assert.IsTrue(closed.WaitOne(5000), "Connection not closed");
await Task.Factory.StartNew(o => ((ManualResetEvent)o).WaitOne(5000), closed);
Assert.IsTrue(closed.WaitOne(10), "Connection not closed");
Assert.AreEqual(ErrorCode.TransferLimitExceeded, (string)connection.Error.Condition);
Assert.IsTrue(receiver.IsClosed);
}).Unwrap().GetAwaiter().GetResult();
Expand Down Expand Up @@ -1794,8 +1798,10 @@ public void ConnectionEventsOnProtocolError()
Connection connection = await Connection.Factory.CreateAsync(this.address);
connection.Closed += (o, e) => closedNotified.Set();
Session session = new Session(connection);
Assert.IsTrue(closeReceived.WaitOne(5000), "Close not received");
Assert.IsTrue(closedNotified.WaitOne(5000), "Closed event not fired");
await Task.Factory.StartNew(o => ((ManualResetEvent)o).WaitOne(5000), closeReceived);
await Task.Factory.StartNew(o => ((ManualResetEvent)o).WaitOne(5000), closedNotified);
Assert.IsTrue(closeReceived.WaitOne(10), "Close not received");
Assert.IsTrue(closedNotified.WaitOne(10), "Closed event not fired");
Assert.AreEqual(ErrorCode.NotFound, (string)connection.Error.Condition);
Assert.IsTrue(session.IsClosed);
Assert.IsTrue(connection.IsClosed);
Expand Down
38 changes: 36 additions & 2 deletions test/Test.Amqp.Net/TaskTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
// ------------------------------------------------------------------------------------

using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Amqp;
using Amqp.Framing;
using Amqp.Types;
using System.Linq;
#if NETFX_CORE
using Microsoft.VisualStudio.TestPlatform.UnitTestFramework;
#else
using System.Net.Sockets;
using Microsoft.VisualStudio.TestTools.UnitTesting;
#endif

Expand Down Expand Up @@ -75,6 +76,22 @@ public async Task BasicSendReceiveAsync()
await connection.CloseAsync();
}

[TestMethod]
public async Task ConnectInvalidSASLAsync()
{
try
{
var address = new Address(this.testTarget.Address.Host, this.testTarget.Address.Port, this.testTarget.Address.User,
string.Empty, this.testTarget.Address.Path, this.testTarget.Address.Scheme);
await Connection.Factory.CreateAsync(address);
Assert.IsTrue(false, "expect AmqpException");
}
catch (AmqpException ex)
{
Trace.WriteLine(TraceLevel.Information, "exception: {0}", ex.Message);
}
}

[TestMethod]
public async Task InterfaceSendReceiveAsync()
{
Expand Down Expand Up @@ -243,6 +260,21 @@ async Task BasicSendReceiveAsyncTest()
#endif

#if NETFX && !NETFX40
[TestMethod]
public async Task ConnectInvalidAddressAsync()
{
try
{
var address = new Address("sth.invalid", 5672);
await Connection.Factory.CreateAsync(address);
Assert.IsTrue(false, "expect SocketException");
}
catch (SocketException ex)
{
Trace.WriteLine(TraceLevel.Information, "exception: {0}", ex.Message);
}
}

[TestMethod]
public async Task CustomMessageBody()
{
Expand Down Expand Up @@ -322,6 +354,8 @@ public async Task LargeMessageOnMessageCallback()

Connection connection = await Connection.Factory.CreateAsync(
this.testTarget.Address, new Open() { ContainerId = "c1", MaxFrameSize = 4096 }, null);
await Task.Yield();

Session session = new Session(connection);
SenderLink sender = new SenderLink(session, "sender-" + testName, testTarget.Path);

Expand All @@ -348,7 +382,7 @@ public async Task LargeMessageOnMessageCallback()
if (++count == nMsgs) done.Set();
});

Assert.IsTrue(done.WaitOne(120000));
Assert.IsTrue(done.WaitOne(10000));

connection.Close();
}
Expand Down

0 comments on commit 772d0d3

Please sign in to comment.