最大权匹配问题,KM模板

chayi / 2023-08-03 / 原文

class KM
{
public:
    // MAXN 最大点数       oo 无穷大
    static const int MAXN = 405, oo = 1000101010;
    int nl, nr, m; // 左边的点数,右边的点数,边数
    int result[MAXN]; // 左边点最大权匹配的匹配
    long long ans;
    KM(int nl, int nr, int m) : nl(nl), nr(nr), m(m) {init(); }
    void init()
    {
        if(nr < nl)
            nr = nl; // necessary for algorithm correctness
        for(int i = 1; i <= nl; i++)
            memset(e[i], 0, (nr + 1) * sizeof(int));
        memset(matchl, 0, (nl + 1) * sizeof(int));
        memset(matchr, 0, (nr + 1) * sizeof(int));
        memset(visitl, 0, (nl + 1) * sizeof(int));
        memset(visitr, 0, (nr + 1) * sizeof(int));
        memset(wl, 0, (nl + 1) * sizeof(int));
        memset(wr, 0, (nr + 1) * sizeof(int));
        vn = 0;
    }
    void insert(int u, int v, int w)
    { // 加边
        e[u][v] = w;
        if(w > wl[u])
            wl[u] = w;
    }
    pair<long long, int *> solve()
    { // ans 最大权匹配的权 result 最大权匹配的匹配
        for(int i = 1; i <= nl; i++)
            bfs(i);
        ans = 0;
        for(int i = 1; i <= nl; i++)
            ans += e[i][matchl[i]];
        for(int i = 1; i <= nl; i++)
            result[i] = e[i][matchl[i]] ? matchl[i] : 0;
        return make_pair(ans, result);
    }
private:
    void found(int x)
    {
        while(x)
        {
            int tmp = matchl[prev[x]];
            matchr[x] = prev[x];
            matchl[prev[x]] = x;
            x = tmp;
        }
    }
    void bfs(int st)
    {
        memset(slack, 63, (nr + 1) * sizeof(int));
        while(!q.empty())
            q.pop();
        vn++;
        visitl[st] = vn;
        q.push(st);
        while(1)
        {
            while(!q.empty())
            {
                int now = q.front(), tmp;
                q.pop();
                for(int i = 1; i <= nr; i++)
                    if(visitr[i] != vn && (tmp = wl[now] + wr[i] - e[now][i]) <= slack[i])
                    {
                        prev[i] = now;
                        if(!tmp)
                        {
                            visitr[i] = vn;
                            if(!matchr[i])
                                return found(i);
                            else
                                visitl[matchr[i]] = vn, q.push(matchr[i]);
                        }
                        else
                            slack[i] = tmp;
                    }
            }
            int d = oo;
            for(int i = 1; i <= nr; i++)
                if(visitr[i] != vn)
                    d = min(d, slack[i]);
            if(d == oo)
                return;
            for(int i = 1; i <= nl; i++)
                if(visitl[i] == vn)
                    wl[i] -= d;
            for(int i = 1; i <= nr; i++)
                if(visitr[i] == vn)
                    wr[i] += d;
                else
                    slack[i] -= d;
            for(int i = 1; i <= nr; i++)
                if(visitr[i] != vn && !slack[i])
                {
                    visitr[i] = vn;
                    if(!matchr[i])
                        return found(i);
                    else
                        visitl[matchr[i]] = vn, q.push(matchr[i]);
                }
        }
    }
    queue<int> q;
    int wl[MAXN], wr[MAXN], slack[MAXN], prev[MAXN], matchl[MAXN], matchr[MAXN], visitl[MAXN], visitr[MAXN], vn;
    int e[MAXN][MAXN];
};

 

class KM
{
public:
    // MAXN 最大点数       oo 无穷大
    static const int MAXN = 405, oo = 1000101010;
    int nl, nr, m; // 左边的点数,右边的点数,边数
    int result[MAXN]; // 左边点最大权匹配的匹配
    long long ans;
    KM(int nl, int nr, int m) : nl(nl), nr(nr), m(m) {init(); }
    void init()
    {
        if(nr < nl)
            nr = nl; // necessary for algorithm correctness
        for(int i = 1; i <= nl; i++)
            memset(e[i], 0, (nr + 1) * sizeof(int));
        memset(matchl, 0, (nl + 1) * sizeof(int));
        memset(matchr, 0, (nr + 1) * sizeof(int));
        memset(visitl, 0, (nl + 1) * sizeof(int));
        memset(visitr, 0, (nr + 1) * sizeof(int));
        memset(wl, 0, (nl + 1) * sizeof(int));
        memset(wr, 0, (nr + 1) * sizeof(int));
        vn = 0;
    }
    void insert(int u, int v, int w)
    { // 加边
        e[u][v] = w;
        if(w > wl[u])
            wl[u] = w;
    }
    pair<long long, int *> solve()
    { // ans 最大权匹配的权 result 最大权匹配的匹配
        for(int i = 1; i <= nl; i++)
            bfs(i);
        ans = 0;
        for(int i = 1; i <= nl; i++)
            ans += e[i][matchl[i]];
        for(int i = 1; i <= nl; i++)
            result[i] = e[i][matchl[i]] ? matchl[i] : 0;
        return make_pair(ans, result);
    }
private:
    void found(int x)
    {
        while(x)
        {
            int tmp = matchl[prev[x]];
            matchr[x] = prev[x];
            matchl[prev[x]] = x;
            x = tmp;
        }
    }
    void bfs(int st)
    {
        memset(slack, 63, (nr + 1) * sizeof(int));
        while(!q.empty())
            q.pop();
        vn++;
        visitl[st] = vn;
        q.push(st);
        while(1)
        {
            while(!q.empty())
            {
                int now = q.front(), tmp;
                q.pop();
                for(int i = 1; i <= nr; i++)
                    if(visitr[i] != vn && (tmp = wl[now] + wr[i] - e[now][i]) <= slack[i])
                    {
                        prev[i] = now;
                        if(!tmp)
                        {
                            visitr[i] = vn;
                            if(!matchr[i])
                                return found(i);
                            else
                                visitl[matchr[i]] = vn, q.push(matchr[i]);
                        }
                        else
                            slack[i] = tmp;
                    }
            }
            int d = oo;
            for(int i = 1; i <= nr; i++)
                if(visitr[i] != vn)
                    d = min(d, slack[i]);
            if(d == oo)
                return;
            for(int i = 1; i <= nl; i++)
                if(visitl[i] == vn)
                    wl[i] -= d;
            for(int i = 1; i <= nr; i++)
                if(visitr[i] == vn)
                    wr[i] += d;
                else
                    slack[i] -= d;
            for(int i = 1; i <= nr; i++)
                if(visitr[i] != vn && !slack[i])
                {
                    visitr[i] = vn;
                    if(!matchr[i])
                        return found(i);
                    else
                        visitl[matchr[i]] = vn, q.push(matchr[i]);
                }
        }
    }
    queue<int> q;
    int wl[MAXN], wr[MAXN], slack[MAXN], prev[MAXN], matchl[MAXN], matchr[MAXN], visitl[MAXN], visitr[MAXN], vn;
    int e[MAXN][MAXN];
};