DID : PS 1문제 (2023/04/09)

DID : PS 1문제 (2023/04/09)

와 시험 전날. 오늘은 어려운 문제를 풀고 싶은 기분이어서 수쿼 하나 잡고 열심히 조졌다. 30분 정도 코딩하고 넣으니까 TLE떠서 결국 풀이를 보고 말았다. 참고한 블로그는 여기이다.

BOJ 13546 수열과 쿼리 4
개어렵다. 처음 생각했던 풀이는 쿼리 구간 내 sub배열의 값에 따른 index를 정렬된 형태로 저장하는 deque를 관리하고, 각 k(배열 값)에 대해 Max(j-i)를 가지고 Maximum 세그먼트 트리를 관리한다. 이거를 Mo's 알고리즘으로 조지면
(1) 구간 확장/축소 시에 O(logK)
(2) 총 구간 확장/축소 횟수 : O((N+Q)sqrtN)
즉, O((N+Q) sqrtN logK)이 걸린다.
그렇게 짠 소스는 다음과 같다.

#include<bits/stdc++.h>
using namespace std;
typedef vector<pair<int, pair<int,int> > > vipii;
typedef pair<int, pair<int,int> > pipii;
typedef pair<int, int > pii;

deque<int> idx[(int)1e5+7];
int sqrtn,k, arr[(int)1e5+7];
int ans[(int)1e5+7], ansn[(int)1e5+7];
vipii q;
bool cmp(pipii a, pipii b)
{
    if(a.second.second/sqrtn == b.second.second/sqrtn)
        return a.second.first < b.second.first;
    return a.second.second/sqrtn < b.second.second/sqrtn;
}

struct Node
{
    Node *l, *r;
    int n,v; //n : 어떤 수 , v : max(j-i) 값
    Node(){l = r = NULL; n = v = 0;}

} root;
void init(Node *node, int s,int e)
{
    if(s == e)
    {
        node->v = 0;
        node->n = s;
        return;
    }
    int mid = (s+e)/2;
    node->l = new Node(); init(node->l,s,mid);
    node->r = new Node(); init(node->r,mid+1, e);
    if(node->l->v < node->r->v)
        node->v = node->r->v, node->n = node->r->n;
    else node->v = node->l->v, node->n = node->l->n;
}
void update(Node* node, int s,int e, int x)
{
    if(s == e)
    {
        if(idx[x].empty()) node->v = 0;
        else node->v = idx[x].back() - idx[x].front();
        return;
    }
    int mid = (s+e)/2;
    if(x <= mid) update(node->l, s,mid,x);
    else update(node->r, mid+1, e,x);


    if(node->l->v < node->r->v)
        node->v = node->r->v, node->n = node->r->n;
    else node->v = node->l->v, node->n = node->l->n;
}
void transition(pii s, pii e)
{
    //printf("--%d %d->%d %d--\n", s.first, s.second, e.first, e.second);
    while(e.first < s.first)
    {
        s.first--;
        idx[arr[s.first]].push_front(s.first);
        update(&root, 1, k, arr[s.first]);
    }
    while(s.second < e.second)
    {
        s.second++;
        idx[arr[s.second]].push_back(s.second);
        update(&root, 1, k, arr[s.second]);
    }
    while(e.first > s.first)
    {
        idx[arr[s.first]].pop_front();
        update(&root, 1, k, arr[s.first]);
        s.first++;
    }
    while(s.second > e.second)
    {
        idx[arr[s.second]].pop_back();
        update(&root, 1, k, arr[s.second]);
        s.second--;
    }
}

