Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ object MultiPartPaymentLifecycle {
* @param assistedRoutes routing hints (usually from a Bolt 11 invoice).
* @param routeParams parameters to fine-tune the routing algorithm.
* @param additionalTlvs when provided, additional tlvs that will be added to the onion sent to the target node.
* @param userCustomTlvs when provided, additional user-defined custom tlvs that will be added to the onion sent to the target node.
*/
case class SendMultiPartPayment(paymentSecret: ByteVector32,
targetNodeId: PublicKey,
Expand All @@ -312,7 +313,8 @@ object MultiPartPaymentLifecycle {
maxAttempts: Int,
assistedRoutes: Seq[Seq[ExtraHop]] = Nil,
routeParams: Option[RouteParams] = None,
additionalTlvs: Seq[OnionTlv] = Nil) {
additionalTlvs: Seq[OnionTlv] = Nil,
userCustomTlvs: Seq[GenericTlv] = Nil) {
require(totalAmount > 0.msat, s"total amount must be > 0")
}

Expand Down Expand Up @@ -416,7 +418,7 @@ object MultiPartPaymentLifecycle {
private def createChildPayment(nodeParams: NodeParams, request: SendMultiPartPayment, childAmount: MilliSatoshi, channel: OutgoingChannel): SendPayment = {
SendPayment(
request.targetNodeId,
Onion.createMultiPartPayload(childAmount, request.totalAmount, request.targetExpiry, request.paymentSecret, request.additionalTlvs),
Onion.createMultiPartPayload(childAmount, request.totalAmount, request.targetExpiry, request.paymentSecret, request.additionalTlvs, request.userCustomTlvs),
request.maxAttempts,
request.assistedRoutes,
request.routeParams,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,14 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorR
case Some(invoice) if invoice.features.allowMultiPart && Features.hasFeature(nodeParams.features, Features.BasicMultiPartPayment) =>
invoice.paymentSecret match {
case Some(paymentSecret) =>
spawnMultiPartPaymentFsm(paymentCfg) forward SendMultiPartPayment(paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.assistedRoutes, r.routeParams)
spawnMultiPartPaymentFsm(paymentCfg) forward SendMultiPartPayment(paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.assistedRoutes, r.routeParams, userCustomTlvs = r.userCustomTlvs)
case None =>
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(PaymentSecretMissing) :: Nil)
}
case _ =>
// NB: we only generate legacy payment onions for now for maximum compatibility.
spawnPaymentFsm(paymentCfg) forward SendPayment(r.recipientNodeId, FinalLegacyPayload(r.recipientAmount, finalExpiry), r.maxAttempts, r.assistedRoutes, r.routeParams)
val paymentSecret = r.paymentRequest.flatMap(_.paymentSecret)
val finalPayload = Onion.createSinglePartPayload(r.recipientAmount, finalExpiry, paymentSecret, r.userCustomTlvs)
spawnPaymentFsm(paymentCfg) forward SendPayment(r.recipientNodeId, finalPayload, r.maxAttempts, r.assistedRoutes, r.routeParams)
}

case r: SendTrampolinePaymentRequest =>
Expand Down Expand Up @@ -201,6 +202,7 @@ object PaymentInitiator {
* @param externalId (optional) externally-controlled identifier (to reconcile between application DB and eclair DB).
* @param assistedRoutes (optional) routing hints (usually from a Bolt 11 invoice).
* @param routeParams (optional) parameters to fine-tune the routing algorithm.
* @param userCustomTlvs (optional) user-defined custom tlvs that will be added to the onion sent to the target node.
*/
case class SendPaymentRequest(recipientAmount: MilliSatoshi,
paymentHash: ByteVector32,
Expand All @@ -210,7 +212,8 @@ object PaymentInitiator {
paymentRequest: Option[PaymentRequest] = None,
externalId: Option[String] = None,
assistedRoutes: Seq[Seq[ExtraHop]] = Nil,
routeParams: Option[RouteParams] = None) {
routeParams: Option[RouteParams] = None,
userCustomTlvs: Seq[GenericTlv] = Nil) {
// We add one block in order to not have our htlcs fail when a new block has just been found.
def finalExpiry(currentBlockHeight: Long) = finalExpiryDelta.toCltvExpiry(currentBlockHeight + 1)
}
Expand Down
13 changes: 5 additions & 8 deletions eclair-core/src/main/scala/fr/acinq/eclair/wire/Onion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -276,17 +276,14 @@ object Onion {
NodeRelayPayload(TlvStream(tlvs2))
}

def createSinglePartPayload(amount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: Option[ByteVector32] = None): FinalPayload = paymentSecret match {
// We try to use the legacy format as much as possible for maximum compatibility, but when we have a payment secret we need to use TLV to include it.
case Some(paymentSecret) => FinalTlvPayload(TlvStream(AmountToForward(amount), OutgoingCltv(expiry), PaymentData(paymentSecret, amount)))
def createSinglePartPayload(amount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: Option[ByteVector32] = None, userCustomTlvs: Seq[GenericTlv] = Nil): FinalPayload = paymentSecret match {
case Some(paymentSecret) => FinalTlvPayload(TlvStream(Seq(AmountToForward(amount), OutgoingCltv(expiry), PaymentData(paymentSecret, amount)), userCustomTlvs))
case None if userCustomTlvs.nonEmpty => FinalTlvPayload(TlvStream(Seq(AmountToForward(amount), OutgoingCltv(expiry)), userCustomTlvs))
case None => FinalLegacyPayload(amount, expiry)
}

def createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32): FinalPayload =
FinalTlvPayload(TlvStream(AmountToForward(amount), OutgoingCltv(expiry), PaymentData(paymentSecret, totalAmount)))

def createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, additionalTlvs: Seq[OnionTlv]): FinalPayload =
FinalTlvPayload(TlvStream(AmountToForward(amount) +: OutgoingCltv(expiry) +: PaymentData(paymentSecret, totalAmount) +: additionalTlvs))
def createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, additionalTlvs: Seq[OnionTlv] = Nil, userCustomTlvs: Seq[GenericTlv] = Nil): FinalPayload =
FinalTlvPayload(TlvStream(AmountToForward(amount) +: OutgoingCltv(expiry) +: PaymentData(paymentSecret, totalAmount) +: additionalTlvs, userCustomTlvs))

/** Create a trampoline outer payload. */
def createTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, trampolinePacket: OnionRoutingPacket): FinalPayload = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ import scodec.{Attempt, Codec, Err}
/**
* Created by t-bast on 20/06/2019.
*/

object TlvCodecs {
// high range types are greater than or equal 2^16, see https://github.com/lightningnetwork/lightning-rfc/blob/master/01-messaging.md#type-length-value-format
private val TLV_TYPE_HIGH_RANGE = 65536

/**
* Truncated uint64 (0 to 8 bytes unsigned integer).
Expand Down Expand Up @@ -104,7 +105,7 @@ object TlvCodecs {
val ltu16: Codec[Int] = variableSizeBytes(uint8, tu16)

private def validateGenericTlv(g: GenericTlv): Attempt[GenericTlv] = {
if (g.tag.toBigInt % 2 == 0) {
if (g.tag < TLV_TYPE_HIGH_RANGE && g.tag.toBigInt % 2 == 0) {
Attempt.Failure(Err("unknown even tlv type"))
} else {
Attempt.Successful(g)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ import fr.acinq.eclair.payment.send.PaymentInitiator._
import fr.acinq.eclair.payment.send.PaymentLifecycle.{SendPayment, SendPaymentToRoute}
import fr.acinq.eclair.payment.send.{PaymentError, PaymentInitiator}
import fr.acinq.eclair.router.{NodeHop, RouteParams}
import fr.acinq.eclair.wire.Onion.FinalLegacyPayload
import fr.acinq.eclair.wire.{Onion, OnionCodecs, OnionTlv, TrampolineFeeInsufficient}
import fr.acinq.eclair.wire.Onion.{FinalLegacyPayload, FinalTlvPayload}
import fr.acinq.eclair.wire.OnionTlv.{AmountToForward, OutgoingCltv}
import fr.acinq.eclair.wire._
import fr.acinq.eclair.{CltvExpiryDelta, LongToBtcAmount, NodeParams, TestConstants, randomBytes32, randomKey}
import fr.acinq.eclair.UInt64.Conversions._
import org.scalatest.{Outcome, Tag, fixture}
import scodec.bits.HexStringSyntax

Expand Down Expand Up @@ -69,6 +71,19 @@ class PaymentInitiatorSpec extends TestKit(ActorSystem("test")) with fixture.Fun
withFixture(test.toNoArgTest(FixtureParam(nodeParams, initiator, payFsm, multiPartPayFsm, sender, eventListener)))
}

test("forward payment with user custom tlv records") { f =>
import f._
val keySendTlvRecords = Seq(GenericTlv(5482373484L, paymentPreimage))
val req = SendPaymentRequest(finalAmount, paymentHash, c, 1, CltvExpiryDelta(42), userCustomTlvs = keySendTlvRecords)
sender.send(initiator, req)
sender.expectMsgType[UUID]
payFsm.expectMsgType[SendPaymentConfig]
val FinalTlvPayload(tlvs) = payFsm.expectMsgType[SendPayment].finalPayload
assert(tlvs.get[AmountToForward].get.amount == finalAmount)
assert(tlvs.get[OutgoingCltv].get.cltv == req.finalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight + 1))
assert(tlvs.unknown == keySendTlvRecords)
}

test("reject payment with unknown mandatory feature") { f =>
import f._
val unknownFeature = 42
Expand Down Expand Up @@ -105,14 +120,14 @@ class PaymentInitiatorSpec extends TestKit(ActorSystem("test")) with fixture.Fun
payFsm.expectMsg(SendPayment(e, FinalLegacyPayload(finalAmount, Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(nodeParams.currentBlockHeight + 1)), 3))
}

test("forward legacy payment when multi-part deactivated", Tag("mpp_disabled")) { f =>
test("forward single-part payment when multi-part deactivated", Tag("mpp_disabled")) { f =>
import f._
val pr = PaymentRequest(Block.LivenetGenesisBlock.hash, Some(finalAmount), paymentHash, randomKey, "Some MPP invoice", features = Some(Features(VariableLengthOnion.optional, PaymentSecret.optional, BasicMultiPartPayment.optional)))
val req = SendPaymentRequest(finalAmount, paymentHash, c, 1, CltvExpiryDelta(42), Some(pr))
sender.send(initiator, req)
val id = sender.expectMsgType[UUID]
payFsm.expectMsg(SendPaymentConfig(id, id, None, paymentHash, finalAmount, c, Upstream.Local(id), Some(pr), storeInDb = true, publishEvent = true, Nil))
payFsm.expectMsg(SendPayment(c, FinalLegacyPayload(finalAmount, req.finalExpiry(nodeParams.currentBlockHeight)), 1))
payFsm.expectMsg(SendPayment(c, FinalTlvPayload(TlvStream(OnionTlv.AmountToForward(finalAmount), OnionTlv.OutgoingCltv(req.finalExpiry(nodeParams.currentBlockHeight)), OnionTlv.PaymentData(pr.paymentSecret.get, finalAmount))), 1))
}

test("forward multi-part payment") { f =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ class OnionCodecsSpec extends FunSuite {
}
}

test("encode/decode variable-length (tlv) final per-hop payload with custom user records") {
val tlvs = TlvStream[OnionTlv](Seq(AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42))), Seq(GenericTlv(5482373484L, hex"16c7ec71663784ff100b6eface1e60a97b92ea9d18b8ece5e558586bc7453828")))
val bin = hex"31 02020231 04012a ff0000000146c6616c2016c7ec71663784ff100b6eface1e60a97b92ea9d18b8ece5e558586bc7453828"

val encoded = finalPerHopPayloadCodec.encode(FinalTlvPayload(tlvs)).require.bytes
assert(encoded === bin)
assert(finalPerHopPayloadCodec.decode(bin.bits).require.value == FinalTlvPayload(tlvs))
}

test("decode multi-part final per-hop payload") {
val notMultiPart = finalPerHopPayloadCodec.decode(hex"07 02020231 04012a".bits).require.value
assert(notMultiPart.totalAmount === 561.msat)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,7 @@ class TlvCodecsSpec extends FunSuite {
hex"12 00",
hex"0a 00",
hex"fd0102 00",
hex"fe01000002 00",
hex"01020101 0a0101",
hex"ff0100000000000002 00",
// Invalid TestTlv1.
hex"01 01 00", // not minimally-encoded
hex"01 02 0001", // not minimally-encoded
Expand Down Expand Up @@ -308,6 +306,16 @@ class TlvCodecsSpec extends FunSuite {
assert(lengthPrefixedTestTlvStreamCodec.encode(stream).require.toByteVector === hex"0f 01012a 0b012b 0d012a fd00fe02002a")
}

test("encode/decode custom even tlv records") {
val lowRangeEven = TlvStream[TestTlv](records = Nil, unknown = Seq(GenericTlv(124, hex"2a")))
val highRangeEven = TlvStream[TestTlv](records = Nil, unknown = Seq(GenericTlv(67876545678L, hex"2b")))

assert(testTlvStreamCodec.encode(lowRangeEven).isFailure)
assert(testTlvStreamCodec.encode(highRangeEven).isSuccessful)
assert(testTlvStreamCodec.decode(hex"7c 01 2a".toBitVector).isFailure) // lowRangeEven
assert(testTlvStreamCodec.decode(testTlvStreamCodec.encode(highRangeEven).require).isSuccessful)
}

test("encode invalid tlv stream") {
val testCases = Seq(
// Unknown even type.
Expand Down