树状数组求逆序对

2019年10月3日


逆序对问题的基础算法

在一个数组中,如果一对数的前后位置与大小顺序相反,即前面的数大于后面的数,那么它们就称为一个逆序对。例如[4,3,5,9,2,9]中,(4,3), (4,2), (3,2),(5,2),(9,2)是逆序对,所以这个数组中有5个逆序对。在(a,b)中,我们称a为逆序对的前件,b为逆序对的后件。

我们可以很容易想到一个复杂度为[latex]O(n^2)[/latex]的算法,就是用两个循环,外层循环从索引1开始,内层循环统计从数组开头到当前索引的逆序对数量。

下面我们尝试对问题变形,看看能否导出复杂度更低的算法。

我们把数组arr=[4, 3, 5, 9, 2, 9]视作一个(无序的)集合。原数组里有1个数字4,我们就设cArr[4]=1;原数组里有2个数字9,我们就设cArr[9]=2;以此类推。

int[] cArr=[0,0,1,1,1,1,0,0,0,2]
     index  0 1 2 3 4 5 6 7 8 9

现在我们一步一步来构建cArr,并同时求逆序对数目。

第1步,初始化int[] cArr=[0,0,0,0,0,0,0,0,0,0]。

  • 0
  • 0
  • 0
  • 0
  • 1
  • 0
  • 0
  • 0
  • 0
  • 0
  • 0
  • 0
  • 0
  • 1
  • 1
  • 0
  • 0
  • 0
  • 0
  • 0
  • 0
  • 0
  • 0
  • 1
  • 1
  • 1
  • 0
  • 0
  • 0
  • 0
  • 0
  • 0
  • 0
  • 1
  • 1
  • 1
  • 0
  • 0
  • 0
  • 1
  • 0
  • 0
  • 1
  • 1
  • 1
  • 1
  • 0
  • 0
  • 0
  • 1
  • 0
  • 0
  • 1
  • 1
  • 1
  • 1
  • 0
  • 0
  • 0
  • 2
  • 于是,[4,3,5,9,2,9]的逆序对数量为0+1+0+0+0+4+0=5。

    这个算法的时间复杂度要多少?

    我们需要循环n次,每个循环里,我们要做两件事:1、设置值,即cArr[index]+=1,2、求和,即cArr.GetSum(startIndex,length)。如果cArr是普通的数组,设置值的复杂度为o(1),求和的复杂度是o(n)。于是算法整体的复杂度为[latex]O(n(1+n))=O(n^2)[/latex]。

    使用树状数组

    现在终于可以介绍树状数组(Binary Indexed Tree)。树状数组这个中文名称比英文Binary Indexed Tree好很多,因为树状数组用起来跟普通数组差不多,而对tree进行索引访问很奇怪,如tree[2]=10。树状数组有两个独特性质:

    1. 树状数组设置值的复杂度为O(log n)。
    2. 对树状数组任意区域求和的复杂度为O(log n)。

    基于以上特性,我一般在代码中按其功能写成QuickSumArray,而不是Binary Indexed Tree或中文树状数组。

    上述算法的C#代码如下。

    public int FindReversePairs(int[] nums)
    {
    	QuickSumArray arr = new QuickSumArray(length: nums.Max() + 1);
    
    	int count = 0;
    	foreach (var n in nums)
    	{
    		arr[n] += 1; //arr.Update(n,1)速度更快,但复杂度相同。
    		count += arr.GetSum(n + 1, arr.Length - n - 1);
    	}
    
    	return count;
    }

    读者想想这个算法有什么问题或者什么限制?

    1、这个算法的空间复杂度跟nums里的最大值相关。2、nums里不能有负数。

    扩大应用范围

    下面我们讲改进算法,让它能应用于负数和大数,该方法叫做离散化。离散化有三步:排序、去重、索引。

    考虑数组[98998988,-32434234,433234556,-32434234,8384733]。

    先把它从小到大排序并去重:[-324342340, 83847331, 989989882, 433245563],用索引代表一个数。

    于是排序后的数组被离散化为[0, 0, 1, 2, 3],原数组就变成[2,0,3,0,1]。

    一个不良的实现是,没有去重,对于重复的数,取其第一个重复数的索引。这么做的后果是,如果有100个相同的数,离散化的数组可能是[0, 100, ..],数值没有非常紧密。

    离散化有一个性质,设原数组为N,对N的每个元素乘以常数c,得到新数组N’,N和N’的离散化结果相同。

    public void Discretize(int[] nums)
    {
    	SortedList<int, int> list = new SortedList<int, int>();
    
    	foreach (var t in nums)
    		list[t] = 0;
    
    	list[list.Keys[0]] = 0;
    	for (int i = 1; i < list.Keys.Count; i++)
    	{
    		if (list.Keys[i - 1] != list.Keys[i])
    			list[list.Keys[i]] = i;
    	}
    
    	for (int i = 0; i < nums.Length; i++)
    		nums[i] = list[nums[i]];
    }

    FindReversePairs()里面只要先调用Discretize()即可。

    最后优化

    将树状数组下标转置,GetSum(int length)这个重载对数组从索引0开始求和。原来的代码使用GetSum(int startIndex, int length)会进行两次求和,然后相减。

    public int FindReversePairs(int[] nums)
    {
    	Discretize(nums);
    	QuickSumArray arr = new QuickSumArray(length: nums.Max() + 1);
    
    	int count = 0;
    	foreach (var n in nums)
    	{
    		count += arr.GetSum(length: arr.Length - 1 - n);
    		arr.Update(arr.Length - 1 - n, 1);
    	}
    
    	return count;
    }

    LeetCode 493. (Important) Reverse Pairs

    原题链接:https://leetcode.com/problems/reverse-pairs/

    Given an array nums, we call (i, j) an important reverse pair if i < j and nums[i] > 2*nums[j].

    You need to return the number of important reverse pairs in the given array.

    Example1:

    Input: [1,3,2,3,1]
    Output: 2
    Example2:

    Input: [2,4,3,5,1]
    Output: 3
    Note:
    The length of the given array will not exceed 50,000.
    All the numbers in the input array are in the range of 32-bit integer.

    public int ReversePairs(int[] nums)
    {
        int[] sorted = new int[nums.Length];
        nums.CopyTo(sorted, 0);
        Array.Sort(sorted);
    
        QuickSumArray arr = new QuickSumArray(nums.Length);
    
        int count = 0;
        foreach (var n in nums)
        {
            //查找2n在nums里的位置,然后对应到离散化后nums
            long newValue = 2L * n + 1;
    
            int label2 = BinarySearchFirstIndex(sorted, newValue);
            count += arr.GetSum(arr.Length - label2);
    
            int label = BinarySearchFirstIndex(sorted, n);
            arr.Update(arr.Length - 1 - label, 1); //arr[n]+=1
        }
    
        return count;
    }
    
    /// <summary>
    /// 给定一个递增数组,返回第一个大于等于val的元素的索引。
    /// 如果所有元素都小于val,则返回arr.Length。
    /// </summary>
    static int BinarySearchFirstIndex(int[] arr, long val)
    {
        Debug.Assert(arr[0] <= arr[arr.Length - 1], "数组必须是递增的。");
    
        int i;
        if (val > Int32.MaxValue)
        {
            if (arr[arr.Length - 1] == Int32.MaxValue)
                i = arr.Length - 1;
            else
                i = Array.BinarySearch(arr, Int32.MaxValue);
        }
        else if (val < int.MinValue)
        {
            return 0;
        }
        else
            i = Array.BinarySearch(arr, (int)val);
    
        if (i < 0)
            return ~i;
    
        while (i >= 0 && arr[i] == val)
        {
            i--;
        }
    
        return i + 1;
    }

    参考资料