int main()
{
    int n; scanf("%d %d",&n,&k);
    sqrtn = (int) sqrt(n);
    for(int i = 1; i <= n; i++) scanf("%d", arr+i);
    int Q; scanf("%d", &Q);
    for(int i = 1; i <= Q; i++)
    {
        int a,b; scanf("%d %d", &a,&b);
        q.push_back({i,{a,b}});
    }
    sort(q.begin(), q.end(), cmp);

    pii bef = {1,1};
    idx[arr[1]].push_front(1);
    init(&root, 1, k);

    for(auto i : q)
    {
        transition(bef, i.second);
        ans[i.first] = root.v;
        ansn[i.first] = root.n;
        bef = i.second;
    }
    for(int i = 1; i <= Q; i++) printf("%d\n", ans[i]);
}

이거 TLE 받고 어떻게 풀지 방향을 잃어서 막 검색하다 위에서 소개한 블로그에서 c[j] 배열을 도입해서 풀었는데, 진짜 미친 아이디어다 생각하고 그대로 다시 풀었다. 결국 세그트리는 필요 없던 거였다.

#include<bits/stdc++.h>
using namespace std;
typedef vector<pair<int, pair<int,int> > > vipii;
typedef pair<int, pair<int,int> > pipii;
typedef pair<int, int > pii;

deque<int> idx[(int)1e5+7];
int sqrtn,k, arr[(int)1e5+7];
int ans[(int)1e5+7], ansn[(int)1e5+7];
vipii q;
bool cmp(pipii a, pipii b)
{
    if(a.second.second/sqrtn == b.second.second/sqrtn)
        return a.second.first < b.second.first;
    return a.second.second/sqrtn < b.second.second/sqrtn;
}

int c[(int)1e5+7], searchcmax = 0;

void transition(pii s, pii e)
{
    while(e.first < s.first)
    {
        s.first--;
        if(!idx[arr[s.first]].empty())
            c[idx[arr[s.first]].back() - idx[arr[s.first]].front()]--;

        idx[arr[s.first]].push_front(s.first);

        c[idx[arr[s.first]].back() - idx[arr[s.first]].front()]++;

        searchcmax = max(searchcmax, idx[arr[s.first]].back() - idx[arr[s.first]].front());
    }
    while(s.second < e.second)
    {
        s.second++;
        if(!idx[arr[s.second]].empty())
            c[idx[arr[s.second]].back() - idx[arr[s.second]].front()]--;

        idx[arr[s.second]].push_back(s.second);

        c[idx[arr[s.second]].back() - idx[arr[s.second]].front()]++;

        searchcmax = max(searchcmax, idx[arr[s.second]].back() - idx[arr[s.second]].front());
    }
    while(e.first > s.first)
    {
        c[idx[arr[s.first]].back() - idx[arr[s.first]].front()]--;

        idx[arr[s.first]].pop_front();

        if(!idx[arr[s.first]].empty())
            c[idx[arr[s.first]].back() - idx[arr[s.first]].front()]++;

        s.first++;
    }
    while(s.second > e.second)
    {
        c[idx[arr[s.second]].back() - idx[arr[s.second]].front()]--;

        idx[arr[s.second]].pop_back();

        if(!idx[arr[s.second]].empty())
            c[idx[arr[s.second]].back() - idx[arr[s.second]].front()]++;
        s.second--;
    }
}

int main()
{
    int n; scanf("%d %d",&n,&k);
    sqrtn = (int) sqrt(n);
    for(int i = 1; i <= n; i++) scanf("%d", arr+i);
    int Q; scanf("%d", &Q);
    for(int i = 1; i <= Q; i++)
    {
        int a,b; scanf("%d %d", &a,&b);
        q.push_back({i,{a,b}});
    }
    sort(q.begin(), q.end(), cmp);

    pii bef = {1,1};
    idx[arr[1]].push_front(1);
    c[0]++;

    for(auto i : q)
    {
        transition(bef, i.second);
        while(c[searchcmax] == 0) searchcmax--;
        ans[i.first] = searchcmax;
        bef = i.second;
    }
    for(int i = 1; i <= Q; i++) printf("%d\n", ans[i]);
}

위에 블로그 아이디어 미쳤으니까 한번씩 보고 오길 바란다.