Wednesday, January 20, 2021

Split array to k sub arrays to minimize largest sum

Problem: Given an array nums which consists of non-negative integers and an integer m, you can split the array into m non-empty continuous subarrays. Write an algorithm to minimize the largest sum among these m subarrays.

Example(taken from leetcode):

Input: nums = [7,2,5,10,8], m = 2
Output: 18
Explanation:
There are four ways to split nums into two subarrays.
The best way is to split it into [7,2,5] and [10,8],
where the largest sum among the two subarrays is only 18.


Approach: We can apply binary search here. Let's see how? For binary search we need a sorted sequence right but here we don't have anything like sorted sequence, Is it? Let's see closely.

What is the minimum largest sum we can get according to problem statement. We will achieve it when we split the array into n subarrays right so the maximum of all the n (subarrays) sums will be the maximum number of the array.

When we will have the maximum sum? Obviously when we will have just one array that is the original array and in that case the target sum will be the sum of all the elements of the array. Right.

Now if you see, no matter how many valid number of splits you do on the original array the target sum will always be in between MAX(nums) and SUM(nums) and this is your sorted array which is [MAX(nums) ... SUM(nums)].

Now in case of binary search we need to move low and high according to greater than and lower than conditions. What are these conditions here? 

Here we need to see for a sum 'mid' how many splits we need to make in the original array. Say we need to make num_splits splits for mid. Here we need to see if num_splits > m that means the sum 'mid' is too low for making m splits and that means we need to increase the low to mid + 1 otherwise we will make high as mid so here is the snapshot of the algorithm:

  • low = MAX(nums), high = SUM(nums)
  • WHILE low < high
    • mid = (low + high) / 2
    • num_splits = GetNumOfSplits(nums, mid)
    • IF num_splits > m
      • low = mid + 1
    • ELSE
      • high = mid
  • RETURN low
The implementation of GetNumOfSplits is very trivial and you can understand it by just looking at the code.

Note that we can also do it using DP but the time complexity will be O(m * n^2) which is expensive.

Implementation in C#:

        public int SplitArray(int[] nums, int m)

        {

            int sum = 0;

            int max = 0;

            for (int i = 0; i < nums.Length; ++i)

            {

                max = Math.Max(max, nums[i]);

                sum += nums[i];

            }

            if (m == 1)

            {

                return sum;

            }

            if (m == nums.Length)

            {

                return max;

            }

            int low = max, high = sum;

            while (low < high)

            {

                int mid = low + (high - low) / 2;

                int numOfSplits = this.Split(nums, mid);

                if (numOfSplits > m)

                {

                    low = mid + 1;

                }

                else

                {

                    high = mid;

                }

            }

            return low;

        }


        private int Split(int[] nums, int mid)

        {

            int sum = 0;

            int numOfSplits = 1;

            for (int i = 0; i < nums.Length; ++i)

            {

                if (sum + nums[i] > mid)

                {

                    sum = nums[i];

                    ++numOfSplits;

                }

                else

                {

                    sum += nums[i];

                }

            }

            return numOfSplits;

        }


Complexity: O(n * log(sum - max)) where n is the number of elements, sum is the sum of all the elements and max is maximum number of all the elements.

No comments:

Post a Comment