Skip to content

Commit

Permalink
Repo sync (#694)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed May 18, 2024
1 parent 965e0ac commit 7c80eab
Show file tree
Hide file tree
Showing 17 changed files with 112 additions and 47 deletions.
12 changes: 6 additions & 6 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ def _yacl():
http_archive,
name = "yacl",
urls = [
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.4b3.tar.gz",
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b0.tar.gz",
],
strip_prefix = "yacl-0.4.4b3",
sha256 = "c6b5f32e92d2e31c1c5d7176792965fcf332d1ae892ab8b049d2e66f6f47e4f2",
strip_prefix = "yacl-0.4.5b0",
sha256 = "68d1dbeb255d404606d3ba9380b915fbbe3886cde575bbe89795657286742bd2",
)

def _libpsi():
maybe(
http_archive,
name = "psi",
urls = [
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.0.dev240401.tar.gz",
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.0.dev240517.tar.gz",
],
strip_prefix = "psi-0.4.0.dev240401",
sha256 = "bc91e5c635fc94f865004e61e3896eb334d76549c1125fbc98caf8c6b3a82463",
strip_prefix = "psi-0.4.0.dev240517",
sha256 = "43a475d44798d0a634f9cff2d2bd3a2c2c5f0f0dee34f01ac5de803f2a0de328",
)

def _rules_proto_grpc():
Expand Down
20 changes: 14 additions & 6 deletions libspu/kernel/hal/fxp_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ Value log_minmax_normalized(SPUContext* ctx, const Value& x) {
const auto k1 = constant(ctx, 1.0F, x.dtype(), x.shape());
auto xm1 = f_sub(ctx, x, k1);

return detail::polynomial(ctx, xm1, kLogCoefficient);
return detail::polynomial(ctx, xm1, kLogCoefficient, SignType::Positive,
SignType::Positive);
}

