Ccmmutty logo
Commutty IT
0 pv16 min read

AtCoder Beginner Contest 261 の A 〜 F を解く

https://cdn.magicode.io/media/notebox/blob_TgSiGYY

はじめに

この記事は 2022-07-23(土) に行われた AtCoder Beginner Contest 261 の解説です。
私は ABCE 4 完でした。
コンテスト中に考えたことの整理を兼ねて忘備録として書いておこうと思います。
(D、F 問題は後から解きましたのでその時の考察を書いてます。)
この記事が一緒に AtCoder を楽しんでいる人の役に立てたら嬉しいです。
説明不足・間違いがありましたらコメントをいただけるとありがたいです。

A - Intersection

Twitter を眺めた感じだとこの問題で引っかかった人が多かった印象です。
以下の解法は嘘かもしれませんが一応 AC 取れたので載せておきます。
間違っているとわかったら修正入れます。

考察

  • サンプル 2 のような L1 == R2 の場合が厄介です。
  • 半開区間で処理すれば良さそう?
  • O(1) で処理しようとするとバグらせそうなので、n を 0 ~ 100 まで動かして、L1 <= n < R1 and L2 <= n < R2 を満たすか?で判定しよう。

実装

def resolve():
  L1, R1, L2, R2 = map(int, input().split(" "))
  count = 0
  for n in range(101):
    if L1 <= n < R1 and L2 <= n < R2:
      count += 1

  print(count)

resolve()

B - Tournament Result

Twitter で i == j の時は判定しないというコードを見たのですが、A[i][i] == "D" が含まれるケースで落ちるのでは?と思いました。
ちゃんと制約に書いてあるって教えてもらいました
記事を修正してあります。
そんなんだから誤読するんだぞ!!

考察

  • 愚直に判定していけば良さそう。
  • 矛盾しないケースが (A[i][j], A[j][i]) == ("W", "L"), ("L", "W"), ("D", "D") であるパターンのみなので、矛盾するケースを見つけるより楽そう。

実装

def resolve():
  N = int(input())
  A = [list(input()) for _ in range(N)]

  for i in range(N-1):
    for j in range(i+1, N):
      # 勝負をした時に矛盾がない組み合わせ 3 パターンをチェックする。
      if A[i][j] == "W" and A[j][i] == "L": continue
      if A[i][j] == "L" and A[j][i] == "W": continue
      if A[i][j] == "D" and A[j][i] == "D": continue

      print("incorrect")
      return

  print("correct")

resolve()

C - NewFolder(1)

考察

  • 「今までに Si が何回出てきたのか?」を管理する必要がある。
  • Si を見た時に、今までに Si と同じ文字列が出現している場合とそうでない場合で場合分けする必要がある。

実装

def resolve():
  from collections import defaultdict
  N = int(input())
  
  # count: {S: <S が今までに何個出現したか>} を管理する dict
  count = defaultdict(int)
  for _ in range(N):
    s = input()
    ans = s
    # S[:i-1] の中に既に s が出現している場合は (<出現した個数>) を足す
    if s in count:
      ans += "({})".format(count[s])
    count[s] += 1
    print(ans)

resolve()

D - Flipping and Bonus

私はコンテスト中に解くことができませんでした。
コンテスト後、DP を使うということを Twitter の TL で知ったので、それをベースに考察して通しました。
これは緑 diff (801) なんですね。
DP の中でも少し捻りを加えた感じだと思ったんですが、皆さん頭が良すぎませんか???

考察

  • DP でやるということは (Twitter 情報で) 掴んでいるので、そこから考察を進めます。
    • 解けた人はどうやって DP に辿り着いたかという話ですが、以下のような人がいるみたいです。
      • 制約が M <= N <= 5000 ということで、O(N^2) でも間に合う => DP で通るだろ!という発想
      • とりあえず DP を考えてみるようにしている
  • DP の状態遷移をどのように行うかを考えます。
    • 今回の問題では一回コイントスをした時に表 or 裏しか出ないのでその 2 パターンの遷移を考えます。
  • 次にどのように状態を持つのかを考えます。
    • 最終的に必要なのは「最大何円もらえるか?」です。
    • 金額に影響を与えるのは「現在何回コイントスを行ったのか?」と「カウンターに現在表示されている数値」だけです。
    • よって、dp[i][c] := i 回コイントスを行い、カウンターに c と表示されている時にもらえる最大の金額 というように状態を持ちます。
  • 最終的に下図のような感じにします。
  • 状態遷移は以下のようになります。
    • 表が出た時 dp[i+1][c+1] = max(dp[i+1][c+1], dp[i][c] + <トスを i+1 回行った時に表が出たら貰える金額> + <カウンターが c+1 になった時にもらえる金額>)
    • 裏が出た時はお金が一切もらえないので、 dp[i+1][0] = max(dp[i+1][0], dp[i][c]) となります。
  • 最終的に N 回のコイントスが終わった時にもらえる金額の最大値 max(dp[N]) が答えになります。

