Problem


Key Idea

This is a typical dynamic programming problem.

Before we get into main point, suppose we have unlimited chance of transaction (k \(= \infty\). k \(= \frac{k}{2}\) is large enough).
Then, we may perform transaction whenever we discover price higher than yesterday, and that becomes our maximum profit.
Please note the quickSolve method in solution.

Now let’s get into our main discussion.

We define a recursive helper function findMaxProfit, which finds maximum profit with k transactions in interval [0, endIdx], inclusive.

private int findMaxProfit(final int[] prices, int endIdx, int k) {
    if (endIdx == -1 || k == 0) return 0;

    int maxProfit= 0;
    for (int i=0; i < endIdx; ++i) {
        if (prices[i] < prices[endIdx]) {
            int spotProfit= prices[endIdx] - prices[i];
            maxProfit= Math.max(maxProfit, findMaxProfit(prices, i-1, k-1) + spotProfit);
        } else {
            maxProfit= Math.max(maxProfit, findMaxProfit(prices, i, k));
        }
    }
    return maxProfit;
}


But that’s not good enough.
We compute maxProfit value over and over, which is known as overlapping subproblems in dp.

Below is the optimal version of findMaxProfit, using memoization.

private int findMaxProfit(final int[] prices, int endIdx, int k) {
        if (endIdx == -1 || k == 0) return 0;

        if (cache[endIdx][k] != -1) return cache[endIdx][k];
        cache[endIdx][k]= 0;
        for (int i=0; i < endIdx; ++i) {
            if (prices[i] < prices[endIdx]) {
                int spotProfit= prices[endIdx] - prices[i];
                cache[endIdx][k]= Math.max(cache[endIdx][k], findMaxProfit(prices, i-1, k-1) + spotProfit);
            } else {
                cache[endIdx][k]= Math.max(cache[endIdx][k], findMaxProfit(prices, i, k));
            }
        }
        return cache[endIdx][k];
    }

We now can simply return findMaxProfit(prices, n-1, k).

  • Time: \(O(n^2)\)
  • Space: \(O(n{\cdot}k)\)


Implementation

/**
 * author: jooncco
 * written: 2022. 9. 11. Tue. 02:04:14 [UTC+9]
 **/

class Solution {
    private int n;
    private int[][] cache;

    public int maxProfit(int k, int[] prices) {
        n= prices.length;
        if (k >= n/2) return quickSolve(prices, k);

        cache= new int[n+1][k+1];
        for (int[] row : cache) Arrays.fill(row, -1);
        return findMaxProfit(prices, n-1, k);
    }

    private int findMaxProfit(final int[] prices, int endIdx, int k) {
        if (endIdx == -1 || k == 0) return 0;

        if (cache[endIdx][k] != -1) return cache[endIdx][k];
        cache[endIdx][k]= 0;
        for (int i=0; i < endIdx; ++i) {
            if (prices[i] < prices[endIdx]) {
                int spotProfit= prices[endIdx] - prices[i];
                cache[endIdx][k]= Math.max(cache[endIdx][k], findMaxProfit(prices, i-1, k-1) + spotProfit);
            } else {
                cache[endIdx][k]= Math.max(cache[endIdx][k], findMaxProfit(prices, i, k));
            }
        }
        return cache[endIdx][k];
    }

    private int quickSolve(final int[] prices, int k) {
        int maxProfit= 0;
        for (int i=1; i < n; ++i) {
            if (prices[i] > prices[i-1]) maxProfit += prices[i] - prices[i-1];
        }
        return maxProfit;
    }
}

Leave a comment