// Ref:
Expand Down Expand Up @@ -552,11 +553,15 @@ static Value rsqrt_init_guess(SPUContext* ctx, const Value& x, const Value& z) {
if (!ctx->config().enable_lower_accuracy_rsqrt()) {
auto coeffs = {0.0F, -15.47994394F, 38.4714796F, -49.86605845F,
26.02942339F};
r = f_add(ctx, detail::polynomial(ctx, u, coeffs),
r = f_add(ctx,
detail::polynomial(ctx, u, coeffs, SignType::Positive,
SignType::Positive),
constant(ctx, 4.14285016F, x.dtype(), x.shape()));
} else {
auto coeffs = {0.0F, -5.9417F, 4.7979F};
r = f_add(ctx, detail::polynomial(ctx, u, coeffs),
r = f_add(ctx,
detail::polynomial(ctx, u, coeffs, SignType::Positive,
SignType::Positive),
constant(ctx, 3.1855F, x.dtype(), x.shape()));
}

Expand Down Expand Up @@ -764,7 +769,8 @@ Value ErfImpl(SPUContext* ctx, const Value& x) {
0.078108};
auto one = constant(ctx, 1.0, x.dtype(), x.shape());

auto z = detail::polynomial(ctx, x, kErfCoefficient);
auto z = detail::polynomial(ctx, x, kErfCoefficient, SignType::Positive,
SignType::Positive);
z = f_square(ctx, z);
z = f_square(ctx, z);
z = detail::reciprocal_goldschmidt_positive(ctx, z);
Expand Down Expand Up @@ -816,9 +822,11 @@ Value AtanApproxLocal(SPUContext* ctx, const Value& x) {
-0.1337452245060563, 0.022023163399866309};

if (ctx->getFxpBits() <= 20) {
return detail::polynomial(ctx, x, kAtanCoefficientSmall);
return detail::polynomial(ctx, x, kAtanCoefficientSmall, SignType::Positive,
SignType::Positive);
} else {
return detail::polynomial(ctx, x, kAtanCoefficientLarge);
return detail::polynomial(ctx, x, kAtanCoefficientLarge, SignType::Positive,
SignType::Positive);
}
}

Expand Down
19 changes: 13 additions & 6 deletions libspu/kernel/hal/fxp_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ namespace detail {
// Calc:
// y = c0 + x*c1 + x^2*c2 + x^3*c3 + ... + x^n*c[n]
Value polynomial(SPUContext* ctx, const Value& x,
absl::Span<Value const> coeffs) {
absl::Span<Value const> coeffs, SignType sign_x,
SignType sign_ret) {
SPU_TRACE_HAL_DISP(ctx, x);
SPU_ENFORCE(x.isFxp());
SPU_ENFORCE(!coeffs.empty());
Expand All @@ -42,25 +43,31 @@ Value polynomial(SPUContext* ctx, const Value& x,
const auto fbits = ctx->getFxpBits();
for (size_t i = 1; i < coeffs.size(); i++) {
if ((i & 1) == 0U) {
// x^{even order} is always positive
x_pow = _trunc(ctx, _mul(ctx, x_pow, x), fbits, SignType::Positive);
} else {
// x^{even order} is always positive
x_pow = _trunc(ctx, _mul(ctx, x_pow, x), fbits);
if (i > 1) {
x_pow = _trunc(ctx, _mul(ctx, x_pow, x), fbits, sign_x);
} else {
// i=1, then save a _trunc
x_pow = x;
}
}
res = _add(ctx, res, _mul(ctx, x_pow, coeffs[i]));
}

return _trunc(ctx, res).setDtype(x.dtype());
return _trunc(ctx, res, fbits, sign_ret).setDtype(x.dtype());
}

Value polynomial(SPUContext* ctx, const Value& x,
absl::Span<float const> coeffs) {
absl::Span<float const> coeffs, SignType sign_x,
SignType sign_ret) {
std::vector<Value> cs;
cs.reserve(coeffs.size());
for (const auto& c : coeffs) {
cs.push_back(constant(ctx, c, x.dtype(), x.shape()));
}
return polynomial(ctx, x, cs);
return polynomial(ctx, x, cs, sign_x, sign_ret);
}

Value highestOneBit(SPUContext* ctx, const Value& x) {
Expand Down
8 changes: 6 additions & 2 deletions libspu/kernel/hal/fxp_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@ Value reciprocal_goldschmidt_positive(SPUContext* ctx, const Value& b_abs);
Value reciprocal_goldschmidt(SPUContext* ctx, const Value& b);

Value polynomial(SPUContext* ctx, const Value& x,
absl::Span<Value const> coeffs);
absl::Span<Value const> coeffs,
SignType sign_x = SignType::Unknown,
SignType sign_ret = SignType::Unknown);

Value polynomial(SPUContext* ctx, const Value& x,
absl::Span<float const> coeffs);
absl::Span<float const> coeffs,
SignType sign_x = SignType::Unknown,
SignType sign_ret = SignType::Unknown);

} // namespace detail

Expand Down
6 changes: 6 additions & 0 deletions libspu/mpc/ab_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,12 @@ TEST_P(ConversionTest, MSB) {

/* GIVEN */
auto p0 = rand_p(obj.get(), kShape);

// SECURENN has an msb input range here
if (conf.protocol() == ProtocolKind::SECURENN) {
p0 = arshift_p(obj.get(), p0, 1);
}

auto a0 = p2a(obj.get(), p0);

/* WHEN */
Expand Down
26 changes: 25 additions & 1 deletion libspu/mpc/api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,31 @@ TEST_BINARY_OP(xor)
TEST_UNARY_OP_P(OP)

TEST_UNARY_OP(not )
TEST_UNARY_OP(msb)
TEST_UNARY_OP_V(msb)
TEST_UNARY_OP_P(msb)

TEST_P(ApiTest, MsbS) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
const size_t npc = std::get<2>(GetParam());

utils::simulate(npc, [&](const std::shared_ptr<yacl::link::Context>& lctx) {
auto sctx = factory(conf, lctx);

auto p0 = rand_p(sctx.get(), kShape);

// SECURENN has an msb input range requirement here
if (conf.protocol() == ProtocolKind::SECURENN) {
p0 = arshift_p(sctx.get(), p0, 1);
}

auto r_s = s2p(sctx.get(), msb_s(sctx.get(), p2s(sctx.get(), p0)));
auto r_p = msb_p(sctx.get(), p0);

/* THEN */
EXPECT_VALUE_EQ(r_s, r_p);
});
}

#define TEST_UNARY_OP_WITH_BIT_S(OP) \
TEST_P(ApiTest, OP##S) { \
Expand Down
8 changes: 4 additions & 4 deletions libspu/mpc/cheetah/ot/yacl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ spu_cc_library(
"@yacl//yacl/crypto/rand",
"@yacl//yacl/crypto/tools:crhash",
"@yacl//yacl/crypto/tools:rp",
"@yacl//yacl/kernels/algorithms:base_ot",
"@yacl//yacl/kernels/algorithms:ferret_ote",
"@yacl//yacl/kernels/algorithms:iknp_ote",
"@yacl//yacl/kernels/algorithms:softspoken_ote",
"@yacl//yacl/kernel/algorithms:base_ot",
"@yacl//yacl/kernel/algorithms:ferret_ote",
"@yacl//yacl/kernel/algorithms:iknp_ote",
"@yacl//yacl/kernel/algorithms:softspoken_ote",
"@yacl//yacl/link",
],
)
Expand Down
10 changes: 5 additions & 5 deletions libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

#include "yacl/base/dynamic_bitset.h"
#include "yacl/crypto/rand/rand.h"
#include "yacl/kernels/algorithms/base_ot.h"
#include "yacl/kernels/algorithms/ferret_ote.h"
#include "yacl/kernels/algorithms/iknp_ote.h"
#include "yacl/kernels/algorithms/ot_store.h"
#include "yacl/kernels/algorithms/softspoken_ote.h"
#include "yacl/kernel/algorithms/base_ot.h"
#include "yacl/kernel/algorithms/ferret_ote.h"
#include "yacl/kernel/algorithms/iknp_ote.h"
#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/kernel/algorithms/softspoken_ote.h"

#include "libspu/core/prelude.h"
#include "libspu/mpc/cheetah/ot/ot_util.h"
Expand Down
13 changes: 8 additions & 5 deletions libspu/mpc/securenn/conversion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,15 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx,
}

NdArrayRef Msb_a2b::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
// SC
// auto in_ = ring_add(in, in);
// auto in_ = in;
// auto in_ = ShareConvert().proc(ctx, in);
#ifndef OPT_SECURENN_MSB
// this is the default securenn implementation
auto in_ = ring_add(in, in);
in_ = ShareConvert().proc(ctx, in_);
auto res = Msb().proc(ctx, in_);
#else
// this is optimized but cannot calculate msb(-1) where all bits are 1
auto res = Msb_opt().proc(ctx, in);
// auto res = Msb().proc(ctx, in_);
#endif
res = A2B().proc(ctx, res);
return res;
}
Expand Down
19 changes: 16 additions & 3 deletions libspu/mpc/securenn/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#pragma once