実装

def resolve():
  N, M = map(int, input().split(" "))
  X = [int(x) for x in input().split(" ")]
  BONUS = [0]*(N+1)
  for _ in range(M):
    c, y = [int(x) for x in input().split(" ")]
    BONUS[c] = y

  # dp[i][c] := i 回コインをトスした後、カウンタの値が c の時に最も多くもらえるお金。
  dp = [[0]*(N+1) for _ in range(N+1)]
  for i in range(N):
    for c in range(i+1):
      # 表が出た時の処理
      dp[i+1][c+1] = max(dp[i+1][c+1], dp[i][c]+X[i]+BONUS[c+1])
      # 裏が出た時の処理
      dp[i+1][0] = max(dp[i+1][0], dp[i][c])

  print(max(dp[-1]))

resolve()

E - Many Operations

考察

  • 問題文通りの計算を素直に行うと、O(N^2) の計算量がかかるので今回の制約だと TLE する。
  • 1 ~ i 番目の処理をまとめて O(1) で行うことができれば全体で O(N) で処理できるので間に合う。
  • X に 1 ~ i 番目の処理を行った結果の値が何になるかは X が決まっていれば O(1) で求めることができる。
    • 全てのとりうる X を列挙してその結果を記録していると、2^30 の空間が必要なのでメモリが足りない。これを圧縮しなければいけない。
    • 処理は全て桁同士が独立したビット演算なので、桁毎に考えれば良さそう。
    • ある桁が 0 or 1 である時に 1 ~ i 番目を全て行なった結果を記録しておけば、X の二進数における各桁の更新をそれぞれ O(1) で行える。

実装

def resolve():
  N, C = map(int, input().split(" "))
  # digit: 処理する桁数(二進数)
  digit = 31

  # X を二進数にして桁毎に配列化した状態で持っておく。
  X = [1 if (C>>i)&1 else 0 for i in range(digit)]

  # filters[i][x] := フィルターをかける前に X の i 桁目が x だった時、フィルターをかけた結果が 0 or 1 のどっちになるか
  filters = [[0, 1] for _ in range(digit)]
  for _ in range(N):
    T, A = [int(x) for x in input().split(" ")]
    # A_: 二進数にして桁毎に配列化したもの
    A_ = [1 if (A>>i)&1 else 0 for i in range(digit)]

    # filters を更新する。
    for i in range(digit):
      if T == 1:
        filters[i][0] &= A_[i]
        filters[i][1] &= A_[i]
      elif T == 2:
        filters[i][0] |= A_[i]
        filters[i][1] |= A_[i]
      else:
        filters[i][0] ^= A_[i]
        filters[i][1] ^= A_[i]
    
    # X を更新
    X = [filters[i][X[i]] for i in range(digit)]
    # 二進数の桁毎の配列なので、それを十進数に直して出力する。
    print(sum(X[i]*pow(2, i) for i in range(digit)))

resolve()

F - Sorting Color Balls

