diff --git a/jest.config.js b/jest.config.js index 153f417f..069bc8a9 100644 --- a/jest.config.js +++ b/jest.config.js @@ -6,10 +6,10 @@ module.exports = { coverageReporters: ['text', 'html'], coverageThreshold: { global: { - branches: 76.5, + branches: 75.52, functions: 92.5, - lines: 93.35, - statements: 93.35, + lines: 92.64, + statements: 92.65, }, }, moduleFileExtensions: ['js', 'json', 'jsx', 'ts', 'tsx', 'node'], diff --git a/src/SmartTransactionsController.test.ts b/src/SmartTransactionsController.test.ts index 8656e684..fa666d2b 100644 --- a/src/SmartTransactionsController.test.ts +++ b/src/SmartTransactionsController.test.ts @@ -1,22 +1,28 @@ +import { ControllerMessenger } from '@metamask/base-controller'; import { NetworkType, convertHexToDecimal, ChainId, } from '@metamask/controller-utils'; -import type { NetworkState } from '@metamask/network-controller'; -import { NetworkStatus } from '@metamask/network-controller'; +import { NetworkStatus, type NetworkState } from '@metamask/network-controller'; import { + type TransactionParams, TransactionStatus, TransactionType, } from '@metamask/transaction-controller'; import nock from 'nock'; import * as sinon from 'sinon'; -// eslint-disable-next-line @typescript-eslint/ban-ts-comment -// @ts-ignore import { API_BASE_URL } from './constants'; import SmartTransactionsController, { DEFAULT_INTERVAL, + getDefaultSmartTransactionsControllerState, +} from './SmartTransactionsController'; +import type { + AllowedActions, + AllowedEvents, + SmartTransactionsControllerActions, + SmartTransactionsControllerEvents, } from './SmartTransactionsController'; import { advanceTime, flushPromises, getFakeProvider } from './test-helpers'; import type { SmartTransaction, UnsignedTransaction, Hex } from './types'; @@ -192,17 +198,17 @@ const createSignedTransaction = () => { return '0xf86c098504a817c800825208943535353535353535353535353535353535353535880de0b6b3a76400008025a02b79f322a625d623a2bb2911e0c6b3e7eaf741a7c7c5d2e8c67ef3ff4acf146ca01ae168fea63dc3391b75b586c8a7c0cb55cdf3b8e2e4d8e097957a3a56c6f2c5'; }; -const createTxParams = () => { +const createTxParams = (): TransactionParams => { return { from: addressFrom, to: '0x0000000000000000000000000000000000000000', - value: 0, + value: '0', data: '0x', - nonce: 0, - type: 2, - chainId: 4, - maxFeePerGas: 2310003200, - maxPriorityFeePerGas: 513154852, + nonce: '0', + type: '2', + chainId: '0x4', + maxFeePerGas: '2310003200', + maxPriorityFeePerGas: '513154852', }; }; @@ -319,884 +325,917 @@ const ethereumChainIdDec = parseInt(ChainId.mainnet, 16); const sepoliaChainIdDec = parseInt(ChainId.sepolia, 16); const trackMetaMetricsEventSpy = jest.fn(); -const defaultState = { - smartTransactionsState: { - smartTransactions: { - [ChainId.mainnet]: [], - }, - userOptIn: undefined, - userOptInV2: undefined, - fees: { - approvalTxFees: undefined, - tradeTxFees: undefined, - }, - feesByChainId: { - [ChainId.mainnet]: { - approvalTxFees: undefined, - tradeTxFees: undefined, - }, - [ChainId.sepolia]: { - approvalTxFees: undefined, - tradeTxFees: undefined, - }, - }, - liveness: true, - livenessByChainId: { - [ChainId.mainnet]: true, - [ChainId.sepolia]: true, - }, - }, -}; - -const mockNetworkState = { - selectedNetworkClientId: NetworkType.mainnet, - networkConfigurations: { - id: { - id: 'id', - rpcUrl: 'string', - chainId: ChainId.mainnet, - ticker: 'string', - }, - }, - networksMetadata: { - id: { - EIPS: { - 1155: true, - }, - status: NetworkStatus.Available, - }, - }, -}; describe('SmartTransactionsController', () => { - let smartTransactionsController: SmartTransactionsController; - let networkListener: (networkState: NetworkState) => void; - - beforeEach(() => { - smartTransactionsController = new SmartTransactionsController({ - onNetworkStateChange: ( - listener: (networkState: NetworkState) => void, - ) => { - networkListener = listener; - }, - getNonceLock: jest.fn(() => { - return { - nextNonce: 'nextNonce', - releaseLock: jest.fn(), - }; - }), - confirmExternalTransaction: jest.fn(), - getTransactions: jest.fn(), - trackMetaMetricsEvent: trackMetaMetricsEventSpy, - getNetworkClientById: jest.fn().mockImplementation((networkClientId) => { - switch (networkClientId) { - case NetworkType.mainnet: - return { - configuration: { - chainId: ChainId.mainnet, - }, - provider: getFakeProvider(), - }; - case NetworkType.sepolia: - return { - configuration: { - chainId: ChainId.sepolia, - }, - provider: getFakeProvider(), - }; - default: - throw new Error('Invalid network client id'); - } - }), - getMetaMetricsProps: jest.fn(async () => { - return Promise.resolve({ - accountHardwareType: 'Ledger Hardware', - accountType: 'hardware', - deviceModel: 'ledger', - }); - }), - }); - // eslint-disable-next-line jest/prefer-spy-on - smartTransactionsController.subscribe = jest.fn(); - - networkListener(mockNetworkState); - }); - afterEach(async () => { jest.clearAllMocks(); nock.cleanAll(); - await smartTransactionsController.stop(); }); - it('initializes with default config', () => { - expect(smartTransactionsController.config).toStrictEqual({ - interval: DEFAULT_INTERVAL, - supportedChainIds: [ChainId.mainnet, ChainId.sepolia], - chainId: ChainId.mainnet, - clientId: 'default', + it('initializes with default state', async () => { + const defaultState = getDefaultSmartTransactionsControllerState(); + await withController(({ controller }) => { + expect(controller.state).toStrictEqual({ + ...defaultState, + smartTransactionsState: { + ...defaultState.smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [], + }, + }, + }); }); }); - it('initializes with default state', () => { - expect(smartTransactionsController.state).toStrictEqual(defaultState); - }); - describe('onNetworkChange', () => { - it('is triggered', () => { - networkListener({ - selectedNetworkClientId: NetworkType.sepolia, - networkConfigurations: {}, - networksMetadata: {}, - } as NetworkState); - expect(smartTransactionsController.config.chainId).toBe(ChainId.sepolia); - }); + it('calls poll', async () => { + await withController(({ controller, triggerNetworStateChange }) => { + const checkPollSpy = jest.spyOn(controller, 'checkPoll'); - it('calls poll', () => { - const checkPollSpy = jest.spyOn(smartTransactionsController, 'checkPoll'); - networkListener({ - selectedNetworkClientId: NetworkType.sepolia, - networkConfigurations: {}, - networksMetadata: {}, - } as NetworkState); - expect(checkPollSpy).toHaveBeenCalled(); + triggerNetworStateChange({ + selectedNetworkClientId: NetworkType.sepolia, + networkConfigurations: {}, + networksMetadata: {}, + } as NetworkState); + + expect(checkPollSpy).toHaveBeenCalled(); + }); }); }); describe('checkPoll', () => { - it('calls poll if there is no pending transaction and pending transactions', () => { + it('calls poll if there is no pending transaction and pending transactions', async () => { const pollSpy = jest - .spyOn(smartTransactionsController, 'poll') + .spyOn(SmartTransactionsController.prototype, 'poll') .mockImplementation(async () => { return new Promise(() => ({})); }); - const { smartTransactionsState } = smartTransactionsController.state; + const { smartTransactionsState } = + getDefaultSmartTransactionsControllerState(); const pendingStx = createStateAfterPending(); - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: pendingStx as SmartTransaction[], + await withController( + { + options: { + state: { + smartTransactionsState: { + ...smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: pendingStx as SmartTransaction[], + }, + }, + }, }, }, - }); - expect(pollSpy).toHaveBeenCalled(); + () => { + expect(pollSpy).toHaveBeenCalled(); + }, + ); }); - it('calls stop if there is a timeoutHandle and no pending transactions', () => { - const stopSpy = jest.spyOn(smartTransactionsController, 'stop'); - smartTransactionsController.timeoutHandle = setTimeout(() => ({})); - smartTransactionsController.checkPoll(smartTransactionsController.state); - expect(stopSpy).toHaveBeenCalled(); - clearInterval(smartTransactionsController.timeoutHandle); + it('calls stop if there is a timeoutHandle and no pending transactions', async () => { + await withController(({ controller }) => { + const stopSpy = jest.spyOn(controller, 'stop'); + controller.timeoutHandle = setTimeout(() => ({})); + + controller.checkPoll(controller.state); + + expect(stopSpy).toHaveBeenCalled(); + + clearInterval(controller.timeoutHandle); + }); }); }); describe('poll', () => { it('does not call updateSmartTransactions on unsupported networks', async () => { - const updateSmartTransactionsSpy = jest.spyOn( - smartTransactionsController, - 'updateSmartTransactions', + await withController( + { + options: { + supportedChainIds: [ChainId.mainnet], + }, + }, + ({ controller, triggerNetworStateChange }) => { + const updateSmartTransactionsSpy = jest.spyOn( + controller, + 'updateSmartTransactions', + ); + + expect(updateSmartTransactionsSpy).not.toHaveBeenCalled(); + + triggerNetworStateChange({ + selectedNetworkClientId: NetworkType.sepolia, + networkConfigurations: {}, + networksMetadata: {}, + } as NetworkState); + + expect(updateSmartTransactionsSpy).not.toHaveBeenCalled(); + }, ); - expect(updateSmartTransactionsSpy).not.toHaveBeenCalled(); - smartTransactionsController.config.supportedChainIds = [ChainId.mainnet]; - networkListener({ - selectedNetworkClientId: NetworkType.sepolia, - networkConfigurations: {}, - networksMetadata: {}, - } as NetworkState); - expect(updateSmartTransactionsSpy).not.toHaveBeenCalled(); }); }); describe('updateSmartTransactions', () => { // TODO rewrite this test... updateSmartTransactions is getting called via the checkPoll method which is called whenever state is updated. // this test should be more isolated to the updateSmartTransactions method. - it('calls fetchSmartTransactionsStatus if there are pending transactions', () => { + it('calls fetchSmartTransactionsStatus if there are pending transactions', async () => { const fetchSmartTransactionsStatusSpy = jest - .spyOn(smartTransactionsController, 'fetchSmartTransactionsStatus') + .spyOn( + SmartTransactionsController.prototype, + 'fetchSmartTransactionsStatus', + ) .mockImplementation(async () => { return new Promise(() => ({})); }); - const { smartTransactionsState } = smartTransactionsController.state; + const { smartTransactionsState } = + getDefaultSmartTransactionsControllerState(); const pendingStx = createStateAfterPending(); - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: pendingStx as SmartTransaction[], + await withController( + { + options: { + state: { + smartTransactionsState: { + ...smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: pendingStx as SmartTransaction[], + }, + }, + }, }, }, - }); - expect(fetchSmartTransactionsStatusSpy).toHaveBeenCalled(); + () => { + expect(fetchSmartTransactionsStatusSpy).toHaveBeenCalled(); + }, + ); }); }); describe('trackStxStatusChange', () => { - it('tracks status change if prevSmartTransactions is undefined', () => { - const smartTransaction = { - ...createStateAfterPending()[0], - swapMetaData: {}, - }; - smartTransactionsController.trackStxStatusChange( - smartTransaction as SmartTransaction, - ); - expect(trackMetaMetricsEventSpy).toHaveBeenCalledWith( - expect.objectContaining({ - event: 'STX Status Updated', - category: 'Transactions', - properties: expect.objectContaining({ - stx_status: SmartTransactionStatuses.PENDING, - is_smart_transaction: true, + it('tracks status change if prevSmartTransactions is undefined', async () => { + await withController(({ controller }) => { + const smartTransaction = { + ...createStateAfterPending()[0], + swapMetaData: {}, + } as SmartTransaction; + + controller.trackStxStatusChange(smartTransaction); + + expect(trackMetaMetricsEventSpy).toHaveBeenCalledWith( + expect.objectContaining({ + event: 'STX Status Updated', + category: 'Transactions', + properties: expect.objectContaining({ + stx_status: SmartTransactionStatuses.PENDING, + is_smart_transaction: true, + }), + sensitiveProperties: expect.objectContaining({ + account_hardware_type: 'Ledger Hardware', + account_type: 'hardware', + device_model: 'ledger', + }), }), - sensitiveProperties: expect.objectContaining({ - account_hardware_type: 'Ledger Hardware', - account_type: 'hardware', - device_model: 'ledger', - }), - }), - ); + ); + }); }); - it('does not track if smartTransaction and prevSmartTransaction have the same status', () => { - const smartTransaction = createStateAfterPending()[0]; - smartTransactionsController.trackStxStatusChange( - smartTransaction as SmartTransaction, - smartTransaction as SmartTransaction, - ); - expect(trackMetaMetricsEventSpy).not.toHaveBeenCalled(); + it('does not track if smartTransaction and prevSmartTransaction have the same status', async () => { + await withController(({ controller }) => { + const smartTransaction = createStateAfterPending()[0]; + + controller.trackStxStatusChange( + smartTransaction as SmartTransaction, + smartTransaction as SmartTransaction, + ); + + expect(trackMetaMetricsEventSpy).not.toHaveBeenCalled(); + }); }); - it('tracks status change if smartTransaction and prevSmartTransaction have different statuses', () => { - const smartTransaction = { - ...createStateAfterSuccess()[0], - swapMetaData: {}, - }; - const prevSmartTransaction = { - ...smartTransaction, - status: SmartTransactionStatuses.PENDING, - }; - smartTransactionsController.trackStxStatusChange( - smartTransaction as SmartTransaction, - prevSmartTransaction as SmartTransaction, - ); - expect(trackMetaMetricsEventSpy).toHaveBeenCalledWith( - expect.objectContaining({ - event: 'STX Status Updated', - category: 'Transactions', - properties: expect.objectContaining({ - stx_status: SmartTransactionStatuses.SUCCESS, - is_smart_transaction: true, - }), - sensitiveProperties: expect.objectContaining({ - account_hardware_type: 'Ledger Hardware', - account_type: 'hardware', - device_model: 'ledger', + it('tracks status change if smartTransaction and prevSmartTransaction have different statuses', async () => { + await withController(({ controller }) => { + const smartTransaction = { + ...createStateAfterSuccess()[0], + swapMetaData: {}, + }; + const prevSmartTransaction = { + ...smartTransaction, + status: SmartTransactionStatuses.PENDING, + }; + + controller.trackStxStatusChange( + smartTransaction as SmartTransaction, + prevSmartTransaction as SmartTransaction, + ); + + expect(trackMetaMetricsEventSpy).toHaveBeenCalledWith( + expect.objectContaining({ + event: 'STX Status Updated', + category: 'Transactions', + properties: expect.objectContaining({ + stx_status: SmartTransactionStatuses.SUCCESS, + is_smart_transaction: true, + }), + sensitiveProperties: expect.objectContaining({ + account_hardware_type: 'Ledger Hardware', + account_type: 'hardware', + device_model: 'ledger', + }), }), - }), - ); + ); + }); }); }); describe('setOptInState', () => { - it('sets optIn state', () => { - smartTransactionsController.setOptInState(true); - expect( - smartTransactionsController.state.smartTransactionsState.userOptInV2, - ).toBe(true); - smartTransactionsController.setOptInState(false); - expect( - smartTransactionsController.state.smartTransactionsState.userOptInV2, - ).toBe(false); - smartTransactionsController.setOptInState(undefined); - expect( - smartTransactionsController.state.smartTransactionsState.userOptInV2, - ).toBeUndefined(); + it('sets optIn state', async () => { + await withController(({ controller }) => { + controller.setOptInState(true); + + expect(controller.state.smartTransactionsState.userOptInV2).toBe(true); + + controller.setOptInState(false); + + expect(controller.state.smartTransactionsState.userOptInV2).toBe(false); + + controller.setOptInState(null); + + expect(controller.state.smartTransactionsState.userOptInV2).toBeNull(); + }); }); }); describe('clearFees', () => { it('clears fees', async () => { - const tradeTx = createUnsignedTransaction(ethereumChainIdDec); - const approvalTx = createUnsignedTransaction(ethereumChainIdDec); - const getFeesApiResponse = createGetFeesApiResponse(); - nock(API_BASE_URL) - .post(`/networks/${ethereumChainIdDec}/getFees`) - .reply(200, getFeesApiResponse); - const fees = await smartTransactionsController.getFees( - tradeTx, - approvalTx, - ); - expect(fees).toMatchObject({ - approvalTxFees: getFeesApiResponse.txs[0], - tradeTxFees: getFeesApiResponse.txs[1], - }); - smartTransactionsController.clearFees(); - expect( - smartTransactionsController.state.smartTransactionsState.fees, - ).toStrictEqual({ - approvalTxFees: undefined, - tradeTxFees: undefined, + await withController(async ({ controller }) => { + const tradeTx = createUnsignedTransaction(ethereumChainIdDec); + const approvalTx = createUnsignedTransaction(ethereumChainIdDec); + const getFeesApiResponse = createGetFeesApiResponse(); + nock(API_BASE_URL) + .post(`/networks/${ethereumChainIdDec}/getFees`) + .reply(200, getFeesApiResponse); + + const fees = await controller.getFees(tradeTx, approvalTx); + + expect(fees).toMatchObject({ + approvalTxFees: getFeesApiResponse.txs[0], + tradeTxFees: getFeesApiResponse.txs[1], + }); + + controller.clearFees(); + + expect(controller.state.smartTransactionsState.fees).toStrictEqual({ + approvalTxFees: null, + tradeTxFees: null, + }); }); }); }); describe('getFees', () => { it('gets unsigned transactions and estimates based on an unsigned transaction', async () => { - const tradeTx = createUnsignedTransaction(ethereumChainIdDec); - const approvalTx = createUnsignedTransaction(ethereumChainIdDec); - const getFeesApiResponse = createGetFeesApiResponse(); - nock(API_BASE_URL) - .post(`/networks/${ethereumChainIdDec}/getFees`) - .reply(200, getFeesApiResponse); - const fees = await smartTransactionsController.getFees( - tradeTx, - approvalTx, - ); - expect(fees).toMatchObject({ - approvalTxFees: getFeesApiResponse.txs[0], - tradeTxFees: getFeesApiResponse.txs[1], + await withController(async ({ controller }) => { + const tradeTx = createUnsignedTransaction(ethereumChainIdDec); + const approvalTx = createUnsignedTransaction(ethereumChainIdDec); + const getFeesApiResponse = createGetFeesApiResponse(); + nock(API_BASE_URL) + .post(`/networks/${ethereumChainIdDec}/getFees`) + .reply(200, getFeesApiResponse); + + const fees = await controller.getFees(tradeTx, approvalTx); + + expect(fees).toMatchObject({ + approvalTxFees: getFeesApiResponse.txs[0], + tradeTxFees: getFeesApiResponse.txs[1], + }); }); }); it('gets estimates based on an unsigned transaction with an undefined nonce', async () => { - const tradeTx: UnsignedTransaction = - createUnsignedTransaction(ethereumChainIdDec); - tradeTx.nonce = undefined; - const getFeesApiResponse = createGetFeesApiResponse(); - nock(API_BASE_URL) - .post(`/networks/${ethereumChainIdDec}/getFees`) - .reply(200, getFeesApiResponse); - const fees = await smartTransactionsController.getFees(tradeTx); - expect(fees).toMatchObject({ - tradeTxFees: getFeesApiResponse.txs[0], + await withController(async ({ controller }) => { + const tradeTx: UnsignedTransaction = + createUnsignedTransaction(ethereumChainIdDec); + tradeTx.nonce = undefined; + const getFeesApiResponse = createGetFeesApiResponse(); + nock(API_BASE_URL) + .post(`/networks/${ethereumChainIdDec}/getFees`) + .reply(200, getFeesApiResponse); + + const fees = await controller.getFees(tradeTx); + + expect(fees).toMatchObject({ + tradeTxFees: getFeesApiResponse.txs[0], + }); }); }); it('should add fee data to feesByChainId state using the networkClientId passed in to identify the appropriate chain', async () => { - const tradeTx = createUnsignedTransaction(sepoliaChainIdDec); - const approvalTx = createUnsignedTransaction(sepoliaChainIdDec); - const getFeesApiResponse = createGetFeesApiResponse(); - nock(API_BASE_URL) - .post(`/networks/${sepoliaChainIdDec}/getFees`) - .reply(200, getFeesApiResponse); - - expect( - smartTransactionsController.state.smartTransactionsState.feesByChainId, - ).toStrictEqual(defaultState.smartTransactionsState.feesByChainId); - - await smartTransactionsController.getFees(tradeTx, approvalTx, { - networkClientId: NetworkType.sepolia, - }); + await withController(async ({ controller }) => { + const tradeTx = createUnsignedTransaction(sepoliaChainIdDec); + const approvalTx = createUnsignedTransaction(sepoliaChainIdDec); + const getFeesApiResponse = createGetFeesApiResponse(); + nock(API_BASE_URL) + .post(`/networks/${sepoliaChainIdDec}/getFees`) + .reply(200, getFeesApiResponse); - expect( - smartTransactionsController.state.smartTransactionsState.feesByChainId, - ).toMatchObject({ - [ChainId.mainnet]: { - approvalTxFees: undefined, - tradeTxFees: undefined, - }, - [ChainId.sepolia]: { - approvalTxFees: getFeesApiResponse.txs[0], - tradeTxFees: getFeesApiResponse.txs[1], - }, + expect( + controller.state.smartTransactionsState.feesByChainId, + ).toStrictEqual( + getDefaultSmartTransactionsControllerState().smartTransactionsState + .feesByChainId, + ); + + await controller.getFees(tradeTx, approvalTx, { + networkClientId: NetworkType.sepolia, + }); + + expect( + controller.state.smartTransactionsState.feesByChainId, + ).toMatchObject({ + [ChainId.mainnet]: { + approvalTxFees: null, + tradeTxFees: null, + }, + [ChainId.sepolia]: { + approvalTxFees: getFeesApiResponse.txs[0], + tradeTxFees: getFeesApiResponse.txs[1], + }, + }); }); }); }); describe('submitSignedTransactions', () => { beforeEach(() => { - // eslint-disable-next-line jest/prefer-spy-on - smartTransactionsController.checkPoll = jest.fn(() => ({})); + jest + .spyOn(SmartTransactionsController.prototype, 'checkPoll') + .mockImplementation(() => ({})); }); it('submits a smart transaction with signed transactions', async () => { - const signedTransaction = createSignedTransaction(); - const signedCanceledTransaction = createSignedCanceledTransaction(); - const submitTransactionsApiResponse = - createSubmitTransactionsApiResponse(); // It has uuid. - nock(API_BASE_URL) - .post( - `/networks/${ethereumChainIdDec}/submitTransactions?stxControllerVersion=${packageJson.version}`, - ) - .reply(200, submitTransactionsApiResponse); - await smartTransactionsController.submitSignedTransactions({ - signedTransactions: [signedTransaction], - signedCanceledTransactions: [signedCanceledTransaction], - txParams: createTxParams(), + await withController(async ({ controller }) => { + const signedTransaction = createSignedTransaction(); + const signedCanceledTransaction = createSignedCanceledTransaction(); + const submitTransactionsApiResponse = + createSubmitTransactionsApiResponse(); // It has uuid. + nock(API_BASE_URL) + .post( + `/networks/${ethereumChainIdDec}/submitTransactions?stxControllerVersion=${packageJson.version}`, + ) + .reply(200, submitTransactionsApiResponse); + + await controller.submitSignedTransactions({ + signedTransactions: [signedTransaction], + signedCanceledTransactions: [signedCanceledTransaction], + txParams: createTxParams(), + }); + + const submittedSmartTransaction = + controller.state.smartTransactionsState.smartTransactions[ + ChainId.mainnet + ][0]; + expect(submittedSmartTransaction.uuid).toBe( + 'dP23W7c2kt4FK9TmXOkz1UM2F20', + ); + expect(submittedSmartTransaction.accountHardwareType).toBe( + 'Ledger Hardware', + ); + expect(submittedSmartTransaction.accountType).toBe('hardware'); + expect(submittedSmartTransaction.deviceModel).toBe('ledger'); }); - const submittedSmartTransaction = - smartTransactionsController.state.smartTransactionsState - .smartTransactions[ChainId.mainnet][0]; - expect(submittedSmartTransaction.uuid).toBe( - 'dP23W7c2kt4FK9TmXOkz1UM2F20', - ); - expect(submittedSmartTransaction.accountHardwareType).toBe( - 'Ledger Hardware', - ); - expect(submittedSmartTransaction.accountType).toBe('hardware'); - expect(submittedSmartTransaction.deviceModel).toBe('ledger'); }); }); describe('fetchSmartTransactionsStatus', () => { beforeEach(() => { - // eslint-disable-next-line jest/prefer-spy-on - smartTransactionsController.checkPoll = jest.fn(() => ({})); + jest + .spyOn(SmartTransactionsController.prototype, 'checkPoll') + .mockImplementation(() => ({})); }); it('fetches a pending status for a single smart transaction via batchStatus API', async () => { - const uuids = ['uuid1']; - const pendingBatchStatusApiResponse = - createPendingBatchStatusApiResponse(); - nock(API_BASE_URL) - .get(`/networks/${ethereumChainIdDec}/batchStatus?uuids=uuid1`) - .reply(200, pendingBatchStatusApiResponse); - - await smartTransactionsController.fetchSmartTransactionsStatus(uuids, { - networkClientId: NetworkType.mainnet, - }); - const pendingState = createStateAfterPending()[0]; - const pendingTransaction = { ...pendingState, history: [pendingState] }; - expect(smartTransactionsController.state).toMatchObject({ - smartTransactionsState: { - smartTransactions: { - [ChainId.mainnet]: [pendingTransaction], - }, - userOptIn: undefined, - userOptInV2: undefined, - fees: { - approvalTxFees: undefined, - tradeTxFees: undefined, - }, - feesByChainId: { - [ChainId.mainnet]: { - approvalTxFees: undefined, - tradeTxFees: undefined, + await withController(async ({ controller }) => { + const uuids = ['uuid1']; + const pendingBatchStatusApiResponse = + createPendingBatchStatusApiResponse(); + nock(API_BASE_URL) + .get(`/networks/${ethereumChainIdDec}/batchStatus?uuids=uuid1`) + .reply(200, pendingBatchStatusApiResponse); + + await controller.fetchSmartTransactionsStatus(uuids, { + networkClientId: NetworkType.mainnet, + }); + + const pendingState = createStateAfterPending()[0]; + const pendingTransaction = { ...pendingState, history: [pendingState] }; + expect(controller.state).toMatchObject({ + smartTransactionsState: { + smartTransactions: { + [ChainId.mainnet]: [pendingTransaction], }, - [ChainId.sepolia]: { - approvalTxFees: undefined, - tradeTxFees: undefined, + userOptIn: null, + userOptInV2: null, + fees: { + approvalTxFees: null, + tradeTxFees: null, + }, + feesByChainId: { + [ChainId.mainnet]: { + approvalTxFees: null, + tradeTxFees: null, + }, + [ChainId.sepolia]: { + approvalTxFees: null, + tradeTxFees: null, + }, + }, + liveness: true, + livenessByChainId: { + [ChainId.mainnet]: true, + [ChainId.sepolia]: true, }, }, - liveness: true, - livenessByChainId: { - [ChainId.mainnet]: true, - [ChainId.sepolia]: true, - }, - }, + }); }); }); it('fetches a success status for a single smart transaction via batchStatus API', async () => { - const uuids = ['uuid2']; - const successBatchStatusApiResponse = - createSuccessBatchStatusApiResponse(); - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsController.state.smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: createStateAfterPending() as SmartTransaction[], + await withController( + { + options: { + state: { + smartTransactionsState: { + ...getDefaultSmartTransactionsControllerState() + .smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: + createStateAfterPending() as SmartTransaction[], + }, + }, + }, }, }, - }); - - nock(API_BASE_URL) - .get(`/networks/${ethereumChainIdDec}/batchStatus?uuids=uuid2`) - .reply(200, successBatchStatusApiResponse); + async ({ controller }) => { + const uuids = ['uuid2']; + const successBatchStatusApiResponse = + createSuccessBatchStatusApiResponse(); + nock(API_BASE_URL) + .get(`/networks/${ethereumChainIdDec}/batchStatus?uuids=uuid2`) + .reply(200, successBatchStatusApiResponse); + + await controller.fetchSmartTransactionsStatus(uuids, { + networkClientId: NetworkType.mainnet, + }); - await smartTransactionsController.fetchSmartTransactionsStatus(uuids, { - networkClientId: NetworkType.mainnet, - }); - const successState = createStateAfterSuccess()[0]; - const successTransaction = { ...successState, history: [successState] }; - expect(smartTransactionsController.state).toMatchObject({ - smartTransactionsState: { - smartTransactions: { - [ChainId.mainnet]: [ - ...createStateAfterPending(), - ...[successTransaction], - ], - }, - userOptIn: undefined, - userOptInV2: undefined, - fees: { - approvalTxFees: undefined, - tradeTxFees: undefined, - }, - liveness: true, - feesByChainId: { - [ChainId.mainnet]: { - approvalTxFees: undefined, - tradeTxFees: undefined, - }, - [ChainId.sepolia]: { - approvalTxFees: undefined, - tradeTxFees: undefined, + const [successState] = createStateAfterSuccess(); + const successTransaction = { + ...successState, + history: [successState], + }; + expect(controller.state).toMatchObject({ + smartTransactionsState: { + smartTransactions: { + [ChainId.mainnet]: [ + ...createStateAfterPending(), + ...[successTransaction], + ], + }, + userOptIn: null, + userOptInV2: null, + fees: { + approvalTxFees: null, + tradeTxFees: null, + }, + liveness: true, + feesByChainId: { + [ChainId.mainnet]: { + approvalTxFees: null, + tradeTxFees: null, + }, + [ChainId.sepolia]: { + approvalTxFees: null, + tradeTxFees: null, + }, + }, + livenessByChainId: { + [ChainId.mainnet]: true, + [ChainId.sepolia]: true, + }, }, - }, - livenessByChainId: { - [ChainId.mainnet]: true, - [ChainId.sepolia]: true, - }, + }); }, - }); + ); }); }); describe('fetchLiveness', () => { it('fetches a liveness for Smart Transactions API', async () => { - const successLivenessApiResponse = createSuccessLivenessApiResponse(); - nock(API_BASE_URL) - .get(`/networks/${ethereumChainIdDec}/health`) - .reply(200, successLivenessApiResponse); - const liveness = await smartTransactionsController.fetchLiveness(); - expect(liveness).toBe(true); + await withController(async ({ controller }) => { + const successLivenessApiResponse = createSuccessLivenessApiResponse(); + nock(API_BASE_URL) + .get(`/networks/${ethereumChainIdDec}/health`) + .reply(200, successLivenessApiResponse); + + const liveness = await controller.fetchLiveness(); + + expect(liveness).toBe(true); + }); }); it('fetches liveness and sets in feesByChainId state for the Smart Transactions API for the chainId of the networkClientId passed in', async () => { - nock(API_BASE_URL) - .get(`/networks/${sepoliaChainIdDec}/health`) - .replyWithError('random error'); - - expect( - smartTransactionsController.state.smartTransactionsState - .livenessByChainId, - ).toStrictEqual({ - [ChainId.mainnet]: true, - [ChainId.sepolia]: true, - }); + await withController(async ({ controller }) => { + nock(API_BASE_URL) + .get(`/networks/${sepoliaChainIdDec}/health`) + .replyWithError('random error'); - await smartTransactionsController.fetchLiveness({ - networkClientId: NetworkType.sepolia, - }); + expect( + controller.state.smartTransactionsState.livenessByChainId, + ).toStrictEqual({ + [ChainId.mainnet]: true, + [ChainId.sepolia]: true, + }); + + await controller.fetchLiveness({ + networkClientId: NetworkType.sepolia, + }); - expect( - smartTransactionsController.state.smartTransactionsState - .livenessByChainId, - ).toStrictEqual({ - [ChainId.mainnet]: true, - [ChainId.sepolia]: false, + expect( + controller.state.smartTransactionsState.livenessByChainId, + ).toStrictEqual({ + [ChainId.mainnet]: true, + [ChainId.sepolia]: false, + }); }); }); }); describe('updateSmartTransaction', () => { beforeEach(() => { - // eslint-disable-next-line jest/prefer-spy-on - smartTransactionsController.checkPoll = jest.fn(() => ({})); + jest + .spyOn(SmartTransactionsController.prototype, 'checkPoll') + .mockImplementation(() => ({})); }); - it('updates smart transaction based on uuid', () => { + it('updates smart transaction based on uuid', async () => { + const { smartTransactionsState } = + getDefaultSmartTransactionsControllerState(); const pendingStx = { ...createStateAfterPending()[0], history: testHistory, }; - const { smartTransactionsState } = smartTransactionsController.state; - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: [pendingStx] as SmartTransaction[], + await withController( + { + options: { + state: { + smartTransactionsState: { + ...smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [pendingStx] as SmartTransaction[], + }, + }, + }, }, }, - }); - const updateTransaction = { - ...pendingStx, - status: 'test', - }; - smartTransactionsController.updateSmartTransaction( - updateTransaction as SmartTransaction, - { - networkClientId: NetworkType.mainnet, + ({ controller }) => { + const updateTransaction = { + ...pendingStx, + status: 'test', + }; + + controller.updateSmartTransaction( + updateTransaction as SmartTransaction, + { + networkClientId: NetworkType.mainnet, + }, + ); + + expect( + controller.state.smartTransactionsState.smartTransactions[ + ChainId.mainnet + ][0].status, + ).toBe('test'); }, ); - - expect( - smartTransactionsController.state.smartTransactionsState - .smartTransactions[ChainId.mainnet][0].status, - ).toBe('test'); }); it('confirms a smart transaction that has status success', async () => { - const { smartTransactionsState } = smartTransactionsController.state; + const { smartTransactionsState } = + getDefaultSmartTransactionsControllerState(); const pendingStx = { ...createStateAfterPending()[0], history: testHistory, }; - - jest - .spyOn(smartTransactionsController, 'getRegularTransactions') - .mockImplementation(() => { - return [createTransactionMeta()]; - }); - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: [pendingStx] as SmartTransaction[], - }, - }, + const confirmExternalTransactionSpy = jest.fn(); + const getRegularTransactionsSpy = jest.fn().mockImplementation(() => { + return [createTransactionMeta()]; }); - const updateTransaction = { - ...pendingStx, - statusMetadata: { - ...pendingStx.statusMetadata, - minedHash: txHash, + await withController( + { + options: { + state: { + smartTransactionsState: { + ...smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [pendingStx] as SmartTransaction[], + }, + }, + }, + confirmExternalTransaction: confirmExternalTransactionSpy, + getTransactions: getRegularTransactionsSpy, + }, }, - status: SmartTransactionStatuses.SUCCESS, - }; + async ({ controller }) => { + const updateTransaction = { + ...pendingStx, + statusMetadata: { + ...pendingStx.statusMetadata, + minedHash: txHash, + }, + status: SmartTransactionStatuses.SUCCESS, + }; - smartTransactionsController.updateSmartTransaction( - updateTransaction as SmartTransaction, - { - networkClientId: NetworkType.mainnet, + controller.updateSmartTransaction( + updateTransaction as SmartTransaction, + { + networkClientId: NetworkType.mainnet, + }, + ); + await flushPromises(); + + expect(confirmExternalTransactionSpy).toHaveBeenCalledTimes(1); + expect( + controller.state.smartTransactionsState.smartTransactions[ + ChainId.mainnet + ], + ).toStrictEqual([ + { + ...updateTransaction, + confirmed: true, + }, + ]); }, ); - await flushPromises(); - expect( - smartTransactionsController.confirmExternalTransaction, - ).toHaveBeenCalledTimes(1); - expect( - smartTransactionsController.state.smartTransactionsState - .smartTransactions[ChainId.mainnet], - ).toStrictEqual([ - { - ...updateTransaction, - confirmed: true, - }, - ]); }); it('confirms a smart transaction that was not found in the list of regular transactions', async () => { - const { smartTransactionsState } = smartTransactionsController.state; + const { smartTransactionsState } = + getDefaultSmartTransactionsControllerState(); const pendingStx = { ...createStateAfterPending()[0], history: testHistory, }; - - jest - .spyOn(smartTransactionsController, 'getRegularTransactions') - .mockImplementation(() => { - return []; - }); - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: [pendingStx] as SmartTransaction[], - }, - }, + const confirmExternalTransactionSpy = jest.fn(); + const getRegularTransactionsSpy = jest.fn().mockImplementation(() => { + return []; }); - const updateTransaction = { - ...pendingStx, - statusMetadata: { - ...pendingStx.statusMetadata, - minedHash: txHash, + await withController( + { + options: { + state: { + smartTransactionsState: { + ...smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [pendingStx] as SmartTransaction[], + }, + }, + }, + confirmExternalTransaction: confirmExternalTransactionSpy, + getTransactions: getRegularTransactionsSpy, + }, }, - status: SmartTransactionStatuses.SUCCESS, - }; + async ({ controller }) => { + const updateTransaction = { + ...pendingStx, + statusMetadata: { + ...pendingStx.statusMetadata, + minedHash: txHash, + }, + status: SmartTransactionStatuses.SUCCESS, + }; - smartTransactionsController.updateSmartTransaction( - updateTransaction as SmartTransaction, - { - networkClientId: NetworkType.mainnet, + controller.updateSmartTransaction( + updateTransaction as SmartTransaction, + { + networkClientId: NetworkType.mainnet, + }, + ); + await flushPromises(); + + expect(confirmExternalTransactionSpy).toHaveBeenCalledTimes(1); + expect( + controller.state.smartTransactionsState.smartTransactions[ + ChainId.mainnet + ], + ).toStrictEqual([ + { + ...updateTransaction, + confirmed: true, + }, + ]); }, ); - await flushPromises(); - expect( - smartTransactionsController.confirmExternalTransaction, - ).toHaveBeenCalledTimes(1); - expect( - smartTransactionsController.state.smartTransactionsState - .smartTransactions[ChainId.mainnet], - ).toStrictEqual([ - { - ...updateTransaction, - confirmed: true, - }, - ]); }); it('confirms a smart transaction that does not have a minedHash', async () => { - const { smartTransactionsState } = smartTransactionsController.state; + const { smartTransactionsState } = + getDefaultSmartTransactionsControllerState(); const pendingStx = { ...createStateAfterPending()[0], history: testHistory, }; - - jest - .spyOn(smartTransactionsController, 'getRegularTransactions') - .mockImplementation(() => { - return [createTransactionMeta(TransactionStatus.confirmed)]; - }); - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: [pendingStx] as SmartTransaction[], - }, - }, + const confirmExternalTransactionSpy = jest.fn(); + const getRegularTransactionsSpy = jest.fn().mockImplementation(() => { + return [createTransactionMeta(TransactionStatus.confirmed)]; }); - const updateTransaction = { - ...pendingStx, - statusMetadata: { - ...pendingStx.statusMetadata, - minedHash: '', + await withController( + { + options: { + state: { + smartTransactionsState: { + ...smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [pendingStx] as SmartTransaction[], + }, + }, + }, + confirmExternalTransaction: confirmExternalTransactionSpy, + getTransactions: getRegularTransactionsSpy, + }, }, - status: SmartTransactionStatuses.SUCCESS, - }; + async ({ controller }) => { + const updateTransaction = { + ...pendingStx, + statusMetadata: { + ...pendingStx.statusMetadata, + minedHash: '', + }, + status: SmartTransactionStatuses.SUCCESS, + }; - smartTransactionsController.updateSmartTransaction( - updateTransaction as SmartTransaction, - { - networkClientId: NetworkType.mainnet, + controller.updateSmartTransaction( + updateTransaction as SmartTransaction, + { + networkClientId: NetworkType.mainnet, + }, + ); + await flushPromises(); + + expect(confirmExternalTransactionSpy).toHaveBeenCalledTimes(1); + expect( + controller.state.smartTransactionsState.smartTransactions[ + ChainId.mainnet + ], + ).toStrictEqual([ + { + ...updateTransaction, + confirmed: true, + }, + ]); }, ); - await flushPromises(); - expect( - smartTransactionsController.confirmExternalTransaction, - ).toHaveBeenCalledTimes(1); - expect( - smartTransactionsController.state.smartTransactionsState - .smartTransactions[ChainId.mainnet], - ).toStrictEqual([ - { - ...updateTransaction, - confirmed: true, - }, - ]); }); it('does not call the "confirmExternalTransaction" fn if a tx is already confirmed', async () => { - const { smartTransactionsState } = smartTransactionsController.state; + const { smartTransactionsState } = + getDefaultSmartTransactionsControllerState(); const pendingStx = { ...createStateAfterPending()[0], history: testHistory, }; - jest - .spyOn(smartTransactionsController, 'getRegularTransactions') - .mockImplementation(() => { - return [createTransactionMeta(TransactionStatus.confirmed)]; - }); - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: [pendingStx] as SmartTransaction[], - }, - }, + const confirmExternalTransactionSpy = jest.fn(); + const getRegularTransactionsSpy = jest.fn().mockImplementation(() => { + return [createTransactionMeta(TransactionStatus.confirmed)]; }); - const updateTransaction = { - ...pendingStx, - status: SmartTransactionStatuses.SUCCESS, - statusMetadata: { - ...pendingStx.statusMetadata, - minedHash: txHash, + await withController( + { + options: { + state: { + smartTransactionsState: { + ...smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [pendingStx] as SmartTransaction[], + }, + }, + }, + confirmExternalTransaction: confirmExternalTransactionSpy, + getTransactions: getRegularTransactionsSpy, + }, }, - }; + async ({ controller }) => { + const updateTransaction = { + ...pendingStx, + status: SmartTransactionStatuses.SUCCESS, + statusMetadata: { + ...pendingStx.statusMetadata, + minedHash: txHash, + }, + }; - smartTransactionsController.updateSmartTransaction( - updateTransaction as SmartTransaction, - { - networkClientId: NetworkType.mainnet, + controller.updateSmartTransaction( + updateTransaction as SmartTransaction, + { + networkClientId: NetworkType.mainnet, + }, + ); + await flushPromises(); + + expect(confirmExternalTransactionSpy).not.toHaveBeenCalled(); + expect( + controller.state.smartTransactionsState.smartTransactions[ + ChainId.mainnet + ], + ).toStrictEqual([ + { + ...updateTransaction, + confirmed: true, + }, + ]); }, ); - await flushPromises(); - expect( - smartTransactionsController.confirmExternalTransaction, - ).not.toHaveBeenCalled(); - expect( - smartTransactionsController.state.smartTransactionsState - .smartTransactions[ChainId.mainnet], - ).toStrictEqual([ - { - ...updateTransaction, - confirmed: true, - }, - ]); }); it('does not call the "confirmExternalTransaction" fn if a tx is already submitted', async () => { - const { smartTransactionsState } = smartTransactionsController.state; + const { smartTransactionsState } = + getDefaultSmartTransactionsControllerState(); const pendingStx = { ...createStateAfterPending()[0], history: testHistory, }; - jest - .spyOn(smartTransactionsController, 'getRegularTransactions') - .mockImplementation(() => { - return [createTransactionMeta(TransactionStatus.submitted)]; - }); - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: [pendingStx] as SmartTransaction[], - }, - }, + const confirmExternalTransactionSpy = jest.fn(); + const getRegularTransactionsSpy = jest.fn().mockImplementation(() => { + return [createTransactionMeta(TransactionStatus.submitted)]; }); - const updateTransaction = { - ...pendingStx, - status: SmartTransactionStatuses.SUCCESS, - statusMetadata: { - ...pendingStx.statusMetadata, - minedHash: txHash, + await withController( + { + options: { + state: { + smartTransactionsState: { + ...smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [pendingStx] as SmartTransaction[], + }, + }, + }, + confirmExternalTransaction: confirmExternalTransactionSpy, + getTransactions: getRegularTransactionsSpy, + }, }, - }; + async ({ controller }) => { + const updateTransaction = { + ...pendingStx, + status: SmartTransactionStatuses.SUCCESS, + statusMetadata: { + ...pendingStx.statusMetadata, + minedHash: txHash, + }, + }; - smartTransactionsController.updateSmartTransaction( - updateTransaction as SmartTransaction, - { - networkClientId: NetworkType.mainnet, + controller.updateSmartTransaction( + updateTransaction as SmartTransaction, + { + networkClientId: NetworkType.mainnet, + }, + ); + await flushPromises(); + + expect(confirmExternalTransactionSpy).not.toHaveBeenCalled(); + expect( + controller.state.smartTransactionsState.smartTransactions[ + ChainId.mainnet + ], + ).toStrictEqual([ + { + ...updateTransaction, + confirmed: true, + }, + ]); }, ); - await flushPromises(); - expect( - smartTransactionsController.confirmExternalTransaction, - ).not.toHaveBeenCalled(); - expect( - smartTransactionsController.state.smartTransactionsState - .smartTransactions[ChainId.mainnet], - ).toStrictEqual([ - { - ...updateTransaction, - confirmed: true, - }, - ]); }); }); describe('cancelSmartTransaction', () => { it('sends POST call to Transactions API', async () => { - const apiCall = nock(API_BASE_URL) - .post(`/networks/${ethereumChainIdDec}/cancel`) - .reply(200, { message: 'successful' }); - await smartTransactionsController.cancelSmartTransaction('uuid1'); - expect(apiCall.isDone()).toBe(true); - }); - }); + await withController(async ({ controller }) => { + const apiCall = nock(API_BASE_URL) + .post(`/networks/${ethereumChainIdDec}/cancel`) + .reply(200, { message: 'successful' }); - describe('setStatusRefreshInterval', () => { - it('sets refresh interval if different', () => { - smartTransactionsController.setStatusRefreshInterval(100); - expect(smartTransactionsController.config.interval).toBe(100); - }); + await controller.cancelSmartTransaction('uuid1'); - it('does not set refresh interval if they are the same', () => { - const configureSpy = jest.spyOn(smartTransactionsController, 'configure'); - smartTransactionsController.setStatusRefreshInterval(DEFAULT_INTERVAL); - expect(configureSpy).toHaveBeenCalledTimes(0); + expect(apiCall.isDone()).toBe(true); + }); }); }); describe('getTransactions', () => { beforeEach(() => { - // eslint-disable-next-line jest/prefer-spy-on - smartTransactionsController.checkPoll = jest.fn(() => ({})); + jest + .spyOn(SmartTransactionsController.prototype, 'checkPoll') + .mockImplementation(() => ({})); }); - it('retrieves smart transactions by addressFrom and status', () => { - const { smartTransactionsState } = smartTransactionsController.state; + it('retrieves smart transactions by addressFrom and status', async () => { + const { smartTransactionsState } = + getDefaultSmartTransactionsControllerState(); const pendingStx = { ...createStateAfterPending()[0], history: testHistory, @@ -1204,95 +1243,143 @@ describe('SmartTransactionsController', () => { from: addressFrom, }, }; - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: [pendingStx] as SmartTransaction[], + await withController( + { + options: { + state: { + smartTransactionsState: { + ...smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [pendingStx] as SmartTransaction[], + }, + }, + }, }, }, - }); - const pendingStxs = smartTransactionsController.getTransactions({ - addressFrom, - status: SmartTransactionStatuses.PENDING, - }); - expect(pendingStxs).toStrictEqual([pendingStx]); + ({ controller }) => { + const pendingStxs = controller.getTransactions({ + addressFrom, + status: SmartTransactionStatuses.PENDING, + }); + + expect(pendingStxs).toStrictEqual([pendingStx]); + }, + ); }); - it('returns empty array if there are no smart transactions', () => { - const transactions = smartTransactionsController.getTransactions({ - addressFrom, - status: SmartTransactionStatuses.PENDING, + it('returns empty array if there are no smart transactions', async () => { + await withController(({ controller }) => { + const transactions = controller.getTransactions({ + addressFrom, + status: SmartTransactionStatuses.PENDING, + }); + + expect(transactions).toStrictEqual([]); }); - expect(transactions).toStrictEqual([]); }); }); describe('getSmartTransactionByMinedTxHash', () => { - it('retrieves a smart transaction by a mined tx hash', () => { - const { smartTransactionsState } = smartTransactionsController.state; - const successfulSmartTransaction = createStateAfterSuccess()[0]; - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: [ - successfulSmartTransaction, - ] as SmartTransaction[], + it('retrieves a smart transaction by a mined tx hash', async () => { + const { smartTransactionsState } = + getDefaultSmartTransactionsControllerState(); + const [successfulSmartTransaction] = createStateAfterSuccess(); + await withController( + { + options: { + state: { + smartTransactionsState: { + ...smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [ + successfulSmartTransaction, + ] as SmartTransaction[], + }, + }, + }, }, }, - }); - const smartTransaction = - smartTransactionsController.getSmartTransactionByMinedTxHash( - successfulSmartTransaction.statusMetadata.minedHash, - ); - expect(smartTransaction).toStrictEqual(successfulSmartTransaction); + ({ controller }) => { + const smartTransaction = controller.getSmartTransactionByMinedTxHash( + successfulSmartTransaction.statusMetadata.minedHash, + ); + + expect(smartTransaction).toStrictEqual(successfulSmartTransaction); + }, + ); }); - it('returns undefined if there is no smart transaction found by tx hash', () => { - const { smartTransactionsState } = smartTransactionsController.state; - const successfulSmartTransaction = createStateAfterSuccess()[0]; - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: [ - successfulSmartTransaction, - ] as SmartTransaction[], + it('returns undefined if there is no smart transaction found by tx hash', async () => { + const { smartTransactionsState } = + getDefaultSmartTransactionsControllerState(); + const [successfulSmartTransaction] = createStateAfterSuccess(); + await withController( + { + options: { + state: { + smartTransactionsState: { + ...smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [ + successfulSmartTransaction, + ] as SmartTransaction[], + }, + }, + }, }, }, - }); - const smartTransaction = - smartTransactionsController.getSmartTransactionByMinedTxHash( - 'nonStxTxHash', - ); - expect(smartTransaction).toBeUndefined(); + ({ controller }) => { + const smartTransaction = + controller.getSmartTransactionByMinedTxHash('nonStxTxHash'); + + expect(smartTransaction).toBeUndefined(); + }, + ); }); }); describe('isNewSmartTransaction', () => { - it('returns true if it is a new STX', () => { - const actual = - smartTransactionsController.isNewSmartTransaction('newUuid'); - expect(actual).toBe(true); + beforeEach(() => { + jest + .spyOn(SmartTransactionsController.prototype, 'checkPoll') + .mockImplementation(() => ({})); }); - it('returns false if an STX already exist', () => { - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsController.state.smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: createStateAfterPending() as SmartTransaction[], + it('returns true if it is a new STX', async () => { + await withController(({ controller }) => { + const actual = controller.isNewSmartTransaction('newUuid'); + + expect(actual).toBe(true); + }); + }); + + it('returns false if an STX already exist', async () => { + await withController( + { + options: { + state: { + smartTransactionsState: { + ...getDefaultSmartTransactionsControllerState() + .smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: + createStateAfterPending() as SmartTransaction[], + }, + }, + }, }, }, - }); - const actual = smartTransactionsController.isNewSmartTransaction('uuid1'); - expect(actual).toBe(false); + ({ controller }) => { + const actual = controller.isNewSmartTransaction('uuid1'); + expect(actual).toBe(false); + }, + ); }); }); describe('startPollingByNetworkClientId', () => { let clock: sinon.SinonFakeTimers; + beforeEach(() => { clock = sinon.useFakeTimers(); }); @@ -1304,231 +1391,472 @@ describe('SmartTransactionsController', () => { it('starts and stops calling smart transactions batch status api endpoint with the correct chainId at the polling interval', async () => { // mock this to a noop because it causes an extra fetch call to the API upon state changes jest - .spyOn(smartTransactionsController, 'checkPoll') + .spyOn(SmartTransactionsController.prototype, 'checkPoll') .mockImplementation(() => undefined); - - // pending transactions in state are required to test polling - smartTransactionsController.update({ - smartTransactionsState: { - ...defaultState.smartTransactionsState, - smartTransactions: { - [ChainId.mainnet]: [ - { - uuid: 'uuid1', - status: 'pending', - cancellable: true, - chainId: ChainId.mainnet, - }, - ], - [ChainId.sepolia]: [ - { - uuid: 'uuid2', - status: 'pending', - cancellable: true, - chainId: ChainId.sepolia, + await withController( + { + options: { + // pending transactions in state are required to test polling + state: { + smartTransactionsState: { + ...getDefaultSmartTransactionsControllerState() + .smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [ + { + uuid: 'uuid1', + status: 'pending', + cancellable: true, + chainId: ChainId.mainnet, + }, + ], + [ChainId.sepolia]: [ + { + uuid: 'uuid2', + status: 'pending', + cancellable: true, + chainId: ChainId.sepolia, + }, + ], + }, }, - ], + }, }, }, - }); - - const handleFetchSpy = jest.spyOn(utils, 'handleFetch'); + async ({ controller }) => { + const handleFetchSpy = jest.spyOn(utils, 'handleFetch'); + const mainnetPollingToken = controller.startPollingByNetworkClientId( + NetworkType.mainnet, + ); + + await advanceTime({ clock, duration: 0 }); + + const fetchHeaders = { + headers: { + 'Content-Type': 'application/json', + 'X-Client-Id': 'default', + }, + }; + + expect(handleFetchSpy).toHaveBeenNthCalledWith( + 1, + `${API_BASE_URL}/networks/${convertHexToDecimal( + ChainId.mainnet, + )}/batchStatus?uuids=uuid1`, + fetchHeaders, + ); + + await advanceTime({ clock, duration: DEFAULT_INTERVAL }); + + expect(handleFetchSpy).toHaveBeenNthCalledWith( + 2, + `${API_BASE_URL}/networks/${convertHexToDecimal( + ChainId.mainnet, + )}/batchStatus?uuids=uuid1`, + fetchHeaders, + ); + + controller.startPollingByNetworkClientId(NetworkType.sepolia); + await advanceTime({ clock, duration: 0 }); + + expect(handleFetchSpy).toHaveBeenNthCalledWith( + 3, + `${API_BASE_URL}/networks/${convertHexToDecimal( + ChainId.sepolia, + )}/batchStatus?uuids=uuid2`, + fetchHeaders, + ); + + await advanceTime({ clock, duration: DEFAULT_INTERVAL }); + + expect(handleFetchSpy).toHaveBeenNthCalledWith( + 5, + `${API_BASE_URL}/networks/${convertHexToDecimal( + ChainId.sepolia, + )}/batchStatus?uuids=uuid2`, + fetchHeaders, + ); + + // stop the mainnet polling + controller.stopPollingByPollingToken(mainnetPollingToken); + + // cycle two polling intervals + await advanceTime({ clock, duration: DEFAULT_INTERVAL }); + + await advanceTime({ clock, duration: DEFAULT_INTERVAL }); + + // check that the mainnet polling has stopped while the sepolia polling continues + expect(handleFetchSpy).toHaveBeenNthCalledWith( + 6, + `${API_BASE_URL}/networks/${convertHexToDecimal( + ChainId.sepolia, + )}/batchStatus?uuids=uuid2`, + fetchHeaders, + ); + + expect(handleFetchSpy).toHaveBeenNthCalledWith( + 7, + `${API_BASE_URL}/networks/${convertHexToDecimal( + ChainId.sepolia, + )}/batchStatus?uuids=uuid2`, + fetchHeaders, + ); + }, + ); + }); + }); - const mainnetPollingToken = - smartTransactionsController.startPollingByNetworkClientId( - NetworkType.mainnet, - ); + describe('wipeSmartTransactions', () => { + it('does not modify state if no address is provided', async () => { + await withController( + { + options: { + state: { + smartTransactionsState: { + ...getDefaultSmartTransactionsControllerState() + .smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [ + { uuid: 'some-uuid-1', txParams: { from: '0x123' } }, + { uuid: 'some-uuid-2', txParams: { from: '0x456' } }, + { uuid: 'some-uuid-3', txParams: { from: '0x123' } }, + ], + [ChainId.sepolia]: [ + { uuid: 'some-uuid-4', txParams: { from: '0x123' } }, + { uuid: 'some-uuid-5', txParams: { from: '0x789' } }, + { uuid: 'some-uuid-6', txParams: { from: '0x123' } }, + ], + }, + }, + }, + }, + }, + ({ controller }) => { + const prevState = { + ...controller.state, + }; - await advanceTime({ clock, duration: 0 }); + controller.wipeSmartTransactions({ address: '' }); - const fetchHeaders = { - headers: { - 'Content-Type': 'application/json', - 'X-Client-Id': 'default', + expect(controller.state).toStrictEqual(prevState); }, - }; - - expect(handleFetchSpy).toHaveBeenNthCalledWith( - 1, - `${API_BASE_URL}/networks/${convertHexToDecimal( - ChainId.mainnet, - )}/batchStatus?uuids=uuid1`, - fetchHeaders, ); + }); - await advanceTime({ clock, duration: DEFAULT_INTERVAL }); + it('removes transactions from all chains saved in the smartTransactionsState if ignoreNetwork is true', async () => { + await withController( + { + options: { + state: { + smartTransactionsState: { + ...getDefaultSmartTransactionsControllerState() + .smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [ + { uuid: 'some-uuid-1', txParams: { from: '0x123' } }, + { uuid: 'some-uuid-2', txParams: { from: '0x456' } }, + { uuid: 'some-uuid-3', txParams: { from: '0x123' } }, + ], + [ChainId.sepolia]: [ + { uuid: 'some-uuid-4', txParams: { from: '0x123' } }, + { uuid: 'some-uuid-5', txParams: { from: '0x789' } }, + { uuid: 'some-uuid-6', txParams: { from: '0x123' } }, + ], + }, + }, + }, + }, + }, + ({ controller }) => { + const address = '0x123'; - expect(handleFetchSpy).toHaveBeenNthCalledWith( - 2, - `${API_BASE_URL}/networks/${convertHexToDecimal( - ChainId.mainnet, - )}/batchStatus?uuids=uuid1`, - fetchHeaders, - ); + controller.wipeSmartTransactions({ + address, + ignoreNetwork: true, + }); - smartTransactionsController.startPollingByNetworkClientId( - NetworkType.sepolia, - ); - await advanceTime({ clock, duration: 0 }); - - expect(handleFetchSpy).toHaveBeenNthCalledWith( - 3, - `${API_BASE_URL}/networks/${convertHexToDecimal( - ChainId.sepolia, - )}/batchStatus?uuids=uuid2`, - fetchHeaders, + const { + smartTransactionsState: { smartTransactions }, + } = controller.state; + Object.keys(smartTransactions).forEach((chainId) => { + const chainIdHex: Hex = chainId as Hex; + expect( + controller.state.smartTransactionsState.smartTransactions[ + chainIdHex + ], + ).not.toContainEqual({ txParams: { from: address } }); + }); + }, ); + }); - await advanceTime({ clock, duration: DEFAULT_INTERVAL }); - - expect(handleFetchSpy).toHaveBeenNthCalledWith( - 5, - `${API_BASE_URL}/networks/${convertHexToDecimal( - ChainId.sepolia, - )}/batchStatus?uuids=uuid2`, - fetchHeaders, - ); + it('removes transactions only from the current chainId if ignoreNetwork is false', async () => { + await withController( + { + options: { + state: { + smartTransactionsState: { + ...getDefaultSmartTransactionsControllerState() + .smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [ + { uuid: 'some-uuid-1', txParams: { from: '0x123' } }, + { uuid: 'some-uuid-2', txParams: { from: '0x456' } }, + { uuid: 'some-uuid-3', txParams: { from: '0x123' } }, + ], + [ChainId.sepolia]: [ + { uuid: 'some-uuid-4', txParams: { from: '0x123' } }, + { uuid: 'some-uuid-5', txParams: { from: '0x789' } }, + { uuid: 'some-uuid-6', txParams: { from: '0x123' } }, + ], + }, + }, + }, + }, + }, + ({ controller }) => { + const address = '0x123'; + controller.wipeSmartTransactions({ + address, + ignoreNetwork: false, + }); - // stop the mainnet polling - smartTransactionsController.stopPollingByPollingToken( - mainnetPollingToken, + expect( + controller.state.smartTransactionsState.smartTransactions[ + ChainId.mainnet + ], + ).not.toContainEqual({ txParams: { from: address } }); + expect( + controller.state.smartTransactionsState.smartTransactions[ + ChainId.sepolia + ], + ).toContainEqual( + expect.objectContaining({ + txParams: expect.objectContaining({ from: address }), + }), + ); + }, ); + }); - // cycle two polling intervals - await advanceTime({ clock, duration: DEFAULT_INTERVAL }); + it('removes transactions from the current chainId (even if it is not in supportedChainIds) if ignoreNetwork is false', async () => { + await withController( + { + options: { + supportedChainIds: [ChainId.sepolia], + chainId: ChainId.mainnet, + state: { + smartTransactionsState: { + ...getDefaultSmartTransactionsControllerState() + .smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [ + { uuid: 'some-uuid-1', txParams: { from: '0x123' } }, + { uuid: 'some-uuid-2', txParams: { from: '0x456' } }, + { uuid: 'some-uuid-3', txParams: { from: '0x123' } }, + ], + [ChainId.sepolia]: [ + { uuid: 'some-uuid-4', txParams: { from: '0x123' } }, + { uuid: 'some-uuid-5', txParams: { from: '0x789' } }, + { uuid: 'some-uuid-6', txParams: { from: '0x123' } }, + ], + }, + }, + }, + }, + }, + ({ controller }) => { + const address = '0x123'; - await advanceTime({ clock, duration: DEFAULT_INTERVAL }); + controller.wipeSmartTransactions({ + address, + ignoreNetwork: false, + }); - // check that the mainnet polling has stopped while the sepolia polling continues - expect(handleFetchSpy).toHaveBeenNthCalledWith( - 6, - `${API_BASE_URL}/networks/${convertHexToDecimal( - ChainId.sepolia, - )}/batchStatus?uuids=uuid2`, - fetchHeaders, + expect( + controller.state.smartTransactionsState.smartTransactions[ + ChainId.mainnet + ], + ).not.toContainEqual({ txParams: { from: address } }); + expect( + controller.state.smartTransactionsState.smartTransactions[ + ChainId.sepolia + ], + ).toContainEqual( + expect.objectContaining({ + txParams: expect.objectContaining({ from: address }), + }), + ); + }, ); + }); - expect(handleFetchSpy).toHaveBeenNthCalledWith( - 7, - `${API_BASE_URL}/networks/${convertHexToDecimal( - ChainId.sepolia, - )}/batchStatus?uuids=uuid2`, - fetchHeaders, - ); + it('removes transactions from all chains (even if they are not in supportedChainIds) if ignoreNetwork is true', async () => { + await withController( + { + options: { + supportedChainIds: [], + state: { + smartTransactionsState: { + ...getDefaultSmartTransactionsControllerState() + .smartTransactionsState, + smartTransactions: { + [ChainId.mainnet]: [ + { uuid: 'some-uuid-1', txParams: { from: '0x123' } }, + { uuid: 'some-uuid-2', txParams: { from: '0x456' } }, + { uuid: 'some-uuid-3', txParams: { from: '0x123' } }, + ], + [ChainId.sepolia]: [ + { uuid: 'some-uuid-4', txParams: { from: '0x123' } }, + { uuid: 'some-uuid-5', txParams: { from: '0x789' } }, + { uuid: 'some-uuid-6', txParams: { from: '0x123' } }, + ], + }, + }, + }, + }, + }, + ({ controller }) => { + const address = '0x123'; - // cleanup - smartTransactionsController.update(defaultState); + controller.wipeSmartTransactions({ + address, + ignoreNetwork: true, + }); - smartTransactionsController.stopAllPolling(); + const { + smartTransactionsState: { smartTransactions }, + } = controller.state; + Object.keys(smartTransactions).forEach((chainId) => { + const chainIdHex: Hex = chainId as Hex; + expect( + controller.state.smartTransactionsState.smartTransactions[ + chainIdHex + ], + ).not.toContainEqual({ txParams: { from: address } }); + }); + }, + ); }); }); +}); - describe('wipeSmartTransactions', () => { - beforeEach(() => { - const newSmartTransactions = { - [ChainId.mainnet]: [ - { uuid: 'some-uuid-1', txParams: { from: '0x123' } }, - { uuid: 'some-uuid-2', txParams: { from: '0x456' } }, - { uuid: 'some-uuid-3', txParams: { from: '0x123' } }, - ], - [ChainId.sepolia]: [ - { uuid: 'some-uuid-4', txParams: { from: '0x123' } }, - { uuid: 'some-uuid-5', txParams: { from: '0x789' } }, - { uuid: 'some-uuid-6', txParams: { from: '0x123' } }, - ], - }; - const { smartTransactionsState } = smartTransactionsController.state; - smartTransactionsController.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: newSmartTransactions, - }, - }); - }); +type WithControllerCallback = ({ + controller, + triggerNetworStateChange, +}: { + controller: SmartTransactionsController; + triggerNetworStateChange: (state: NetworkState) => void; +}) => Promise | ReturnValue; + +type WithControllerOptions = { + options?: Partial< + ConstructorParameters[0] + >; +}; - it('does not modify state if no address is provided', () => { - const prevState = { - ...smartTransactionsController.state, - }; - smartTransactionsController.wipeSmartTransactions({ address: '' }); - expect(smartTransactionsController.state).toStrictEqual(prevState); - }); +type WithControllerArgs = + | [WithControllerCallback] + | [WithControllerOptions, WithControllerCallback]; + +/** + * Builds a controller based on the given options, and calls the given function + * with that controller. + * + * @param args - Either a function, or an options bag + a function. The options + * bag is equivalent to the controller options; the function will be called + * with the built controller. + * @returns Whatever the callback returns. + */ +async function withController( + ...args: WithControllerArgs +): Promise { + const [{ ...rest }, fn] = args.length === 2 ? args : [{}, args[0]]; + const { options } = rest; + const controllerMessenger = new ControllerMessenger< + SmartTransactionsControllerActions | AllowedActions, + SmartTransactionsControllerEvents | AllowedEvents + >(); + controllerMessenger.registerActionHandler( + 'NetworkController:getNetworkClientById', + jest.fn().mockImplementation((networkClientId) => { + switch (networkClientId) { + case NetworkType.mainnet: + return { + configuration: { + chainId: ChainId.mainnet, + }, + provider: getFakeProvider(), + }; + case NetworkType.sepolia: + return { + configuration: { + chainId: ChainId.sepolia, + }, + provider: getFakeProvider(), + }; + default: + throw new Error('Invalid network client id'); + } + }), + ); - it('removes transactions from all chains saved in the smartTransactionsState if ignoreNetwork is true', () => { - const address = '0x123'; - smartTransactionsController.wipeSmartTransactions({ - address, - ignoreNetwork: true, - }); - const { smartTransactions } = - smartTransactionsController.state.smartTransactionsState; - Object.keys(smartTransactions).forEach((chainId) => { - const chainIdHex: Hex = chainId as Hex; - expect( - smartTransactionsController.state.smartTransactionsState - .smartTransactions[chainIdHex], - ).not.toContainEqual({ txParams: { from: address } }); - }); - }); + const messenger = controllerMessenger.getRestricted({ + name: 'SmartTransactionsController', + allowedActions: ['NetworkController:getNetworkClientById'], + allowedEvents: ['NetworkController:stateChange'], + }); - it('removes transactions only from the current chainId if ignoreNetwork is false', () => { - const address = '0x123'; - smartTransactionsController.wipeSmartTransactions({ - address, - ignoreNetwork: false, + const controller = new SmartTransactionsController({ + messenger, + getNonceLock: jest.fn().mockResolvedValue({ + nextNonce: 'nextNonce', + releaseLock: jest.fn(), + }), + confirmExternalTransaction: jest.fn(), + getTransactions: jest.fn(), + trackMetaMetricsEvent: trackMetaMetricsEventSpy, + getMetaMetricsProps: jest.fn(async () => { + return Promise.resolve({ + accountHardwareType: 'Ledger Hardware', + accountType: 'hardware', + deviceModel: 'ledger', }); - expect( - smartTransactionsController.state.smartTransactionsState - .smartTransactions[smartTransactionsController.config.chainId], - ).not.toContainEqual({ txParams: { from: address } }); - expect( - smartTransactionsController.state.smartTransactionsState - .smartTransactions[ChainId.sepolia], - ).toContainEqual( - expect.objectContaining({ - txParams: expect.objectContaining({ from: address }), - }), - ); - }); + }), + ...options, + }); - it('removes transactions from the current chainId (even if it is not in supportedChainIds) if ignoreNetwork is false', () => { - const address = '0x123'; - smartTransactionsController.config.supportedChainIds = [ChainId.mainnet]; - smartTransactionsController.config.chainId = ChainId.sepolia; - smartTransactionsController.wipeSmartTransactions({ - address, - ignoreNetwork: false, - }); - expect( - smartTransactionsController.state.smartTransactionsState - .smartTransactions[smartTransactionsController.config.chainId], - ).not.toContainEqual({ txParams: { from: address } }); - expect( - smartTransactionsController.state.smartTransactionsState - .smartTransactions[ChainId.mainnet], - ).toContainEqual( - expect.objectContaining({ - txParams: expect.objectContaining({ from: address }), - }), - ); - }); + function triggerNetworStateChange(state: NetworkState) { + controllerMessenger.publish('NetworkController:stateChange', state, []); + } + + triggerNetworStateChange({ + selectedNetworkClientId: NetworkType.mainnet, + networkConfigurations: { + id: { + id: 'id', + rpcUrl: 'string', + chainId: ChainId.mainnet, + ticker: 'string', + }, + }, + networksMetadata: { + id: { + EIPS: { + 1155: true, + }, + status: NetworkStatus.Available, + }, + }, + }); - it('removes transactions from all chains (even if they are not in supportedChainIds) if ignoreNetwork is true', () => { - const address = '0x123'; - smartTransactionsController.config.supportedChainIds = []; - smartTransactionsController.wipeSmartTransactions({ - address, - ignoreNetwork: true, - }); - const { smartTransactions } = - smartTransactionsController.state.smartTransactionsState; - Object.keys(smartTransactions).forEach((chainId) => { - const chainIdHex: Hex = chainId as Hex; - expect( - smartTransactionsController.state.smartTransactionsState - .smartTransactions[chainIdHex], - ).not.toContainEqual({ txParams: { from: address } }); - }); + try { + return await fn({ + controller, + triggerNetworStateChange, }); - }); -}); + } finally { + controller.stop(); + controller.stopAllPolling(); + } +} diff --git a/src/SmartTransactionsController.ts b/src/SmartTransactionsController.ts index 2a2b597a..5226cd3d 100644 --- a/src/SmartTransactionsController.ts +++ b/src/SmartTransactionsController.ts @@ -1,19 +1,29 @@ -// eslint-disable-next-line import/no-nodejs-modules import { hexlify } from '@ethersproject/bytes'; -import type { BaseConfig, BaseState } from '@metamask/base-controller'; -import { query, safelyExecute, ChainId } from '@metamask/controller-utils'; +import type { + ControllerGetStateAction, + ControllerStateChangeEvent, + RestrictedControllerMessenger, +} from '@metamask/base-controller'; +import { + query, + safelyExecute, + ChainId, + isSafeDynamicKey, +} from '@metamask/controller-utils'; import EthQuery from '@metamask/eth-query'; import type { NetworkClientId, - NetworkController, - NetworkState, + NetworkControllerGetNetworkClientByIdAction, + NetworkControllerStateChangeEvent, } from '@metamask/network-controller'; -import { StaticIntervalPollingControllerV1 } from '@metamask/polling-controller'; -import type { TransactionMeta } from '@metamask/transaction-controller'; +import { StaticIntervalPollingController } from '@metamask/polling-controller'; +import type { + TransactionController, + TransactionMeta, + TransactionParams, +} from '@metamask/transaction-controller'; import { TransactionStatus } from '@metamask/transaction-controller'; import { BigNumber } from 'bignumber.js'; -// eslint-disable-next-line import/no-nodejs-modules -import EventEmitter from 'events'; import cloneDeep from 'lodash/cloneDeep'; import { MetaMetricsEventCategory, MetaMetricsEventName } from './constants'; @@ -47,160 +57,242 @@ import { const SECOND = 1000; export const DEFAULT_INTERVAL = SECOND * 5; +const DEFAULT_CLIENT_ID = 'default'; const ETH_QUERY_ERROR_MSG = '`ethQuery` is not defined on SmartTransactionsController'; -export type SmartTransactionsControllerConfig = BaseConfig & { - interval: number; - clientId: string; - chainId: Hex; - supportedChainIds: Hex[]; +/** + * The name of the {@link SmartTransactionsController} + */ +const controllerName = 'SmartTransactionsController'; + +const controllerMetadata = { + smartTransactionsState: { + persist: false, + anonymous: true, + }, }; type FeeEstimates = { - approvalTxFees: IndividualTxFees | undefined; - tradeTxFees: IndividualTxFees | undefined; + approvalTxFees: IndividualTxFees | null; + tradeTxFees: IndividualTxFees | null; }; -export type SmartTransactionsControllerState = BaseState & { +export type SmartTransactionsControllerState = { smartTransactionsState: { smartTransactions: Record; - userOptIn: boolean | undefined; - userOptInV2: boolean | undefined; - liveness: boolean | undefined; + userOptIn: boolean | null; + userOptInV2: boolean | null; + liveness: boolean | null; fees: FeeEstimates; feesByChainId: Record; livenessByChainId: Record; }; }; -export default class SmartTransactionsController extends StaticIntervalPollingControllerV1< - SmartTransactionsControllerConfig, - SmartTransactionsControllerState +/** + * Get the default {@link SmartTransactionsController} state. + * + * @returns The default {@link SmartTransactionsController} state. + */ +export function getDefaultSmartTransactionsControllerState(): SmartTransactionsControllerState { + return { + smartTransactionsState: { + smartTransactions: {}, + userOptIn: null, + userOptInV2: null, + fees: { + approvalTxFees: null, + tradeTxFees: null, + }, + liveness: true, + livenessByChainId: { + [ChainId.mainnet]: true, + [ChainId.sepolia]: true, + }, + feesByChainId: { + [ChainId.mainnet]: { + approvalTxFees: null, + tradeTxFees: null, + }, + [ChainId.sepolia]: { + approvalTxFees: null, + tradeTxFees: null, + }, + }, + }, + }; +} + +export type SmartTransactionsControllerGetStateAction = + ControllerGetStateAction< + typeof controllerName, + SmartTransactionsControllerState + >; + +/** + * The actions that can be performed using the {@link SmartTransactionsController}. + */ +export type SmartTransactionsControllerActions = + SmartTransactionsControllerGetStateAction; + +export type AllowedActions = NetworkControllerGetNetworkClientByIdAction; + +export type SmartTransactionsControllerStateChangeEvent = + ControllerStateChangeEvent< + typeof controllerName, + SmartTransactionsControllerState + >; + +export type SmartTransactionsControllerSmartTransactionEvent = { + type: 'SmartTransactionsController:smartTransaction'; + payload: [SmartTransaction]; +}; + +/** + * The events that {@link SmartTransactionsController} can emit. + */ +export type SmartTransactionsControllerEvents = + | SmartTransactionsControllerStateChangeEvent + | SmartTransactionsControllerSmartTransactionEvent; + +export type AllowedEvents = NetworkControllerStateChangeEvent; + +/** + * The messenger of the {@link SmartTransactionsController}. + */ +export type SmartTransactionsControllerMessenger = + RestrictedControllerMessenger< + typeof controllerName, + SmartTransactionsControllerActions | AllowedActions, + SmartTransactionsControllerEvents | AllowedEvents, + AllowedActions['type'], + AllowedEvents['type'] + >; + +type SmartTransactionsControllerOptions = { + interval?: number; + clientId?: string; + chainId?: Hex; + supportedChainIds?: Hex[]; + getNonceLock: TransactionController['getNonceLock']; + confirmExternalTransaction: TransactionController['confirmExternalTransaction']; + trackMetaMetricsEvent: ( + event: { + event: MetaMetricsEventName; + category: MetaMetricsEventCategory; + properties?: ReturnType; + sensitiveProperties?: ReturnType< + typeof getSmartTransactionMetricsSensitiveProperties + >; + }, + options?: { metaMetricsId?: string } & Record, + ) => void; + state?: Partial; + messenger: SmartTransactionsControllerMessenger; + getTransactions: (options?: GetTransactionsOptions) => TransactionMeta[]; + getMetaMetricsProps: () => Promise; +}; + +export default class SmartTransactionsController extends StaticIntervalPollingController< + typeof controllerName, + SmartTransactionsControllerState, + SmartTransactionsControllerMessenger > { - /** - * Name of this controller used during composition - */ - override name = 'SmartTransactionsController'; + #interval: number; - public timeoutHandle?: NodeJS.Timeout; + #clientId: string; - private readonly getNonceLock: any; + #chainId: Hex; - private ethQuery: EthQuery | undefined; + #supportedChainIds: Hex[]; - public confirmExternalTransaction: any; + timeoutHandle?: NodeJS.Timeout; - public getRegularTransactions: ( - options?: GetTransactionsOptions, - ) => TransactionMeta[]; + readonly #getNonceLock: SmartTransactionsControllerOptions['getNonceLock']; + + #ethQuery: EthQuery | undefined; - private readonly trackMetaMetricsEvent: any; + #confirmExternalTransaction: SmartTransactionsControllerOptions['confirmExternalTransaction']; - public eventEmitter: EventEmitter; + #getRegularTransactions: ( + options?: GetTransactionsOptions, + ) => TransactionMeta[]; - private readonly getNetworkClientById: NetworkController['getNetworkClientById']; + readonly #trackMetaMetricsEvent: SmartTransactionsControllerOptions['trackMetaMetricsEvent']; - private readonly getMetaMetricsProps: () => Promise; + readonly #getMetaMetricsProps: () => Promise; /* istanbul ignore next */ - private async fetch(request: string, options?: RequestInit) { - const { clientId } = this.config; + async #fetch(request: string, options?: RequestInit) { const fetchOptions = { ...options, headers: { 'Content-Type': 'application/json', - ...(clientId && { 'X-Client-Id': clientId }), + ...(this.#clientId && { 'X-Client-Id': this.#clientId }), }, }; return handleFetch(request, fetchOptions); } - constructor( - { - onNetworkStateChange, - getNonceLock, - confirmExternalTransaction, - getTransactions, - trackMetaMetricsEvent, - getNetworkClientById, - getMetaMetricsProps, - }: { - onNetworkStateChange: ( - listener: (networkState: NetworkState) => void, - ) => void; - getNonceLock: any; - confirmExternalTransaction: any; - getTransactions: (options?: GetTransactionsOptions) => TransactionMeta[]; - trackMetaMetricsEvent: any; - getNetworkClientById: NetworkController['getNetworkClientById']; - getMetaMetricsProps: () => Promise; - }, - config?: Partial, - state?: Partial, - ) { - super(config, state); - - this.defaultConfig = { - interval: DEFAULT_INTERVAL, - chainId: ChainId.mainnet, - clientId: 'default', - supportedChainIds: [ChainId.mainnet, ChainId.sepolia], - }; - - this.defaultState = { - smartTransactionsState: { - smartTransactions: {}, - userOptIn: undefined, - userOptInV2: undefined, - fees: { - approvalTxFees: undefined, - tradeTxFees: undefined, - }, - liveness: true, - livenessByChainId: { - [ChainId.mainnet]: true, - [ChainId.sepolia]: true, - }, - feesByChainId: { - [ChainId.mainnet]: { - approvalTxFees: undefined, - tradeTxFees: undefined, - }, - [ChainId.sepolia]: { - approvalTxFees: undefined, - tradeTxFees: undefined, - }, - }, + constructor({ + interval = DEFAULT_INTERVAL, + clientId = DEFAULT_CLIENT_ID, + chainId: InitialChainId = ChainId.mainnet, + supportedChainIds = [ChainId.mainnet, ChainId.sepolia], + getNonceLock, + confirmExternalTransaction, + trackMetaMetricsEvent, + state = {}, + messenger, + getTransactions, + getMetaMetricsProps, + }: SmartTransactionsControllerOptions) { + super({ + name: controllerName, + metadata: controllerMetadata, + messenger, + state: { + ...getDefaultSmartTransactionsControllerState(), + ...state, }, - }; - - this.initialize(); - this.setIntervalLength(this.config.interval); - this.getNonceLock = getNonceLock; - this.ethQuery = undefined; - this.confirmExternalTransaction = confirmExternalTransaction; - this.getRegularTransactions = getTransactions; - this.trackMetaMetricsEvent = trackMetaMetricsEvent; - this.getNetworkClientById = getNetworkClientById; - this.getMetaMetricsProps = getMetaMetricsProps; + }); + this.#interval = interval; + this.#clientId = clientId; + this.#chainId = InitialChainId; + this.#supportedChainIds = supportedChainIds; + this.setIntervalLength(interval); + this.#getNonceLock = getNonceLock; + this.#ethQuery = undefined; + this.#confirmExternalTransaction = confirmExternalTransaction; + this.#getRegularTransactions = getTransactions; + this.#trackMetaMetricsEvent = trackMetaMetricsEvent; + this.#getMetaMetricsProps = getMetaMetricsProps; this.initializeSmartTransactionsForChainId(); - onNetworkStateChange(({ selectedNetworkClientId }) => { - const { - configuration: { chainId }, - provider, - } = this.getNetworkClientById(selectedNetworkClientId); - this.configure({ chainId }); - this.ethQuery = new EthQuery(provider); - this.initializeSmartTransactionsForChainId(); - this.checkPoll(this.state); - }); + this.messagingSystem.subscribe( + 'NetworkController:stateChange', + ({ selectedNetworkClientId }) => { + const { + configuration: { chainId }, + provider, + } = this.messagingSystem.call( + 'NetworkController:getNetworkClientById', + selectedNetworkClientId, + ); + this.#chainId = chainId; + this.#ethQuery = new EthQuery(provider); + this.initializeSmartTransactionsForChainId(); + this.checkPoll(this.state); + }, + ); - this.subscribe((currentState: any) => this.checkPoll(currentState)); - this.eventEmitter = new EventEmitter(); + this.messagingSystem.subscribe( + `${controllerName}:stateChange`, + (currentState) => this.checkPoll(currentState), + ); } async _executePoll(networkClientId: string): Promise { @@ -209,15 +301,16 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo // wondering if we should add some kind of predicate to the polling controller to check whether // we should poll or not const chainId = this.#getChainId({ networkClientId }); - if (!this.config.supportedChainIds.includes(chainId)) { + if (!this.#supportedChainIds.includes(chainId)) { return Promise.resolve(); } return this.updateSmartTransactions({ networkClientId }); } - checkPoll(state: any) { - const { smartTransactions } = state.smartTransactionsState; - const currentSmartTransactions = smartTransactions[this.config.chainId]; + checkPoll({ + smartTransactionsState: { smartTransactions }, + }: SmartTransactionsControllerState) { + const currentSmartTransactions = smartTransactions[this.#chainId]; const pendingTransactions = currentSmartTransactions?.filter( isSmartTransactionPending, ); @@ -229,32 +322,28 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo } initializeSmartTransactionsForChainId() { - if (this.config.supportedChainIds.includes(this.config.chainId)) { - const { smartTransactionsState } = this.state; - this.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - ...smartTransactionsState.smartTransactions, - [this.config.chainId]: - smartTransactionsState.smartTransactions[this.config.chainId] ?? - [], - }, - }, + if (this.#supportedChainIds.includes(this.#chainId)) { + this.update((state) => { + state.smartTransactionsState.smartTransactions[this.#chainId] = + state.smartTransactionsState.smartTransactions[this.#chainId] ?? []; }); } } async poll(interval?: number): Promise { - const { chainId, supportedChainIds } = this.config; - interval && this.configure({ interval }, false, false); + if (interval) { + this.#interval = interval; + } + this.timeoutHandle && clearInterval(this.timeoutHandle); - if (!supportedChainIds.includes(chainId)) { + + if (!this.#supportedChainIds.includes(this.#chainId)) { return; } + this.timeoutHandle = setInterval(() => { safelyExecute(async () => this.updateSmartTransactions()); - }, this.config.interval); + }, this.#interval); await safelyExecute(async () => this.updateSmartTransactions()); } @@ -263,12 +352,9 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo this.timeoutHandle = undefined; } - setOptInState(state: boolean | undefined): void { - this.update({ - smartTransactionsState: { - ...this.state.smartTransactionsState, - userOptInV2: state, - }, + setOptInState(optInState: boolean | null): void { + this.update((state) => { + state.smartTransactionsState.userOptInV2 = optInState; }); } @@ -286,7 +372,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo return; // If status hasn't changed, don't track it again. } - this.trackMetaMetricsEvent({ + this.#trackMetaMetricsEvent({ event: MetaMetricsEventName.StxStatusUpdated, category: MetaMetricsEventCategory.Transactions, properties: getSmartTransactionMetricsProperties(updatedSmartTransaction), @@ -297,10 +383,10 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo } isNewSmartTransaction(smartTransactionUuid: string): boolean { - const { chainId } = this.config; - const { smartTransactionsState } = this.state; - const { smartTransactions } = smartTransactionsState; - const currentSmartTransactions = smartTransactions[chainId]; + const { + smartTransactionsState: { smartTransactions }, + } = this.state; + const currentSmartTransactions = smartTransactions[this.#chainId]; const currentIndex = currentSmartTransactions?.findIndex( (stx) => stx.uuid === smartTransactionUuid, ); @@ -311,14 +397,15 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo smartTransaction: SmartTransaction, { networkClientId }: { networkClientId?: NetworkClientId } = {}, ) { - let { - ethQuery, - config: { chainId }, - } = this; + let ethQuery = this.#ethQuery; + let chainId = this.#chainId; if (networkClientId) { - const networkClient = this.getNetworkClientById(networkClientId); - chainId = networkClient.configuration.chainId; - ethQuery = new EthQuery(networkClient.provider); + const { configuration, provider } = this.messagingSystem.call( + 'NetworkController:getNetworkClientById', + networkClientId, + ); + chainId = configuration.chainId; + ethQuery = new EthQuery(provider); } this.#createOrUpdateSmartTransaction(smartTransaction, { @@ -330,41 +417,41 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo #updateSmartTransaction( smartTransaction: SmartTransaction, { - chainId = this.config.chainId, + chainId = this.#chainId, }: { chainId: Hex; }, ) { - const { smartTransactionsState } = this.state; - const { smartTransactions } = smartTransactionsState; + const { + smartTransactionsState: { smartTransactions }, + } = this.state; const currentSmartTransactions = smartTransactions[chainId] ?? []; const currentIndex = currentSmartTransactions?.findIndex( (stx) => stx.uuid === smartTransaction.uuid, ); + if (currentIndex === -1) { return; // Smart transaction not found, don't update anything. } - this.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - ...smartTransactionsState.smartTransactions, - [chainId]: smartTransactionsState.smartTransactions[chainId].map( - (existingSmartTransaction, index) => { - return index === currentIndex - ? { ...existingSmartTransaction, ...smartTransaction } - : existingSmartTransaction; - }, - ), - }, - }, + + if (!isSafeDynamicKey(chainId)) { + return; + } + + this.update((state) => { + state.smartTransactionsState.smartTransactions[chainId][currentIndex] = { + ...state.smartTransactionsState.smartTransactions[chainId][ + currentIndex + ], + ...smartTransaction, + }; }); } async #addMetaMetricsPropsToNewSmartTransaction( smartTransaction: SmartTransaction, ) { - const metaMetricsProps = await this.getMetaMetricsProps(); + const metaMetricsProps = await this.#getMetaMetricsProps(); smartTransaction.accountHardwareType = metaMetricsProps?.accountHardwareType; smartTransaction.accountType = metaMetricsProps?.accountType; @@ -374,15 +461,16 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo async #createOrUpdateSmartTransaction( smartTransaction: SmartTransaction, { - chainId = this.config.chainId, - ethQuery = this.ethQuery, + chainId = this.#chainId, + ethQuery = this.#ethQuery, }: { chainId: Hex; ethQuery: EthQuery | undefined; }, ): Promise { - const { smartTransactionsState } = this.state; - const { smartTransactions } = smartTransactionsState; + const { + smartTransactionsState: { smartTransactions }, + } = this.state; const currentSmartTransactions = smartTransactions[chainId] ?? []; const currentIndex = currentSmartTransactions?.findIndex( (stx) => stx.uuid === smartTransaction.uuid, @@ -390,7 +478,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo const isNewSmartTransaction = this.isNewSmartTransaction( smartTransaction.uuid, ); - if (this.ethQuery === undefined) { + if (this.#ethQuery === undefined) { throw new Error(ETH_QUERY_ERROR_MSG); } @@ -422,22 +510,18 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo .concat(currentSmartTransactions.slice(cancelledNonceIndex + 1)) .concat(historifiedSmartTransaction) : currentSmartTransactions.concat(historifiedSmartTransaction); - this.update({ - smartTransactionsState: { - ...smartTransactionsState, - smartTransactions: { - ...smartTransactionsState.smartTransactions, - [chainId]: nextSmartTransactions, - }, - }, + + this.update((state) => { + state.smartTransactionsState.smartTransactions[this.#chainId] = + nextSmartTransactions; }); return; } // We have to emit this event here, because then a txHash is returned to the TransactionController once it's available // and the #doesTransactionNeedConfirmation function will work properly, since it will find the txHash in the regular transactions list. - this.eventEmitter.emit( - `${smartTransaction.uuid}:smartTransaction`, + this.messagingSystem.publish( + `SmartTransactionsController:smartTransaction`, smartTransaction, ); @@ -468,7 +552,9 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo }: { networkClientId?: NetworkClientId; } = {}): Promise { - const { smartTransactions } = this.state.smartTransactionsState; + const { + smartTransactionsState: { smartTransactions }, + } = this.state; const chainId = this.#getChainId({ networkClientId }); const smartTransactionsForChainId = smartTransactions[chainId]; @@ -487,7 +573,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo if (!txHash) { return true; } - const transactions = this.getRegularTransactions(); + const transactions = this.#getRegularTransactions(); const foundTransaction = transactions?.find((tx) => { return tx.hash?.toLowerCase() === txHash.toLowerCase(); }); @@ -505,8 +591,8 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo async #confirmSmartTransaction( smartTransaction: SmartTransaction, { - chainId = this.config.chainId, - ethQuery = this.ethQuery, + chainId = this.#chainId, + ethQuery = this.#ethQuery, }: { chainId: Hex; ethQuery: EthQuery | undefined; @@ -530,7 +616,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo const maxFeePerGas = transaction?.maxFeePerGas; const maxPriorityFeePerGas = transaction?.maxPriorityFeePerGas; if (transactionReceipt?.blockNumber) { - const blockData: { baseFeePerGas?: string } | null = await query( + const blockData: { baseFeePerGas?: Hex } | null = await query( ethQuery, 'getBlockByNumber', [transactionReceipt?.blockNumber, false], @@ -568,13 +654,15 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo : originalTxMeta; if (this.#doesTransactionNeedConfirmation(txHash)) { - this.confirmExternalTransaction( - txMeta, + this.#confirmExternalTransaction( + // TODO: Replace 'as' assertion with correct typing for `txMeta` + txMeta as TransactionMeta, transactionReceipt, - baseFeePerGas, + // TODO: Replace 'as' assertion with correct typing for `baseFeePerGas` + baseFeePerGas as Hex, ); } - this.trackMetaMetricsEvent({ + this.#trackMetaMetricsEvent({ event: MetaMetricsEventName.StxConfirmed, category: MetaMetricsEventCategory.Transactions, properties: getSmartTransactionMetricsProperties(smartTransaction), @@ -589,7 +677,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo ); } } catch (error) { - this.trackMetaMetricsEvent({ + this.#trackMetaMetricsEvent({ event: MetaMetricsEventName.StxConfirmationFailed, category: MetaMetricsEventCategory.Transactions, }); @@ -612,7 +700,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo chainId, )}?${params.toString()}`; - const data = (await this.fetch(url)) as Record< + const data = (await this.#fetch(url)) as Record< string, SmartTransactionsStatus >; @@ -636,7 +724,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo async addNonceToTransaction( transaction: UnsignedTransaction, ): Promise { - const nonceLock = await this.getNonceLock(transaction.from); + const nonceLock = await this.#getNonceLock(transaction.from); const nonce = nonceLock.nextNonce; nonceLock.releaseLock(); return { @@ -647,15 +735,13 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo clearFees(): Fees { const fees = { - approvalTxFees: undefined, - tradeTxFees: undefined, + approvalTxFees: null, + tradeTxFees: null, }; - this.update({ - smartTransactionsState: { - ...this.state.smartTransactionsState, - fees, - }, + this.update((state) => { + state.smartTransactionsState.fees = fees; }); + return fees; } @@ -684,38 +770,36 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo ); } transactions.push(unsignedTradeTransactionWithNonce); - const data = await this.fetch(getAPIRequestURL(APIType.GET_FEES, chainId), { - method: 'POST', - body: JSON.stringify({ - txs: transactions, - }), - }); - let approvalTxFees; - let tradeTxFees; + const data = await this.#fetch( + getAPIRequestURL(APIType.GET_FEES, chainId), + { + method: 'POST', + body: JSON.stringify({ + txs: transactions, + }), + }, + ); + let approvalTxFees: IndividualTxFees | null; + let tradeTxFees: IndividualTxFees | null; if (approvalTx) { approvalTxFees = data?.txs[0]; tradeTxFees = data?.txs[1]; } else { + approvalTxFees = null; tradeTxFees = data?.txs[0]; } - this.update({ - smartTransactionsState: { - ...this.state.smartTransactionsState, - ...(chainId === this.config.chainId && { - fees: { - approvalTxFees, - tradeTxFees, - }, - }), - feesByChainId: { - ...this.state.smartTransactionsState.feesByChainId, - [chainId]: { - approvalTxFees, - tradeTxFees, - }, - }, - }, + this.update((state) => { + if (chainId === this.#chainId) { + state.smartTransactionsState.fees = { + approvalTxFees, + tradeTxFees, + }; + } + state.smartTransactionsState.feesByChainId[chainId] = { + approvalTxFees, + tradeTxFees, + }; }); return { @@ -735,13 +819,13 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo }: { signedTransactions: SignedTransaction[]; signedCanceledTransactions: SignedCanceledTransaction[]; - transactionMeta?: any; - txParams?: any; + transactionMeta?: TransactionMeta; + txParams?: TransactionParams; networkClientId?: NetworkClientId; }) { const chainId = this.#getChainId({ networkClientId }); const ethQuery = this.#getEthQuery({ networkClientId }); - const data = await this.fetch( + const data = await this.#fetch( getAPIRequestURL(APIType.SUBMIT_TRANSACTIONS, chainId), { method: 'POST', @@ -762,18 +846,16 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo console.error('provider error', error); } - const requiresNonce = !txParams.nonce; + const requiresNonce = txParams && !txParams.nonce; let nonce; let nonceLock; let nonceDetails = {}; if (requiresNonce) { - nonceLock = await this.getNonceLock(txParams?.from); + nonceLock = await this.#getNonceLock(txParams.from); nonce = hexlify(nonceLock.nextNonce); nonceDetails = nonceLock.nonceDetails; - if (txParams) { - txParams.nonce ??= nonce; - } + txParams.nonce ??= nonce; } const submitTransactionResponse = { ...data, @@ -792,7 +874,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo uuid: submitTransactionResponse.uuid, txHash: submitTransactionResponse.txHash, cancellable: true, - type: transactionMeta?.type || 'swap', + type: transactionMeta?.type ?? 'swap', }, { chainId, ethQuery }, ); @@ -806,9 +888,14 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo #getChainId({ networkClientId, }: { networkClientId?: NetworkClientId } = {}): Hex { - return networkClientId - ? this.getNetworkClientById(networkClientId).configuration.chainId - : this.config.chainId; + if (networkClientId) { + return this.messagingSystem.call( + 'NetworkController:getNetworkClientById', + networkClientId, + ).configuration.chainId; + } + + return this.#chainId; } #getEthQuery({ @@ -817,14 +904,18 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo networkClientId?: NetworkClientId; } = {}): EthQuery { if (networkClientId) { - return new EthQuery(this.getNetworkClientById(networkClientId).provider); + const { provider } = this.messagingSystem.call( + 'NetworkController:getNetworkClientById', + networkClientId, + ); + return new EthQuery(provider); } - if (this.ethQuery === undefined) { + if (this.#ethQuery === undefined) { throw new Error(ETH_QUERY_ERROR_MSG); } - return this.ethQuery; + return this.#ethQuery; } // TODO: This should return if the cancellation was on chain or not (for nonce management) @@ -839,7 +930,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo } = {}, ): Promise { const chainId = this.#getChainId({ networkClientId }); - await this.fetch(getAPIRequestURL(APIType.CANCEL, chainId), { + await this.#fetch(getAPIRequestURL(APIType.CANCEL, chainId), { method: 'POST', body: JSON.stringify({ uuid }), }); @@ -853,7 +944,7 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo const chainId = this.#getChainId({ networkClientId }); let liveness = false; try { - const response = await this.fetch( + const response = await this.#fetch( getAPIRequestURL(APIType.LIVENESS, chainId), ); liveness = Boolean(response.lastBlock); @@ -861,30 +952,27 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo console.log('"fetchLiveness" API call failed'); } - this.update({ - smartTransactionsState: { - ...this.state.smartTransactionsState, - ...(chainId === this.config.chainId && { liveness }), - livenessByChainId: { - ...this.state.smartTransactionsState.livenessByChainId, - [chainId]: liveness, - }, - }, + this.update((state) => { + if (chainId === this.#chainId) { + state.smartTransactionsState.liveness = liveness; + } + state.smartTransactionsState.livenessByChainId[chainId] = liveness; }); return liveness; } async setStatusRefreshInterval(interval: number): Promise { - if (interval !== this.config.interval) { - this.configure({ interval }, false, false); + if (interval !== this.#interval) { + this.#interval = interval; } } #getCurrentSmartTransactions(): SmartTransaction[] { - const { smartTransactions } = this.state.smartTransactionsState; - const { chainId } = this.config; - const currentSmartTransactions = smartTransactions?.[chainId]; + const { + smartTransactionsState: { smartTransactions }, + } = this.state; + const currentSmartTransactions = smartTransactions?.[this.#chainId]; if (!currentSmartTransactions || currentSmartTransactions.length === 0) { return []; } @@ -931,17 +1019,18 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo } const addressLowerCase = address.toLowerCase(); if (ignoreNetwork) { - const { smartTransactions } = this.state.smartTransactionsState; - Object.keys(smartTransactions).forEach((chainId) => { - const chainIdHex: Hex = chainId as Hex; + const { + smartTransactionsState: { smartTransactions }, + } = this.state; + (Object.keys(smartTransactions) as Hex[]).forEach((chainId) => { this.#wipeSmartTransactionsPerChainId({ - chainId: chainIdHex, + chainId, addressLowerCase, }); }); } else { this.#wipeSmartTransactionsPerChainId({ - chainId: this.config.chainId, + chainId: this.#chainId, addressLowerCase, }); } @@ -954,7 +1043,9 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo chainId: Hex; addressLowerCase: string; }): void { - const { smartTransactions } = this.state.smartTransactionsState; + const { + smartTransactionsState: { smartTransactions }, + } = this.state; const smartTransactionsForSelectedChain: SmartTransaction[] = smartTransactions?.[chainId]; if ( @@ -968,14 +1059,9 @@ export default class SmartTransactionsController extends StaticIntervalPollingCo (smartTransaction: SmartTransaction) => smartTransaction.txParams?.from !== addressLowerCase, ); - this.update({ - smartTransactionsState: { - ...this.state.smartTransactionsState, - smartTransactions: { - ...smartTransactions, - [chainId]: newSmartTransactionsForSelectedChain, - }, - }, + this.update((state) => { + state.smartTransactionsState.smartTransactions[chainId] = + newSmartTransactionsForSelectedChain; }); } } diff --git a/src/index.test.ts b/src/index.test.ts index 803c94ca..34901851 100644 --- a/src/index.test.ts +++ b/src/index.test.ts @@ -1,16 +1,32 @@ -import DefaultExport from '.'; -import SmartTransactionsController from './SmartTransactionsController'; +import { ControllerMessenger } from '@metamask/base-controller'; + +import DefaultExport, { + type SmartTransactionsControllerActions, + type SmartTransactionsControllerEvents, +} from '.'; +import SmartTransactionsController, { + type AllowedActions, + type AllowedEvents, +} from './SmartTransactionsController'; describe('default export', () => { it('exports SmartTransactionsController', () => { jest.useFakeTimers(); + const controllerMessenger = new ControllerMessenger< + SmartTransactionsControllerActions | AllowedActions, + SmartTransactionsControllerEvents | AllowedEvents + >(); + const messenger = controllerMessenger.getRestricted({ + name: 'SmartTransactionsController', + allowedActions: ['NetworkController:getNetworkClientById'], + allowedEvents: ['NetworkController:stateChange'], + }); const controller = new DefaultExport({ - onNetworkStateChange: jest.fn(), - getNonceLock: null, + messenger, + getNonceLock: jest.fn(), confirmExternalTransaction: jest.fn(), getTransactions: jest.fn(), trackMetaMetricsEvent: jest.fn(), - getNetworkClientById: jest.fn(), getMetaMetricsProps: jest.fn(async () => { return Promise.resolve({}); }), diff --git a/src/index.ts b/src/index.ts index 6bf8815f..f34e9b33 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,3 +1,10 @@ -import SmartTransactionsController from './SmartTransactionsController'; - -export default SmartTransactionsController; +export { default } from './SmartTransactionsController'; +export type { + SmartTransactionsControllerMessenger, + SmartTransactionsControllerState, + SmartTransactionsControllerGetStateAction, + SmartTransactionsControllerActions, + SmartTransactionsControllerStateChangeEvent, + SmartTransactionsControllerSmartTransactionEvent, + SmartTransactionsControllerEvents, +} from './SmartTransactionsController'; diff --git a/src/types.ts b/src/types.ts index 861a581a..d8b774ea 100644 --- a/src/types.ts +++ b/src/types.ts @@ -113,8 +113,8 @@ export type IndividualTxFees = { }; export type Fees = { - approvalTxFees: IndividualTxFees | undefined; - tradeTxFees: IndividualTxFees | undefined; + approvalTxFees: IndividualTxFees | null; + tradeTxFees: IndividualTxFees | null; }; // TODO