#include "libspu/mpc/kernel.h"

namespace spu::mpc::securenn {
Expand Down Expand Up @@ -76,20 +75,34 @@ class B2A_Randbit : public UnaryKernel {
class Msb_a2b : public UnaryKernel {
public:
static constexpr char kBindName[] = "msb_a2b";
// static constexpr char kBindName[] = "msb_a2b_nosc";

ce::CExpr latency() const override {
#ifndef OPT_SECURENN_MSB
return ce::Const(4) // share convert
+ ce::Const(5) // msb
+ Log(ce::K() + 1) // adder-circuit;
* Log(ce::N()); // tree-reduce parties;
#else
return ce::Const(5) // msb_a2a
+ (Log(ce::K()) + 1) // adder-circuit;
* Log(ce::N()); // tree-reduce parties.;
* Log(ce::N()); // tree-reduce parties;
#endif
}
ce::CExpr comm() const override {
const auto log_p =
9; // in fact, now the element is ring2k_t rather than [0, p-1]
#ifndef OPT_SECURENN_MSB
return (6 * ce::K() + 4 * log_p * ce::K()) // share convert
+ (13 * ce::K() + 4 * ce::K() * log_p) // msb
+ (2 * Log(ce::K()) + 1) // KS-adder-circuit
* 2 * ce::K() * (ce::N() - 1) // And gate, for nPC
* (ce::N() - 1); // (no-matter tree or ring) reduce
#else
return (9 * ce::K() + 3 * ce::K() * log_p) // msb_a2a
+ (2 * Log(ce::K()) + 1) // KS-adder-circuit
* 2 * ce::K() * (ce::N() - 1) // And gate, for nPC
* (ce::N() - 1); // (no-matter tree or ring) reduce
#endif
}

NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override;
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/beaver/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ spu_cc_library(
"//libspu/mpc/spdz2k/ot:tiny_ot",
"//libspu/mpc/utils:ring_ops",
"@yacl//yacl/crypto/tools:prg",
"@yacl//yacl/kernels/algorithms:ot_store",
"@yacl//yacl/kernel/algorithms:ot_store",
"@yacl//yacl/link",
"@yacl//yacl/utils:matrix_utils",
"@yacl//yacl/utils:serialize",
Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/spdz2k/beaver/beaver_tinyot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
#include "yacl/base/dynamic_bitset.h"
#include "yacl/crypto/rand/rand.h"
#include "yacl/crypto/tools/prg.h"
#include "yacl/kernels/algorithms/base_ot.h"
#include "yacl/kernels/algorithms/ot_store.h"
#include "yacl/kernel/algorithms/base_ot.h"
#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/utils/serialize.h"

#include "libspu/mpc/common/prg_tensor.h"
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/beaver/beaver_tinyot.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#pragma once

#include "yacl/kernels/algorithms/ot_store.h"
#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/link/context.h"

#include "libspu/mpc/common/prg_state.h"
Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/spdz2k/ot/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ spu_cc_library(
"@yacl//yacl/crypto/tools:prg",
"@yacl//yacl/crypto/tools:ro",
"@yacl//yacl/crypto/tools:rp",
"@yacl//yacl/kernels/algorithms:base_ot",
"@yacl//yacl/kernel/algorithms:base_ot",
"@yacl//yacl/link",
"@yacl//yacl/utils:matrix_utils",
"@yacl//yacl/utils:serialize",
Expand All @@ -68,7 +68,7 @@ spu_cc_library(
"//libspu/mpc/utils:ring_ops",
"@com_github_emptoolkit_emp_tool//:emp-tool",
"@yacl//yacl/crypto/tools:prg",
"@yacl//yacl/kernels/algorithms:ot_store",
"@yacl//yacl/kernel/algorithms:ot_store",
"@yacl//yacl/link",
],
)
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/ot/kos_ote.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#pragma once
#include "absl/types/span.h"
#include "yacl/base/dynamic_bitset.h"
#include "yacl/kernels/algorithms/ot_store.h"
#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/link/link.h"
namespace spu::mpc::spdz2k {

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/ot/tiny_ot.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.
#include <vector>

#include "yacl/kernels/algorithms/ot_store.h"
#include "yacl/kernel/algorithms/ot_store.h"

#include "libspu/mpc/common/communicator.h"

Expand Down
2 changes: 1 addition & 1 deletion spu/tests/jnp_testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def post(x):
2,
float_dtypes,
all_shapes,
jtu.rand_small,
rand_default,
),
REC("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_small),
REC(
Expand Down

0 comments on commit 7c80eab

Please sign in to comment.