考察

  • ボールの色を考えない場合(ボールが全部別の色だった場合)、転倒数をただ求めればいいだけ。
    • ソートの操作回数を求めるのに転倒数を使う問題をやった記憶がある。
  • ボールの色が同じ時、コストがかからないのを考慮しなければいけないのが厄介。
  • 転倒数を求める時に、「X[:i-1] の内 X[i] よりも大きい数字の個数」を求めているが、その中で同じ色の数字を除外すれば良さそう。
    • 入れ替え操作を行う時、X[:i-1] の内 X[i] よりも大きい数字 1 個と X[i] がすれ違う(問題文中の入れ替え操作を行う)のは 1 回なので、上記の数字の個数 == 削減できるコストになる。
  • 「X[:i-1] の内 X[i] よりも大きくて色が違う数字の個数」は「X[:i-1] の内 X[i] よりも大きい数字の個数」から「X[:i-1] の内 X[i] よりも大きくて同じ色の数字の個数」を引いた値になる。
    • 「X[:i-1] の内 X[i] よりも大きくて同じ色の数字の個数」を求めるには、今まで見てきた数字を色毎にソートした状態で持っておいて、X[i] よりも大きい数字の個数を調べればいいんだから tatyam さんの SortedMultiset を上手く使えばできそう。
      • 調べたらそのものズバリなメソッドがあったのでこれを利用する。

実装

少し長いです。
def resolve(): から私の実装です。
# tatyam さんの SortedMultiset 
# https://github.com/tatyam-prime/SortedSet
# https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
# multiset
import math
from bisect import bisect_left, bisect_right, insort
from typing import Generic, Iterable, Iterator, TypeVar, Union, List
T = TypeVar('T')

class SortedMultiset(Generic[T]):
  BUCKET_RATIO = 50
  REBUILD_RATIO = 170

  def _build(self, a=None) -> None:
    "Evenly divide `a` into buckets."
    if a is None: a = list(self)
    size = self.size = len(a)
    bucket_size = int(math.ceil(math.sqrt(size / self.BUCKET_RATIO)))
    self.a = [a[size * i // bucket_size : size * (i + 1) // bucket_size] for i in range(bucket_size)]
  
  def __init__(self, a: Iterable[T] = []) -> None:
    "Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)"
    a = list(a)
    if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
        a = sorted(a)
    self._build(a)

  def __iter__(self) -> Iterator[T]:
    for i in self.a:
      for j in i: yield j

  def __reversed__(self) -> Iterator[T]:
    for i in reversed(self.a):
      for j in reversed(i): yield j
  
  def __len__(self) -> int:
    return self.size
  
  def __repr__(self) -> str:
    return "SortedMultiset" + str(self.a)
  
  def __str__(self) -> str:
    s = str(list(self))
    return "{" + s[1 : len(s) - 1] + "}"

  def _find_bucket(self, x: T) -> List[T]:
    "Find the bucket which should contain x. self must not be empty."
    for a in self.a:
      if x <= a[-1]: return a
    return a

  def __contains__(self, x: T) -> bool:
    if self.size == 0: return False
    a = self._find_bucket(x)
    i = bisect_left(a, x)
    return i != len(a) and a[i] == x

  def count(self, x: T) -> int:
    "Count the number of x."
    return self.index_right(x) - self.index(x)

  def add(self, x: T) -> None:
    "Add an element. / O(√N)"
    if self.size == 0:
      self.a = [[x]]
      self.size = 1
      return
    a = self._find_bucket(x)
    insort(a, x)
    self.size += 1
    if len(a) > len(self.a) * self.REBUILD_RATIO:
      self._build()

  def discard(self, x: T) -> bool:
    "Remove an element and return True if removed. / O(√N)"
    if self.size == 0: return False
    a = self._find_bucket(x)
    i = bisect_left(a, x)
    if i == len(a) or a[i] != x: return False
    a.pop(i)
    self.size -= 1
    if len(a) == 0: self._build()
    return True

  def lt(self, x: T) -> Union[T, None]:
    "Find the largest element < x, or None if it doesn't exist."
    for a in reversed(self.a):
      if a[0] < x:
        return a[bisect_left(a, x) - 1]

  def le(self, x: T) -> Union[T, None]:
    "Find the largest element <= x, or None if it doesn't exist."
    for a in reversed(self.a):
      if a[0] <= x:
        return a[bisect_right(a, x) - 1]

  def gt(self, x: T) -> Union[T, None]:
    "Find the smallest element > x, or None if it doesn't exist."
    for a in self.a:
      if a[-1] > x:
        return a[bisect_right(a, x)]

  def ge(self, x: T) -> Union[T, None]:
    "Find the smallest element >= x, or None if it doesn't exist."
    for a in self.a:
      if a[-1] >= x:
        return a[bisect_left(a, x)]

  def __getitem__(self, x: int) -> T:
    "Return the x-th element, or IndexError if it doesn't exist."
    if x < 0: x += self.size
    if x < 0: raise IndexError
    for a in self.a:
      if x < len(a): return a[x]
      x -= len(a)
    raise IndexError

  def index(self, x: T) -> int:
    "Count the number of elements < x."
    ans = 0
    for a in self.a:
      if a[-1] >= x:
        return ans + bisect_left(a, x)
      ans += len(a)
    return ans

  def index_right(self, x: T) -> int:
    "Count the number of elements <= x."
    ans = 0
    for a in self.a:
      if a[-1] > x:
        return ans + bisect_right(a, x)
      ans += len(a)
    return ans

# takayg1 さんの SegTree https://qiita.com/takayg1/items/c811bd07c21923d7ec69
#####segfunc#####
def segfunc(x, y):
  return x+y
#################

#####ide_ele#####
ide_ele = 0 # 区間和、最大公約数
#################

class SegTree:
  """
  init(init_val, ide_ele): 配列init_valで初期化 O(N)
  update(k, x): k番目の値をxに更新 O(logN)
  query(l, r): 区間[l, r)をsegfuncしたものを返す O(logN)
  """
  def __init__(self, init_val, segfunc, ide_ele):
    """
    init_val: 配列の初期値
    segfunc: 区間にしたい操作
    ide_ele: 単位元
    n: 要素数
    num: n以上の最小の2のべき乗
    tree: セグメント木(1-index)
    """
    n = len(init_val)
    self.segfunc = segfunc
    self.ide_ele = ide_ele
    self.num = 1 << (n - 1).bit_length()
    self.tree = [ide_ele] * 2 * self.num
    # 配列の値を葉にセット
    for i in range(n):
      self.tree[self.num + i] = init_val[i]
    # 構築していく
    for i in range(self.num - 1, 0, -1):
      self.tree[i] = self.segfunc(self.tree[2 * i], self.tree[2 * i + 1])

  def update(self, k, x):
    """
    k番目の値をxに更新
    k: index(0-index)
    x: update value
    """
    # 葉の部分が前半に入っている。+= self.numをすることで元になる配列要素に移動
    k += self.num
    self.tree[k] = x
    while k > 1:
      # k >> 1 == k // 2 (index k の親の index)
      # k ^ 1 : 末尾を xor する。偶数だったら +1、奇数だったら -1 する。k インデックス(子)の片割れ
      self.tree[k >> 1] = self.segfunc(self.tree[k], self.tree[k ^ 1])
      k >>= 1

  def query(self, l, r):
    """
    [l, r)のsegfuncしたものを得る
    l: index(0-index)
    r: index(0-index)
    """
    res = self.ide_ele

    l += self.num
    r += self.num
    while l < r:
      # l & 1 => l が 奇数 (ペアの右側) だったら 1:
      if l & 1:
        res = self.segfunc(res, self.tree[l])
        l += 1
      # r が奇数 = r-1 が偶数。なので、子のペアの左を見ることになる。
      if r & 1:
        res = self.segfunc(res, self.tree[r - 1])
      # l // 2
      l >>= 1
      r >>= 1
    return res

def resolve():
  N = int(input())
  C = [int(x)-1 for x in input().split(" ")]
  X = [int(x)-1 for x in input().split(" ")]
  # recent_balles[c]: 今まで見てきたボールの内、c 色のもの。SortedMultiset で管理されている。
  recent_balles = [SortedMultiset([]) for _ in range(N)]

  # seg: 今まで見てきたボールに書かれた数字の個数を記録してある。転倒数を求めるのに使う。
  seg = SegTree([0]*(N+1), segfunc, ide_ele)

  ans = 0
  for i in range(N):
    # 答えに転倒数の一部(X[:i]より左にある X[i] よりも大きな数の個数)を足す。
    ans += seg.query(X[i]+1, N)

    # ↑ で足した値から、色が一緒でコストがかからなかったケースを除外する。
    ans -= recent_balles[C[i]].size - recent_balles[C[i]].index_right(X[i])

    # アップデート
    seg.update(X[i], seg.query(X[i], X[i]+1)+1)
    recent_balles[C[i]].add(X[i])

  print(ans)

resolve()

Discussion

コメントにはログインが必要です。